diff --git a/go.mod b/go.mod index 07c0f4e..8edde2c 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/verdverm/chatgpt go 1.18 require ( - github.com/sashabaranov/go-gpt3 v1.0.1 + github.com/sashabaranov/go-openai v1.5.0 github.com/spf13/cobra v1.6.1 ) diff --git a/go.sum b/go.sum index 78ec907..1774fa5 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t github.com/inconshreveable/mousetrap v1.0.1 h1:U3uMjPSQEBMNp1lFxmllqCPM6P5u/Xq7Pgzkat/bFNc= github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sashabaranov/go-gpt3 v1.0.1 h1:KHwY4uroFlX1qI1Hui7d31ZI6uzbNGL9zAkh1FkfhuM= -github.com/sashabaranov/go-gpt3 v1.0.1/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ= +github.com/sashabaranov/go-openai v1.5.0 h1:4Gr/7g/KtVzW0ddn7TC2aUlyzvhZBIM+qRZ6Ae2kMa0= +github.com/sashabaranov/go-openai v1.5.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA= github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= diff --git a/main.go b/main.go index 75e5cd5..929be47 100644 --- a/main.go +++ b/main.go @@ -11,7 +11,7 @@ import ( "strconv" "strings" - gpt3 "github.com/sashabaranov/go-gpt3" + gpt3 "github.com/sashabaranov/go-openai" "github.com/spf13/cobra" ) @@ -52,14 +52,18 @@ Examples: chatgpt -c ... # model options (https://platform.openai.com/docs/api-reference/completions/create) - chatgpt -T 4096 # set max tokens in reponse [0,4096] - chatgpt -C # clean whitespace before sending - chatgpt -E # echo back the prompt, useful for vim coding - chatgpt --temp # set the temperature param [0.0,2.0] - chatgpt --topp # set the TopP param [0.0,1.0] - chatgpt --pres # set the Presence Penalty [-2.0,2.0] - chatgpt --freq # set the Frequency Penalty [-2.0,2.0] - + chatgpt -T 4096 # set max tokens in reponse [0,4096] + chatgpt -C # clean whitespace before sending + chatgpt -E # echo back the prompt, useful for vim coding + chatgpt --temp # set the temperature param [0.0,2.0] + chatgpt --topp # set the TopP param [0.0,1.0] + chatgpt --pres # set the Presence Penalty [-2.0,2.0] + chatgpt --freq # set the Frequency Penalty [-2.0,2.0] + + # change model selection, available models are listed here: + # https://pkg.go.dev/github.com/sashabaranov/go-openai#Client.ListModels + chatgpt -m text-davinci-003 # set the model to text-davinci-003 (the default) + chatgpt -m text-ada-001 # set the model to text-ada-001 ` @@ -72,6 +76,7 @@ var interactiveHelp = `starting interactive session... 'topp' set the TopP param [0.0,1.0] 'pres' set the Presence Penalty [-2.0,2.0] 'freq' set the Frequency Penalty [-2.0,2.0] + 'model' to change the selected model ` //go:embed pretexts/* @@ -97,6 +102,7 @@ var Temp float64 var TopP float64 var PresencePenalty float64 var FrequencyPenalty float64 +var Model string func GetCompletionResponse(client *gpt3.Client, ctx context.Context, question string) ([]string, error) { if CleanPrompt { @@ -109,7 +115,7 @@ func GetCompletionResponse(client *gpt3.Client, ctx context.Context, question st } req := gpt3.CompletionRequest{ - Model: gpt3.GPT3TextDavinci003, + Model: Model, MaxTokens: MaxTokens, Prompt: question, Echo: Echo, @@ -136,7 +142,8 @@ func GetEditsResponse(client *gpt3.Client, ctx context.Context, input, instructi input = strings.ReplaceAll(input, "\n", " ") input = strings.ReplaceAll(input, " ", " ") } - m := gpt3.GPT3TextDavinci003 + + m := Model req := gpt3.EditsRequest{ Model: &m, Input: input, @@ -222,6 +229,7 @@ type NullWriter int func (NullWriter) Write([]byte) (int, error) { return 0, nil } func main() { + apiKey := os.Getenv("CHATGPT_API_KEY") if apiKey == "" { fmt.Println("CHATGPT_API_KEY environment var is missing\nVisit https://platform.openai.com/account/api-keys to get one\n") @@ -361,6 +369,7 @@ func main() { rootCmd.Flags().Float64VarP(&TopP, "topp", "", 1.0, "set the TopP parameter") rootCmd.Flags().Float64VarP(&PresencePenalty, "pres", "", 0.0, "set the Presence Penalty parameter") rootCmd.Flags().Float64VarP(&FrequencyPenalty, "freq", "", 0.0, "set the Frequency Penalty parameter") + rootCmd.Flags().StringVarP(&Model, "model", "m", gpt3.GPT3TextDavinci003, "select the model to use with -q or -e") // run the command rootCmd.Execute() @@ -397,6 +406,16 @@ func RunPrompt(client *gpt3.Client) error { } continue + case "model": + if len(parts) == 1 { + fmt.Println("model is set to", Model) + continue + } + + Model = parts[1] + fmt.Println("model is now", Model) + continue + case "tokens": if len(parts) == 1 { fmt.Println("tokens is set to", MaxTokens)