From e802db40be18f3b40b3baf8abfe93d72da6f0f95 Mon Sep 17 00:00:00 2001 From: "Charles A. Daniels" Date: Thu, 23 Feb 2023 19:02:52 -0500 Subject: [PATCH] add support for changing the model with --model/-m --- main.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/main.go b/main.go index 75e5cd5..23c4da4 100644 --- a/main.go +++ b/main.go @@ -97,6 +97,37 @@ var Temp float64 var TopP float64 var PresencePenalty float64 var FrequencyPenalty float64 +var Model string + +// checkModel verifies that the selected mode id is one that go-gpt3 knows +// about, producing an error if not. +// +// TODO: in future, this could probably leverage gpt3.Client.ListModels() to +// support the user's fine-tuned models. +func checkModel(m string) error { + knownModels := []string{ + gpt3.GPT3TextDavinci003, + gpt3.GPT3TextDavinci002, + gpt3.GPT3TextCurie001, + gpt3.GPT3TextBabbage001, + gpt3.GPT3TextAda001, + gpt3.GPT3TextDavinci001, + gpt3.GPT3DavinciInstructBeta, + gpt3.GPT3Davinci, + gpt3.GPT3CurieInstructBeta, + gpt3.GPT3Curie, + gpt3.GPT3Ada, + gpt3.GPT3Babbage, + } + + for _, v := range knownModels { + if m == v { + return nil + } + } + + return fmt.Errorf("unknown model '%s', expected one of: %s", m, strings.Join(knownModels, ", ")) +} func GetCompletionResponse(client *gpt3.Client, ctx context.Context, question string) ([]string, error) { if CleanPrompt { @@ -108,8 +139,13 @@ func GetCompletionResponse(client *gpt3.Client, ctx context.Context, question st question += "\n" } + err := checkModel(Model) + if err != nil { + return nil, err + } + req := gpt3.CompletionRequest{ - Model: gpt3.GPT3TextDavinci003, + Model: Model, MaxTokens: MaxTokens, Prompt: question, Echo: Echo, @@ -136,7 +172,13 @@ func GetEditsResponse(client *gpt3.Client, ctx context.Context, input, instructi input = strings.ReplaceAll(input, "\n", " ") input = strings.ReplaceAll(input, " ", " ") } - m := gpt3.GPT3TextDavinci003 + + err := checkModel(Model) + if err != nil { + return nil, err + } + + m := Model req := gpt3.EditsRequest{ Model: &m, Input: input, @@ -222,6 +264,7 @@ type NullWriter int func (NullWriter) Write([]byte) (int, error) { return 0, nil } func main() { + apiKey := os.Getenv("CHATGPT_API_KEY") if apiKey == "" { fmt.Println("CHATGPT_API_KEY environment var is missing\nVisit https://platform.openai.com/account/api-keys to get one\n") @@ -361,6 +404,7 @@ func main() { rootCmd.Flags().Float64VarP(&TopP, "topp", "", 1.0, "set the TopP parameter") rootCmd.Flags().Float64VarP(&PresencePenalty, "pres", "", 0.0, "set the Presence Penalty parameter") rootCmd.Flags().Float64VarP(&FrequencyPenalty, "freq", "", 0.0, "set the Frequency Penalty parameter") + rootCmd.Flags().StringVarP(&Model, "model", "m", gpt3.GPT3TextDavinci003, "select the model to use with -q or -e") // run the command rootCmd.Execute()