diff --git a/README.md b/README.md index 54aa2ea..bbb80a4 100644 --- a/README.md +++ b/README.md @@ -195,6 +195,7 @@ Application Options: -T, --topp= Set top P (default: 0.9) -s, --stream Stream -P, --presencepenalty= Set presence penalty (default: 0.0) + -u, --user-instead-of-system Use the user role instead of the system role for the pattern -F, --frequencypenalty= Set frequency penalty (default: 0.0) -l, --listpatterns List all patterns -L, --listmodels List all available models diff --git a/cli/flags.go b/cli/flags.go index 7de2584..ce49289 100644 --- a/cli/flags.go +++ b/cli/flags.go @@ -23,6 +23,7 @@ type Flags struct { TopP float64 `short:"T" long:"topp" description:"Set top P" default:"0.9"` Stream bool `short:"s" long:"stream" description:"Stream"` PresencePenalty float64 `short:"P" long:"presencepenalty" description:"Set presence penalty" default:"0.0"` + UserInsteadOfSystemRole bool `short:"u" long:"user-instead-of-system" description:"Use the user role instead of the system role for the pattern"` FrequencyPenalty float64 `short:"F" long:"frequencypenalty" description:"Set frequency penalty" default:"0.0"` ListPatterns bool `short:"l" long:"listpatterns" description:"List all patterns"` ListAllModels bool `short:"L" long:"listmodels" description:"List all available models"` @@ -89,10 +90,11 @@ func readStdin() (string, error) { func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) { ret = &common.ChatOptions{ - Temperature: o.Temperature, - TopP: o.TopP, - PresencePenalty: o.PresencePenalty, - FrequencyPenalty: o.FrequencyPenalty, + Temperature: o.Temperature, + TopP: o.TopP, + PresencePenalty: o.PresencePenalty, + FrequencyPenalty: o.FrequencyPenalty, + UserInsteadOfSystemRole: o.UserInsteadOfSystemRole, } return } diff --git a/cli/flags_test.go b/cli/flags_test.go index 865a262..894cbf0 100644 --- a/cli/flags_test.go +++ b/cli/flags_test.go @@ -56,10 +56,11 @@ func TestBuildChatOptions(t *testing.T) { } expectedOptions := &common.ChatOptions{ - Temperature: 0.8, - TopP: 0.9, - PresencePenalty: 0.1, - FrequencyPenalty: 0.2, + Temperature: 0.8, + TopP: 0.9, + PresencePenalty: 0.1, + FrequencyPenalty: 0.2, + UserInsteadOfSystemRole: false, } options := flags.BuildChatOptions() assert.Equal(t, expectedOptions, options) diff --git a/common/domain.go b/common/domain.go index f546830..4019486 100644 --- a/common/domain.go +++ b/common/domain.go @@ -1,5 +1,7 @@ package common +import goopenai "github.com/sashabaranov/go-openai" + type Message struct { Role string `json:"role"` Content string `json:"content"` @@ -14,11 +16,12 @@ type ChatRequest struct { } type ChatOptions struct { - Model string - Temperature float64 - TopP float64 - PresencePenalty float64 - FrequencyPenalty float64 + Model string + Temperature float64 + TopP float64 + PresencePenalty float64 + FrequencyPenalty float64 + UserInsteadOfSystemRole bool } // NormalizeMessages remove empty messages and ensure messages order user-assist-user @@ -32,8 +35,8 @@ func NormalizeMessages(msgs []*Message, defaultUserMessage string) (ret []*Messa } // Ensure, that each odd position shall be a user message - if fullMessageIndex%2 == 0 && message.Role != "user" { - ret = append(ret, &Message{Role: "user", Content: defaultUserMessage}) + if fullMessageIndex%2 == 0 && message.Role != goopenai.ChatMessageRoleUser { + ret = append(ret, &Message{Role: goopenai.ChatMessageRoleUser, Content: defaultUserMessage}) fullMessageIndex++ } ret = append(ret, message) diff --git a/common/domain_test.go b/common/domain_test.go index a4b5ffe..49dcdf3 100644 --- a/common/domain_test.go +++ b/common/domain_test.go @@ -1,23 +1,24 @@ package common import ( + goopenai "github.com/sashabaranov/go-openai" "github.com/stretchr/testify/assert" "testing" ) func TestNormalizeMessages(t *testing.T) { msgs := []*Message{ - {Role: "user", Content: "Hello"}, - {Role: "bot", Content: "Hi there!"}, - {Role: "bot", Content: ""}, - {Role: "user", Content: ""}, - {Role: "user", Content: "How are you?"}, + {Role: goopenai.ChatMessageRoleUser, Content: "Hello"}, + {Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"}, + {Role: goopenai.ChatMessageRoleUser, Content: ""}, + {Role: goopenai.ChatMessageRoleUser, Content: ""}, + {Role: goopenai.ChatMessageRoleUser, Content: "How are you?"}, } expected := []*Message{ - {Role: "user", Content: "Hello"}, - {Role: "bot", Content: "Hi there!"}, - {Role: "user", Content: "How are you?"}, + {Role: goopenai.ChatMessageRoleUser, Content: "Hello"}, + {Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"}, + {Role: goopenai.ChatMessageRoleUser, Content: "How are you?"}, } actual := NormalizeMessages(msgs, "default") diff --git a/core/chatter.go b/core/chatter.go index f90a0c0..2215a6e 100644 --- a/core/chatter.go +++ b/core/chatter.go @@ -3,10 +3,10 @@ package core import ( "context" "fmt" - "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/db" "github.com/danielmiessler/fabric/vendors" + goopenai "github.com/sashabaranov/go-openai" ) type Chatter struct { @@ -26,7 +26,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m } var session *db.Session - if session, err = chatRequest.BuildChatSession(); err != nil { + if session, err = chatRequest.BuildChatSession(opts.UserInsteadOfSystemRole); err != nil { return } @@ -53,7 +53,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m } if chatRequest.Session != nil && message != "" { - chatRequest.Session.Append(&common.Message{Role: "system", Content: message}) + chatRequest.Session.Append(&common.Message{Role: goopenai.ChatMessageRoleAssistant, Content: message}) err = o.db.Sessions.SaveSession(chatRequest.Session) } return diff --git a/core/chatter_test.go b/core/chatter_test.go index 70966e7..3336e0d 100644 --- a/core/chatter_test.go +++ b/core/chatter_test.go @@ -10,7 +10,7 @@ func TestBuildChatSession(t *testing.T) { Pattern: "test pattern", Message: "test message", } - session, err := chat.BuildChatSession() + session, err := chat.BuildChatSession(false) if err != nil { t.Fatalf("BuildChatSession() error = %v", err) } diff --git a/core/fabric.go b/core/fabric.go index eec56fd..ea47298 100644 --- a/core/fabric.go +++ b/core/fabric.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "github.com/danielmiessler/fabric/vendors/groq" + goopenai "github.com/sashabaranov/go-openai" "os" "strconv" "strings" @@ -236,7 +237,7 @@ func (o *Fabric) CreateOutputFile(message string, fileName string) (err error) { return } -func (o *Chat) BuildChatSession() (ret *db.Session, err error) { +func (o *Chat) BuildChatSession(userInsteadOfSystemRole bool) (ret *db.Session, err error) { // new messages will be appended to the session and used to send the message if o.Session != nil { ret = o.Session @@ -245,14 +246,20 @@ func (o *Chat) BuildChatSession() (ret *db.Session, err error) { } systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern) - - if systemMessage != "" { - ret.Append(&common.Message{Role: "system", Content: systemMessage}) - } - userMessage := strings.TrimSpace(o.Message) - if userMessage != "" { - ret.Append(&common.Message{Role: "user", Content: userMessage}) + + if userInsteadOfSystemRole { + message := systemMessage + userMessage + if message != "" { + ret.Append(&common.Message{Role: goopenai.ChatMessageRoleUser, Content: message}) + } + } else { + if systemMessage != "" { + ret.Append(&common.Message{Role: goopenai.ChatMessageRoleSystem, Content: systemMessage}) + } + if userMessage != "" { + ret.Append(&common.Message{Role: goopenai.ChatMessageRoleUser, Content: userMessage}) + } } if ret.IsEmpty() { diff --git a/vendors/anthropic/anthropic.go b/vendors/anthropic/anthropic.go index 5f62ac0..0da9c0f 100644 --- a/vendors/anthropic/anthropic.go +++ b/vendors/anthropic/anthropic.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + goopenai "github.com/sashabaranov/go-openai" "github.com/danielmiessler/fabric/common" "github.com/liushuangls/go-anthropic/v2" @@ -121,10 +122,8 @@ func (an *Client) toMessages(msgs []*common.Message) (ret []anthropic.Message) { for _, msg := range normalizedMessages { var message anthropic.Message switch msg.Role { - case "user": + case goopenai.ChatMessageRoleUser: message = anthropic.NewUserTextMessage(msg.Content) - case "system": - message = anthropic.NewAssistantTextMessage(msg.Content) default: message = anthropic.NewAssistantTextMessage(msg.Content) } diff --git a/vendors/dryrun/dryrun.go b/vendors/dryrun/dryrun.go index c13350c..5d3d077 100644 --- a/vendors/dryrun/dryrun.go +++ b/vendors/dryrun/dryrun.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + goopenai "github.com/sashabaranov/go-openai" "github.com/danielmiessler/fabric/common" ) @@ -35,9 +36,11 @@ func (c *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch for _, msg := range msgs { switch msg.Role { - case "system": + case goopenai.ChatMessageRoleSystem: output += fmt.Sprintf("System:\n%s\n\n", msg.Content) - case "user": + case goopenai.ChatMessageRoleAssistant: + output += fmt.Sprintf("Assistant:\n%s\n\n", msg.Content) + case goopenai.ChatMessageRoleUser: output += fmt.Sprintf("User:\n%s\n\n", msg.Content) default: output += fmt.Sprintf("%s:\n%s\n\n", msg.Role, msg.Content) @@ -56,14 +59,16 @@ func (c *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch return nil } -func (c *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (string, error) { +func (c *Client) Send(_ context.Context, msgs []*common.Message, opts *common.ChatOptions) (string, error) { fmt.Println("Dry run: Would send the following request:") for _, msg := range msgs { switch msg.Role { - case "system": + case goopenai.ChatMessageRoleSystem: fmt.Printf("System:\n%s\n\n", msg.Content) - case "user": + case goopenai.ChatMessageRoleAssistant: + fmt.Printf("Assistant:\n%s\n\n", msg.Content) + case goopenai.ChatMessageRoleUser: fmt.Printf("User:\n%s\n\n", msg.Content) default: fmt.Printf("%s:\n%s\n\n", msg.Role, msg.Content) @@ -84,6 +89,6 @@ func (c *Client) Setup() error { return nil } -func (c *Client) SetupFillEnvFileContent(buffer *bytes.Buffer) { +func (c *Client) SetupFillEnvFileContent(_ *bytes.Buffer) { // No environment variables needed for dry run } diff --git a/vendors/openai/openai.go b/vendors/openai/openai.go index e9c9755..b382a97 100644 --- a/vendors/openai/openai.go +++ b/vendors/openai/openai.go @@ -111,17 +111,7 @@ func (o *Client) buildChatCompletionRequest( msgs []*common.Message, opts *common.ChatOptions, ) (ret goopenai.ChatCompletionRequest) { messages := lo.Map(msgs, func(message *common.Message, _ int) goopenai.ChatCompletionMessage { - var role string - - switch message.Role { - case "user": - role = goopenai.ChatMessageRoleUser - case "system": - role = goopenai.ChatMessageRoleSystem - default: - role = goopenai.ChatMessageRoleSystem - } - return goopenai.ChatCompletionMessage{Role: role, Content: message.Content} + return goopenai.ChatCompletionMessage{Role: message.Role, Content: message.Content} }) ret = goopenai.ChatCompletionRequest{