diff --git a/README.md b/README.md index cdcab7f..361f452 100644 --- a/README.md +++ b/README.md @@ -45,8 +45,8 @@ Examples: chatgpt -p cynic -q "Is the world going to be ok?" chatgpt -p teacher convo.txt - # extra options - chatgpt -t 4096 # set max tokens in reponse + # model options + chatgpt -T 4096 # set max tokens in reponse chatgpt -c # clean whitespace before sending Usage: diff --git a/main.go b/main.go index ae322e6..de4d1d9 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "fmt" "os" "runtime/debug" + "strconv" "strings" gpt3 "github.com/sashabaranov/go-gpt3" @@ -44,27 +45,57 @@ Examples: chatgpt -p cynic -q "Is the world going to be ok?" chatgpt -p teacher convo.txt - # extra options - chatgpt -t 4096 # set max tokens in reponse - chatgpt -c # clean whitespace before sending + # edit mode + chatgpt -e ... + + # code mode + chatgpt -c ... + + # model options (https://platform.openai.com/docs/api-reference/completions/create) + chatgpt -T 4096 # set max tokens in reponse [0,4096] + chatgpt -C # clean whitespace before sending + chatgpt --temp # set the temperature param [0.0,2.0] + chatgpt --topp # set the TopP param [0.0,1.0] + chatgpt --pres # set the Presence Penalty [-2.0,2.0] + chatgpt --freq # set the Frequency Penalty [-2.0,2.0] + + ` var interactiveHelp = `starting interactive session... - 'quit' to exit, 'save ' to preserve + 'quit' to exit + 'save ' to preserve + 'tokens' to change the MaxToken param + 'count' to change number of responses + 'temp' set the temperature param [0.0,2.0] + 'topp' set the TopP param [0.0,1.0] + 'pres' set the Presence Penalty [-2.0,2.0] + 'freq' set the Frequency Penalty [-2.0,2.0] ` //go:embed pretexts/* var predefined embed.FS var Version bool + +// prompt vars var Question string var Pretext string -var MaxTokens int var PromptMode bool +var EditMode bool +var CodeMode bool var CleanPrompt bool var WriteBack bool var PromptText string -func GetResponse(client *gpt3.Client, ctx context.Context, question string) (string, error) { +// chatgpt vars +var MaxTokens int +var Count int +var Temp float64 +var TopP float64 +var PresencePenalty float64 +var FrequencyPenalty float64 + +func GetCompletionResponse(client *gpt3.Client, ctx context.Context, question string) ([]string, error) { if CleanPrompt { question = strings.ReplaceAll(question, "\n", " ") question = strings.ReplaceAll(question, " ", " ") @@ -73,13 +104,75 @@ func GetResponse(client *gpt3.Client, ctx context.Context, question string) (str Model: gpt3.GPT3TextDavinci003, MaxTokens: MaxTokens, Prompt: question, + N: Count, + Temperature: float32(Temp), + TopP: float32(TopP), + PresencePenalty: float32(PresencePenalty), + FrequencyPenalty: float32(FrequencyPenalty), } resp, err := client.CreateCompletion(ctx, req) if err != nil { - return "", err + return nil, err } - return resp.Choices[0].Text, nil + var r []string + for _, c := range resp.Choices { + r = append(r, c.Text) + } + return r, nil +} + +func GetEditsResponse(client *gpt3.Client, ctx context.Context, input, instruction string) ([]string, error) { + if CleanPrompt { + input = strings.ReplaceAll(input, "\n", " ") + input = strings.ReplaceAll(input, " ", " ") + } + m := gpt3.GPT3TextDavinci003 + req := gpt3.EditsRequest{ + Model: &m, + Input: input, + Instruction: instruction, + N: Count, + Temperature: float32(Temp), + TopP: float32(TopP), + } + resp, err := client.Edits(ctx, req) + if err != nil { + return nil, err + } + + var r []string + for _, c := range resp.Choices { + r = append(r, c.Text) + } + return r, nil +} + +func GetCodeResponse(client *gpt3.Client, ctx context.Context, question string) ([]string, error) { + if CleanPrompt { + question = strings.ReplaceAll(question, "\n", " ") + question = strings.ReplaceAll(question, " ", " ") + } + req := gpt3.CompletionRequest{ + Model: gpt3.CodexCodeDavinci002, + MaxTokens: MaxTokens, + Prompt: question, + N: Count, + Temperature: float32(Temp), + TopP: float32(TopP), + PresencePenalty: float32(PresencePenalty), + FrequencyPenalty: float32(FrequencyPenalty), + } + resp, err := client.CreateCompletion(ctx, req) + if err != nil { + return nil, err + } + + var r []string + for _, c := range resp.Choices { + r = append(r, c.Text) + } + return r, nil } func printVersion() { @@ -127,7 +220,6 @@ func main() { Short: "Chat with ChatGPT in console.", Long: LongHelp, Run: func(cmd *cobra.Command, args []string) { - fmt.Println(Version) if Version { printVersion() os.Exit(0) @@ -212,7 +304,7 @@ func main() { } // if there is a question, it comes last in the prompt - if Question != "" { + if Question != "" && !EditMode { PromptText += "\n" + Question } @@ -235,14 +327,27 @@ func main() { } // setup flags + rootCmd.Flags().BoolVarP(&Version, "version", "", false, "print version information") + + // prompt releated rootCmd.Flags().StringVarP(&Question, "question", "q", "", "ask a single question and print the response back") rootCmd.Flags().StringVarP(&Pretext, "pretext", "p", "", "pretext to add to ChatGPT input, use 'list' or 'view:' to inspect predefined, '' to use a pretext, or otherwise supply any custom text") rootCmd.Flags().BoolVarP(&PromptMode, "interactive", "i", false, "start an interactive session with ChatGPT") - rootCmd.Flags().BoolVarP(&CleanPrompt, "clean", "c", false, "remove excess whitespace from prompt before sending") + rootCmd.Flags().BoolVarP(&EditMode, "edit", "e", false, "request an edit with ChatGPT") + rootCmd.Flags().BoolVarP(&CodeMode, "code", "c", false, "request code completion with ChatGPT") + rootCmd.Flags().BoolVarP(&CleanPrompt, "clean", "x", false, "remove excess whitespace from prompt before sending") rootCmd.Flags().BoolVarP(&WriteBack, "write", "w", false, "write response to end of context file") - rootCmd.Flags().IntVarP(&MaxTokens, "tokens", "t", 420, "set the MaxTokens to generate per response") - rootCmd.Flags().BoolVarP(&Version, "version", "", false, "print version information") + // params related + rootCmd.Flags().IntVarP(&MaxTokens, "tokens", "T", 1024, "set the MaxTokens to generate per response") + rootCmd.Flags().IntVarP(&Count, "count", "C", 1, "set the number of response options to create") + rootCmd.Flags().Float64VarP(&Temp, "temp", "", 1.0, "set the temperature parameter") + rootCmd.Flags().Float64VarP(&TopP, "topp", "", 1.0, "set the TopP parameter") + rootCmd.Flags().Float64VarP(&PresencePenalty, "pres", "", 0.0, "set the Presence Penalty parameter") + rootCmd.Flags().Float64VarP(&FrequencyPenalty, "freq", "", 0.0, "set the Frequency Penalty parameter") + + + // run the command rootCmd.Execute() } @@ -259,36 +364,163 @@ func RunPrompt(client *gpt3.Client) error { } question := scanner.Text() + parts := strings.Fields(question) + + // look for commands + switch parts[0] { + case "quit", "q", "exit": + quit = true + continue - if strings.HasPrefix(question, "save") { - parts := strings.Fields(question) + case "save": name := parts[1] - fmt.Printf("saving conversation to %s\n", name) + fmt.Printf("saving session to %s\n", name) err := os.WriteFile(name, []byte(PromptText), 0644) if err != nil { - return err + fmt.Println(err) } + continue + case "tokens": + if len(parts) == 1 { + fmt.Println("tokens is set to", MaxTokens) + continue + } + c, err := strconv.Atoi(parts[1]) + if err != nil { + fmt.Println(err) + continue + } + + MaxTokens = c + fmt.Println("tokens is now", MaxTokens) continue - } - switch question { - case "quit", "q", "exit": - quit = true + case "count": + if len(parts) == 1 { + fmt.Println("count is set to", Count) + continue + } + c, err := strconv.Atoi(parts[1]) + if err != nil { + fmt.Println(err) + continue + } + + Count = c + fmt.Println("count is now", Count) + continue + + case "temp": + if len(parts) == 1 { + fmt.Println("temp is set to", Temp) + continue + } + f, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + fmt.Println(err) + continue + } + Temp = f + fmt.Println("temp is now", Temp) + + case "topp": + if len(parts) == 1 { + fmt.Println("topp is set to", TopP) + continue + } + f, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + fmt.Println(err) + continue + } + TopP = f + fmt.Println("topp is now", TopP) + + case "pres": + if len(parts) == 1 { + fmt.Println("pres is set to", PresencePenalty) + continue + } + f, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + fmt.Println(err) + continue + } + PresencePenalty = f + fmt.Println("pres is now", PresencePenalty) + + case "freq": + if len(parts) == 1 { + fmt.Println("freq is set to", FrequencyPenalty) + continue + } + f, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + fmt.Println(err) + continue + } + FrequencyPenalty = f + fmt.Println("freq is now", FrequencyPenalty) + default: // add the question to the existing prompt text, to keep context PromptText += "\n> " + question - r, err := GetResponse(client, ctx, PromptText) + var R []string + var err error + + if CodeMode { + R, err = GetCodeResponse(client, ctx, PromptText) + } else if EditMode { + R, err = GetEditsResponse(client, ctx, PromptText, Question) + } else { + R, err = GetCompletionResponse(client, ctx, PromptText) + } if err != nil { return err } + final := "" + + if len(R) == 1 { + final = R[0] + } else { + for i, r := range R { + final += fmt.Sprintf("[%d]: %s\n\n", i, r) + } + fmt.Println(final) + ok := false + pos := 0 + + for !ok { + fmt.Print("> ") + + if !scanner.Scan() { + break + } + + ans := scanner.Text() + pos, err = strconv.Atoi(ans) + if err != nil { + fmt.Println(err) + continue + } + if pos < 0 || pos >= Count { + fmt.Println("choice must be between 0 and", Count-1) + continue + } + ok = true + } + + final = R[pos] + } + // we add response to the prompt, this is how ChatGPT sessions keep context - PromptText += "\n" + strings.TrimSpace(r) + PromptText += "\n" + strings.TrimSpace(final) // print the latest portion of the conversation - fmt.Println(r + "\n") + fmt.Println(final + "\n") } } @@ -298,15 +530,33 @@ func RunPrompt(client *gpt3.Client) error { func RunOnce(client *gpt3.Client, filename string) error { ctx := context.Background() - r, err := GetResponse(client, ctx, PromptText) + var R []string + var err error + + if CodeMode { + R, err = GetCodeResponse(client, ctx, PromptText) + } else if EditMode { + R, err = GetEditsResponse(client, ctx, PromptText, Question) + } else { + R, err = GetCompletionResponse(client, ctx, PromptText) + } if err != nil { return err } + final := "" + if len(R) == 1 { + final = R[0] + } else { + for i, r := range R { + final += fmt.Sprintf("[%d]: %s\n\n", i, r) + } + } + if filename == "" || !WriteBack { - fmt.Println(r) + fmt.Println(final) } else { - err = AppendToFile(filename, r) + err = AppendToFile(filename, final) if err != nil { return err }