change interactive mode to use chat based models

pull/5/head v0.4.0
Tony Worm 1 year ago
parent 46176a0d3d
commit c50c3a3cc7

@ -12,11 +12,53 @@ import (
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )
func RunPrompt(client *openai.Client) error { var interactiveHelp = `starting interactive session...
'quit' to exit
'save <filename>' to preserve
'clear' to erase the context
'context' to see the current context
'prompt' to set the context to a prompt
'tokens' to change the MaxToken param
'count' to change number of responses
'temp' set the temperature param [0.0,2.0]
'topp' set the TopP param [0.0,1.0]
'pres' set the Presence Penalty [-2.0,2.0]
'freq' set the Frequency Penalty [-2.0,2.0]
'model' to change the selected model
`
func RunInteractive(client *openai.Client) error {
ctx := context.Background() ctx := context.Background()
scanner := bufio.NewScanner(os.Stdin) scanner := bufio.NewScanner(os.Stdin)
quit := false quit := false
// initial req setup
var req openai.ChatCompletionRequest
// override default model in interactive | chat mode
if !(strings.HasPrefix(Model, "gpt-3") || strings.HasPrefix(Model, "gpt-4")) {
Model = "gpt-3.5-turbo-0301"
fmt.Println("using chat compatible model:", Model, "\n")
}
fmt.Println(interactiveHelp)
fmt.Println(PromptText + "\n")
req.Model = Model
req.N = Count
req.MaxTokens = MaxTokens
req.Temperature = float32(Temp)
req.TopP = float32(TopP)
req.PresencePenalty = float32(PresencePenalty)
req.FrequencyPenalty = float32(FrequencyPenalty)
if PromptText != "" {
req.Messages = append(req.Messages, openai.ChatCompletionMessage{
Role: "system",
Content: PromptText,
})
}
// interactive loop
for !quit { for !quit {
fmt.Print("> ") fmt.Print("> ")
@ -34,7 +76,7 @@ func RunPrompt(client *openai.Client) error {
continue continue
case "clear": case "clear":
PromptText = "" req.Messages = make([]openai.ChatCompletionMessage,0)
case "context": case "context":
fmt.Println("\n===== Current Context =====") fmt.Println("\n===== Current Context =====")
@ -55,6 +97,18 @@ func RunPrompt(client *openai.Client) error {
// prime prompt with custom pretext // prime prompt with custom pretext
fmt.Printf("setting prompt to:\n%s", p) fmt.Printf("setting prompt to:\n%s", p)
PromptText = p PromptText = p
if PromptText != "" {
msg := openai.ChatCompletionMessage{
Role: "system",
Content: PromptText,
}
// new first message or replace
if len(req.Messages) == 0 {
req.Messages = append(req.Messages, msg)
} else {
req.Messages[0] = msg
}
}
case "save": case "save":
name := parts[1] name := parts[1]
@ -68,17 +122,17 @@ func RunPrompt(client *openai.Client) error {
case "model": case "model":
if len(parts) == 1 { if len(parts) == 1 {
fmt.Println("model is set to", Model) fmt.Println("model is set to", req.Model)
continue continue
} }
Model = parts[1] req.Model = parts[1]
fmt.Println("model is now", Model) fmt.Println("model is now", req.Model)
continue continue
case "tokens": case "tokens":
if len(parts) == 1 { if len(parts) == 1 {
fmt.Println("tokens is set to", MaxTokens) fmt.Println("tokens is set to", req.MaxTokens)
continue continue
} }
c, err := strconv.Atoi(parts[1]) c, err := strconv.Atoi(parts[1])
@ -87,13 +141,13 @@ func RunPrompt(client *openai.Client) error {
continue continue
} }
MaxTokens = c req.MaxTokens = c
fmt.Println("tokens is now", MaxTokens) fmt.Println("tokens is now", req.MaxTokens)
continue continue
case "count": case "count":
if len(parts) == 1 { if len(parts) == 1 {
fmt.Println("count is set to", Count) fmt.Println("count is set to", req.N)
continue continue
} }
c, err := strconv.Atoi(parts[1]) c, err := strconv.Atoi(parts[1])
@ -102,87 +156,85 @@ func RunPrompt(client *openai.Client) error {
continue continue
} }
Count = c req.N = c
fmt.Println("count is now", Count) fmt.Println("count is now", req.N)
continue continue
case "temp": case "temp":
if len(parts) == 1 { if len(parts) == 1 {
fmt.Println("temp is set to", Temp) fmt.Println("temp is set to", req.Temperature)
continue continue
} }
f, err := strconv.ParseFloat(parts[1], 64) f, err := strconv.ParseFloat(parts[1], 32)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
continue continue
} }
Temp = f req.Temperature = float32(f)
fmt.Println("temp is now", Temp) fmt.Println("temp is now", req.Temperature)
case "topp": case "topp":
if len(parts) == 1 { if len(parts) == 1 {
fmt.Println("topp is set to", TopP) fmt.Println("topp is set to", req.TopP)
continue continue
} }
f, err := strconv.ParseFloat(parts[1], 64) f, err := strconv.ParseFloat(parts[1], 32)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
continue continue
} }
TopP = f req.TopP = float32(f)
fmt.Println("topp is now", TopP) fmt.Println("topp is now", req.TopP)
case "pres": case "pres":
if len(parts) == 1 { if len(parts) == 1 {
fmt.Println("pres is set to", PresencePenalty) fmt.Println("pres is set to", req.PresencePenalty)
continue continue
} }
f, err := strconv.ParseFloat(parts[1], 64) f, err := strconv.ParseFloat(parts[1], 32)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
continue continue
} }
PresencePenalty = f req.PresencePenalty = float32(f)
fmt.Println("pres is now", PresencePenalty) fmt.Println("pres is now", req.PresencePenalty)
case "freq": case "freq":
if len(parts) == 1 { if len(parts) == 1 {
fmt.Println("freq is set to", FrequencyPenalty) fmt.Println("freq is set to", req.FrequencyPenalty)
continue continue
} }
f, err := strconv.ParseFloat(parts[1], 64) f, err := strconv.ParseFloat(parts[1], 32)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
continue continue
} }
FrequencyPenalty = f req.FrequencyPenalty = float32(f)
fmt.Println("freq is now", FrequencyPenalty) fmt.Println("freq is now", req.FrequencyPenalty)
default: default:
// add the question to the existing prompt text, to keep context
PromptText += "\n> " + question
var R []string
var err error var err error
// TODO, chat mode? // add the question to the existing messages
if CodeMode { msg := openai.ChatCompletionMessage{
// R, err = GetCodeResponse(client, ctx, PromptText) Role: "user",
} else if EditMode { Content: question,
R, err = GetEditsResponse(client, ctx, PromptText, Question)
} else {
R, err = GetCompletionResponse(client, ctx, PromptText)
} }
req.Messages = append(req.Messages, msg)
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil { if err != nil {
return err return err
} }
R := resp.Choices
final := "" final := ""
if len(R) == 1 { if len(R) == 1 {
final = R[0] final = R[0].Message.Content
} else { } else {
for i, r := range R { for i, r := range R {
final += fmt.Sprintf("[%d]: %s\n\n", i, r) final += fmt.Sprintf("[%d]: %s\n\n", i, r.Message.Content)
} }
fmt.Println(final) fmt.Println(final)
ok := false ok := false
@ -208,11 +260,15 @@ func RunPrompt(client *openai.Client) error {
ok = true ok = true
} }
final = R[pos] final = R[pos].Message.Content
} }
// we add response to the prompt, this is how ChatGPT sessions keep context // we add response to the prompt, this is how ChatGPT sessions keep context
PromptText += "\n" + strings.TrimSpace(final) msg = openai.ChatCompletionMessage{
Role: "assistant",
Content: final,
}
req.Messages = append(req.Messages, msg)
// print the latest portion of the conversation // print the latest portion of the conversation
fmt.Println(final + "\n") fmt.Println(final + "\n")
} }

@ -7,41 +7,6 @@ import (
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )
/*
func GetChatCompletionResponse(client *openai.Client, ctx context.Context, question string) ([]string, error) {
if CleanPrompt {
question = strings.ReplaceAll(question, "\n", " ")
question = strings.ReplaceAll(question, " ", " ")
}
// insert newline at end to prevent completion of question
if !strings.HasSuffix(question, "\n") {
question += "\n"
}
req := openai.ChatCompletionRequest{
Model: Model,
MaxTokens: MaxTokens,
Prompt: question,
Echo: Echo,
N: Count,
Temperature: float31(Temp),
TopP: float31(TopP),
PresencePenalty: float31(PresencePenalty),
FrequencyPenalty: float31(FrequencyPenalty),
}
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
return nil, err
}
var r []string
for _, c := range resp.Choices {
r = append(r, c.Text)
}
return r, nil
}
*/
func GetCompletionResponse(client *openai.Client, ctx context.Context, question string) ([]string, error) { func GetCompletionResponse(client *openai.Client, ctx context.Context, question string) ([]string, error) {
if CleanPrompt { if CleanPrompt {
question = strings.ReplaceAll(question, "\n", " ") question = strings.ReplaceAll(question, "\n", " ")

@ -66,21 +66,6 @@ Examples:
` `
var interactiveHelp = `starting interactive session...
'quit' to exit
'save <filename>' to preserve
'clear' to erase the context
'context' to see the current context
'prompt' to set the context to a prompt
'tokens' to change the MaxToken param
'count' to change number of responses
'temp' set the temperature param [0.0,2.0]
'topp' set the TopP param [0.0,1.0]
'pres' set the Presence Penalty [-2.0,2.0]
'freq' set the Frequency Penalty [-2.0,2.0]
'model' to change the selected model
`
var Version bool var Version bool
// prompt vars // prompt vars
@ -212,9 +197,7 @@ func main() {
// interactive or file mode // interactive or file mode
if PromptMode { if PromptMode {
fmt.Println(interactiveHelp) err = RunInteractive(client)
fmt.Println(PromptText)
err = RunPrompt(client)
} else { } else {
// empty filename (no args) prints to stdout // empty filename (no args) prints to stdout
err = RunOnce(client, filename) err = RunOnce(client, filename)

Loading…
Cancel
Save