|
|
|
@ -97,6 +97,37 @@ var Temp float64
|
|
|
|
|
var TopP float64
|
|
|
|
|
var PresencePenalty float64
|
|
|
|
|
var FrequencyPenalty float64
|
|
|
|
|
var Model string
|
|
|
|
|
|
|
|
|
|
// checkModel verifies that the selected mode id is one that go-gpt3 knows
|
|
|
|
|
// about, producing an error if not.
|
|
|
|
|
//
|
|
|
|
|
// TODO: in future, this could probably leverage gpt3.Client.ListModels() to
|
|
|
|
|
// support the user's fine-tuned models.
|
|
|
|
|
func checkModel(m string) error {
|
|
|
|
|
knownModels := []string{
|
|
|
|
|
gpt3.GPT3TextDavinci003,
|
|
|
|
|
gpt3.GPT3TextDavinci002,
|
|
|
|
|
gpt3.GPT3TextCurie001,
|
|
|
|
|
gpt3.GPT3TextBabbage001,
|
|
|
|
|
gpt3.GPT3TextAda001,
|
|
|
|
|
gpt3.GPT3TextDavinci001,
|
|
|
|
|
gpt3.GPT3DavinciInstructBeta,
|
|
|
|
|
gpt3.GPT3Davinci,
|
|
|
|
|
gpt3.GPT3CurieInstructBeta,
|
|
|
|
|
gpt3.GPT3Curie,
|
|
|
|
|
gpt3.GPT3Ada,
|
|
|
|
|
gpt3.GPT3Babbage,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for _, v := range knownModels {
|
|
|
|
|
if m == v {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return fmt.Errorf("unknown model '%s', expected one of: %s", m, strings.Join(knownModels, ", "))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func GetCompletionResponse(client *gpt3.Client, ctx context.Context, question string) ([]string, error) {
|
|
|
|
|
if CleanPrompt {
|
|
|
|
@ -108,8 +139,13 @@ func GetCompletionResponse(client *gpt3.Client, ctx context.Context, question st
|
|
|
|
|
question += "\n"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err := checkModel(Model)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
req := gpt3.CompletionRequest{
|
|
|
|
|
Model: gpt3.GPT3TextDavinci003,
|
|
|
|
|
Model: Model,
|
|
|
|
|
MaxTokens: MaxTokens,
|
|
|
|
|
Prompt: question,
|
|
|
|
|
Echo: Echo,
|
|
|
|
@ -136,7 +172,13 @@ func GetEditsResponse(client *gpt3.Client, ctx context.Context, input, instructi
|
|
|
|
|
input = strings.ReplaceAll(input, "\n", " ")
|
|
|
|
|
input = strings.ReplaceAll(input, " ", " ")
|
|
|
|
|
}
|
|
|
|
|
m := gpt3.GPT3TextDavinci003
|
|
|
|
|
|
|
|
|
|
err := checkModel(Model)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
m := Model
|
|
|
|
|
req := gpt3.EditsRequest{
|
|
|
|
|
Model: &m,
|
|
|
|
|
Input: input,
|
|
|
|
@ -222,6 +264,7 @@ type NullWriter int
|
|
|
|
|
func (NullWriter) Write([]byte) (int, error) { return 0, nil }
|
|
|
|
|
|
|
|
|
|
func main() {
|
|
|
|
|
|
|
|
|
|
apiKey := os.Getenv("CHATGPT_API_KEY")
|
|
|
|
|
if apiKey == "" {
|
|
|
|
|
fmt.Println("CHATGPT_API_KEY environment var is missing\nVisit https://platform.openai.com/account/api-keys to get one\n")
|
|
|
|
@ -361,6 +404,7 @@ func main() {
|
|
|
|
|
rootCmd.Flags().Float64VarP(&TopP, "topp", "", 1.0, "set the TopP parameter")
|
|
|
|
|
rootCmd.Flags().Float64VarP(&PresencePenalty, "pres", "", 0.0, "set the Presence Penalty parameter")
|
|
|
|
|
rootCmd.Flags().Float64VarP(&FrequencyPenalty, "freq", "", 0.0, "set the Frequency Penalty parameter")
|
|
|
|
|
rootCmd.Flags().StringVarP(&Model, "model", "m", gpt3.GPT3TextDavinci003, "select the model to use with -q or -e")
|
|
|
|
|
|
|
|
|
|
// run the command
|
|
|
|
|
rootCmd.Execute()
|
|
|
|
|