add support for changing the model with --model/-m

pull/1/head
Charles A. Daniels 1 year ago
parent 210d15ae70
commit e802db40be

@ -97,6 +97,37 @@ var Temp float64
var TopP float64
var PresencePenalty float64
var FrequencyPenalty float64
var Model string
// checkModel verifies that the selected mode id is one that go-gpt3 knows
// about, producing an error if not.
//
// TODO: in future, this could probably leverage gpt3.Client.ListModels() to
// support the user's fine-tuned models.
func checkModel(m string) error {
knownModels := []string{
gpt3.GPT3TextDavinci003,
gpt3.GPT3TextDavinci002,
gpt3.GPT3TextCurie001,
gpt3.GPT3TextBabbage001,
gpt3.GPT3TextAda001,
gpt3.GPT3TextDavinci001,
gpt3.GPT3DavinciInstructBeta,
gpt3.GPT3Davinci,
gpt3.GPT3CurieInstructBeta,
gpt3.GPT3Curie,
gpt3.GPT3Ada,
gpt3.GPT3Babbage,
}
for _, v := range knownModels {
if m == v {
return nil
}
}
return fmt.Errorf("unknown model '%s', expected one of: %s", m, strings.Join(knownModels, ", "))
}
func GetCompletionResponse(client *gpt3.Client, ctx context.Context, question string) ([]string, error) {
if CleanPrompt {
@ -108,8 +139,13 @@ func GetCompletionResponse(client *gpt3.Client, ctx context.Context, question st
question += "\n"
}
err := checkModel(Model)
if err != nil {
return nil, err
}
req := gpt3.CompletionRequest{
Model: gpt3.GPT3TextDavinci003,
Model: Model,
MaxTokens: MaxTokens,
Prompt: question,
Echo: Echo,
@ -136,7 +172,13 @@ func GetEditsResponse(client *gpt3.Client, ctx context.Context, input, instructi
input = strings.ReplaceAll(input, "\n", " ")
input = strings.ReplaceAll(input, " ", " ")
}
m := gpt3.GPT3TextDavinci003
err := checkModel(Model)
if err != nil {
return nil, err
}
m := Model
req := gpt3.EditsRequest{
Model: &m,
Input: input,
@ -222,6 +264,7 @@ type NullWriter int
func (NullWriter) Write([]byte) (int, error) { return 0, nil }
func main() {
apiKey := os.Getenv("CHATGPT_API_KEY")
if apiKey == "" {
fmt.Println("CHATGPT_API_KEY environment var is missing\nVisit https://platform.openai.com/account/api-keys to get one\n")
@ -361,6 +404,7 @@ func main() {
rootCmd.Flags().Float64VarP(&TopP, "topp", "", 1.0, "set the TopP parameter")
rootCmd.Flags().Float64VarP(&PresencePenalty, "pres", "", 0.0, "set the Presence Penalty parameter")
rootCmd.Flags().Float64VarP(&FrequencyPenalty, "freq", "", 0.0, "set the Frequency Penalty parameter")
rootCmd.Flags().StringVarP(&Model, "model", "m", gpt3.GPT3TextDavinci003, "select the model to use with -q or -e")
// run the command
rootCmd.Execute()

Loading…
Cancel
Save