diff --git a/chat.go b/chat.go index 6304b6d..a98aa9a 100644 --- a/chat.go +++ b/chat.go @@ -12,11 +12,53 @@ import ( "github.com/sashabaranov/go-openai" ) -func RunPrompt(client *openai.Client) error { +var interactiveHelp = `starting interactive session... + 'quit' to exit + 'save ' to preserve + 'clear' to erase the context + 'context' to see the current context + 'prompt' to set the context to a prompt + 'tokens' to change the MaxToken param + 'count' to change number of responses + 'temp' set the temperature param [0.0,2.0] + 'topp' set the TopP param [0.0,1.0] + 'pres' set the Presence Penalty [-2.0,2.0] + 'freq' set the Frequency Penalty [-2.0,2.0] + 'model' to change the selected model +` + +func RunInteractive(client *openai.Client) error { ctx := context.Background() scanner := bufio.NewScanner(os.Stdin) quit := false + // initial req setup + var req openai.ChatCompletionRequest + + // override default model in interactive | chat mode + if !(strings.HasPrefix(Model, "gpt-3") || strings.HasPrefix(Model, "gpt-4")) { + Model = "gpt-3.5-turbo-0301" + fmt.Println("using chat compatible model:", Model, "\n") + } + fmt.Println(interactiveHelp) + fmt.Println(PromptText + "\n") + + req.Model = Model + req.N = Count + req.MaxTokens = MaxTokens + req.Temperature = float32(Temp) + req.TopP = float32(TopP) + req.PresencePenalty = float32(PresencePenalty) + req.FrequencyPenalty = float32(FrequencyPenalty) + + if PromptText != "" { + req.Messages = append(req.Messages, openai.ChatCompletionMessage{ + Role: "system", + Content: PromptText, + }) + } + + // interactive loop for !quit { fmt.Print("> ") @@ -34,7 +76,7 @@ func RunPrompt(client *openai.Client) error { continue case "clear": - PromptText = "" + req.Messages = make([]openai.ChatCompletionMessage,0) case "context": fmt.Println("\n===== Current Context =====") @@ -55,6 +97,18 @@ func RunPrompt(client *openai.Client) error { // prime prompt with custom pretext fmt.Printf("setting prompt to:\n%s", p) PromptText = p + if PromptText != "" { + msg := openai.ChatCompletionMessage{ + Role: "system", + Content: PromptText, + } + // new first message or replace + if len(req.Messages) == 0 { + req.Messages = append(req.Messages, msg) + } else { + req.Messages[0] = msg + } + } case "save": name := parts[1] @@ -68,17 +122,17 @@ func RunPrompt(client *openai.Client) error { case "model": if len(parts) == 1 { - fmt.Println("model is set to", Model) + fmt.Println("model is set to", req.Model) continue } - Model = parts[1] - fmt.Println("model is now", Model) + req.Model = parts[1] + fmt.Println("model is now", req.Model) continue case "tokens": if len(parts) == 1 { - fmt.Println("tokens is set to", MaxTokens) + fmt.Println("tokens is set to", req.MaxTokens) continue } c, err := strconv.Atoi(parts[1]) @@ -87,13 +141,13 @@ func RunPrompt(client *openai.Client) error { continue } - MaxTokens = c - fmt.Println("tokens is now", MaxTokens) + req.MaxTokens = c + fmt.Println("tokens is now", req.MaxTokens) continue case "count": if len(parts) == 1 { - fmt.Println("count is set to", Count) + fmt.Println("count is set to", req.N) continue } c, err := strconv.Atoi(parts[1]) @@ -102,87 +156,85 @@ func RunPrompt(client *openai.Client) error { continue } - Count = c - fmt.Println("count is now", Count) + req.N = c + fmt.Println("count is now", req.N) continue case "temp": if len(parts) == 1 { - fmt.Println("temp is set to", Temp) + fmt.Println("temp is set to", req.Temperature) continue } - f, err := strconv.ParseFloat(parts[1], 64) + f, err := strconv.ParseFloat(parts[1], 32) if err != nil { fmt.Println(err) continue } - Temp = f - fmt.Println("temp is now", Temp) + req.Temperature = float32(f) + fmt.Println("temp is now", req.Temperature) case "topp": if len(parts) == 1 { - fmt.Println("topp is set to", TopP) + fmt.Println("topp is set to", req.TopP) continue } - f, err := strconv.ParseFloat(parts[1], 64) + f, err := strconv.ParseFloat(parts[1], 32) if err != nil { fmt.Println(err) continue } - TopP = f - fmt.Println("topp is now", TopP) + req.TopP = float32(f) + fmt.Println("topp is now", req.TopP) case "pres": if len(parts) == 1 { - fmt.Println("pres is set to", PresencePenalty) + fmt.Println("pres is set to", req.PresencePenalty) continue } - f, err := strconv.ParseFloat(parts[1], 64) + f, err := strconv.ParseFloat(parts[1], 32) if err != nil { fmt.Println(err) continue } - PresencePenalty = f - fmt.Println("pres is now", PresencePenalty) + req.PresencePenalty = float32(f) + fmt.Println("pres is now", req.PresencePenalty) case "freq": if len(parts) == 1 { - fmt.Println("freq is set to", FrequencyPenalty) + fmt.Println("freq is set to", req.FrequencyPenalty) continue } - f, err := strconv.ParseFloat(parts[1], 64) + f, err := strconv.ParseFloat(parts[1], 32) if err != nil { fmt.Println(err) continue } - FrequencyPenalty = f - fmt.Println("freq is now", FrequencyPenalty) + req.FrequencyPenalty = float32(f) + fmt.Println("freq is now", req.FrequencyPenalty) default: - // add the question to the existing prompt text, to keep context - PromptText += "\n> " + question - var R []string var err error - // TODO, chat mode? - if CodeMode { - // R, err = GetCodeResponse(client, ctx, PromptText) - } else if EditMode { - R, err = GetEditsResponse(client, ctx, PromptText, Question) - } else { - R, err = GetCompletionResponse(client, ctx, PromptText) + // add the question to the existing messages + msg := openai.ChatCompletionMessage{ + Role: "user", + Content: question, } + req.Messages = append(req.Messages, msg) + + resp, err := client.CreateChatCompletion(ctx, req) if err != nil { return err } + R := resp.Choices final := "" if len(R) == 1 { - final = R[0] + final = R[0].Message.Content } else { for i, r := range R { - final += fmt.Sprintf("[%d]: %s\n\n", i, r) + final += fmt.Sprintf("[%d]: %s\n\n", i, r.Message.Content) } fmt.Println(final) ok := false @@ -208,11 +260,15 @@ func RunPrompt(client *openai.Client) error { ok = true } - final = R[pos] + final = R[pos].Message.Content } // we add response to the prompt, this is how ChatGPT sessions keep context - PromptText += "\n" + strings.TrimSpace(final) + msg = openai.ChatCompletionMessage{ + Role: "assistant", + Content: final, + } + req.Messages = append(req.Messages, msg) // print the latest portion of the conversation fmt.Println(final + "\n") } diff --git a/get.go b/get.go index 9e1cf09..c69d663 100644 --- a/get.go +++ b/get.go @@ -7,41 +7,6 @@ import ( "github.com/sashabaranov/go-openai" ) -/* -func GetChatCompletionResponse(client *openai.Client, ctx context.Context, question string) ([]string, error) { - if CleanPrompt { - question = strings.ReplaceAll(question, "\n", " ") - question = strings.ReplaceAll(question, " ", " ") - } - // insert newline at end to prevent completion of question - if !strings.HasSuffix(question, "\n") { - question += "\n" - } - - req := openai.ChatCompletionRequest{ - Model: Model, - MaxTokens: MaxTokens, - Prompt: question, - Echo: Echo, - N: Count, - Temperature: float31(Temp), - TopP: float31(TopP), - PresencePenalty: float31(PresencePenalty), - FrequencyPenalty: float31(FrequencyPenalty), - } - resp, err := client.CreateChatCompletion(ctx, req) - if err != nil { - return nil, err - } - - var r []string - for _, c := range resp.Choices { - r = append(r, c.Text) - } - return r, nil -} -*/ - func GetCompletionResponse(client *openai.Client, ctx context.Context, question string) ([]string, error) { if CleanPrompt { question = strings.ReplaceAll(question, "\n", " ") diff --git a/main.go b/main.go index 0438b64..eeeafa0 100644 --- a/main.go +++ b/main.go @@ -66,21 +66,6 @@ Examples: ` -var interactiveHelp = `starting interactive session... - 'quit' to exit - 'save ' to preserve - 'clear' to erase the context - 'context' to see the current context - 'prompt' to set the context to a prompt - 'tokens' to change the MaxToken param - 'count' to change number of responses - 'temp' set the temperature param [0.0,2.0] - 'topp' set the TopP param [0.0,1.0] - 'pres' set the Presence Penalty [-2.0,2.0] - 'freq' set the Frequency Penalty [-2.0,2.0] - 'model' to change the selected model -` - var Version bool // prompt vars @@ -212,9 +197,7 @@ func main() { // interactive or file mode if PromptMode { - fmt.Println(interactiveHelp) - fmt.Println(PromptText) - err = RunPrompt(client) + err = RunInteractive(client) } else { // empty filename (no args) prints to stdout err = RunOnce(client, filename)