feat: implement -u, --user-instead-of-system: Use the user role instead of the system role for the pattern. It is needed for Open AI o1 models for now.

This commit is contained in:
Eugen Eisler 2024-09-15 15:09:45 +02:00
parent 19a0b8a1d6
commit 329c843567
11 changed files with 64 additions and 55 deletions

View File

@ -195,6 +195,7 @@ Application Options:
-T, --topp= Set top P (default: 0.9) -T, --topp= Set top P (default: 0.9)
-s, --stream Stream -s, --stream Stream
-P, --presencepenalty= Set presence penalty (default: 0.0) -P, --presencepenalty= Set presence penalty (default: 0.0)
-u, --user-instead-of-system Use the user role instead of the system role for the pattern
-F, --frequencypenalty= Set frequency penalty (default: 0.0) -F, --frequencypenalty= Set frequency penalty (default: 0.0)
-l, --listpatterns List all patterns -l, --listpatterns List all patterns
-L, --listmodels List all available models -L, --listmodels List all available models

View File

@ -23,6 +23,7 @@ type Flags struct {
TopP float64 `short:"T" long:"topp" description:"Set top P" default:"0.9"` TopP float64 `short:"T" long:"topp" description:"Set top P" default:"0.9"`
Stream bool `short:"s" long:"stream" description:"Stream"` Stream bool `short:"s" long:"stream" description:"Stream"`
PresencePenalty float64 `short:"P" long:"presencepenalty" description:"Set presence penalty" default:"0.0"` PresencePenalty float64 `short:"P" long:"presencepenalty" description:"Set presence penalty" default:"0.0"`
UserInsteadOfSystemRole bool `short:"u" long:"user-instead-of-system" description:"Use the user role instead of the system role for the pattern"`
FrequencyPenalty float64 `short:"F" long:"frequencypenalty" description:"Set frequency penalty" default:"0.0"` FrequencyPenalty float64 `short:"F" long:"frequencypenalty" description:"Set frequency penalty" default:"0.0"`
ListPatterns bool `short:"l" long:"listpatterns" description:"List all patterns"` ListPatterns bool `short:"l" long:"listpatterns" description:"List all patterns"`
ListAllModels bool `short:"L" long:"listmodels" description:"List all available models"` ListAllModels bool `short:"L" long:"listmodels" description:"List all available models"`
@ -89,10 +90,11 @@ func readStdin() (string, error) {
func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) { func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
ret = &common.ChatOptions{ ret = &common.ChatOptions{
Temperature: o.Temperature, Temperature: o.Temperature,
TopP: o.TopP, TopP: o.TopP,
PresencePenalty: o.PresencePenalty, PresencePenalty: o.PresencePenalty,
FrequencyPenalty: o.FrequencyPenalty, FrequencyPenalty: o.FrequencyPenalty,
UserInsteadOfSystemRole: o.UserInsteadOfSystemRole,
} }
return return
} }

View File

@ -56,10 +56,11 @@ func TestBuildChatOptions(t *testing.T) {
} }
expectedOptions := &common.ChatOptions{ expectedOptions := &common.ChatOptions{
Temperature: 0.8, Temperature: 0.8,
TopP: 0.9, TopP: 0.9,
PresencePenalty: 0.1, PresencePenalty: 0.1,
FrequencyPenalty: 0.2, FrequencyPenalty: 0.2,
UserInsteadOfSystemRole: false,
} }
options := flags.BuildChatOptions() options := flags.BuildChatOptions()
assert.Equal(t, expectedOptions, options) assert.Equal(t, expectedOptions, options)

View File

@ -1,5 +1,7 @@
package common package common
import goopenai "github.com/sashabaranov/go-openai"
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
@ -14,11 +16,12 @@ type ChatRequest struct {
} }
type ChatOptions struct { type ChatOptions struct {
Model string Model string
Temperature float64 Temperature float64
TopP float64 TopP float64
PresencePenalty float64 PresencePenalty float64
FrequencyPenalty float64 FrequencyPenalty float64
UserInsteadOfSystemRole bool
} }
// NormalizeMessages remove empty messages and ensure messages order user-assist-user // NormalizeMessages remove empty messages and ensure messages order user-assist-user
@ -32,8 +35,8 @@ func NormalizeMessages(msgs []*Message, defaultUserMessage string) (ret []*Messa
} }
// Ensure, that each odd position shall be a user message // Ensure, that each odd position shall be a user message
if fullMessageIndex%2 == 0 && message.Role != "user" { if fullMessageIndex%2 == 0 && message.Role != goopenai.ChatMessageRoleUser {
ret = append(ret, &Message{Role: "user", Content: defaultUserMessage}) ret = append(ret, &Message{Role: goopenai.ChatMessageRoleUser, Content: defaultUserMessage})
fullMessageIndex++ fullMessageIndex++
} }
ret = append(ret, message) ret = append(ret, message)

View File

@ -1,23 +1,24 @@
package common package common
import ( import (
goopenai "github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
) )
func TestNormalizeMessages(t *testing.T) { func TestNormalizeMessages(t *testing.T) {
msgs := []*Message{ msgs := []*Message{
{Role: "user", Content: "Hello"}, {Role: goopenai.ChatMessageRoleUser, Content: "Hello"},
{Role: "bot", Content: "Hi there!"}, {Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"},
{Role: "bot", Content: ""}, {Role: goopenai.ChatMessageRoleUser, Content: ""},
{Role: "user", Content: ""}, {Role: goopenai.ChatMessageRoleUser, Content: ""},
{Role: "user", Content: "How are you?"}, {Role: goopenai.ChatMessageRoleUser, Content: "How are you?"},
} }
expected := []*Message{ expected := []*Message{
{Role: "user", Content: "Hello"}, {Role: goopenai.ChatMessageRoleUser, Content: "Hello"},
{Role: "bot", Content: "Hi there!"}, {Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"},
{Role: "user", Content: "How are you?"}, {Role: goopenai.ChatMessageRoleUser, Content: "How are you?"},
} }
actual := NormalizeMessages(msgs, "default") actual := NormalizeMessages(msgs, "default")

View File

@ -3,10 +3,10 @@ package core
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/db" "github.com/danielmiessler/fabric/db"
"github.com/danielmiessler/fabric/vendors" "github.com/danielmiessler/fabric/vendors"
goopenai "github.com/sashabaranov/go-openai"
) )
type Chatter struct { type Chatter struct {
@ -26,7 +26,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
} }
var session *db.Session var session *db.Session
if session, err = chatRequest.BuildChatSession(); err != nil { if session, err = chatRequest.BuildChatSession(opts.UserInsteadOfSystemRole); err != nil {
return return
} }
@ -53,7 +53,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
} }
if chatRequest.Session != nil && message != "" { if chatRequest.Session != nil && message != "" {
chatRequest.Session.Append(&common.Message{Role: "system", Content: message}) chatRequest.Session.Append(&common.Message{Role: goopenai.ChatMessageRoleAssistant, Content: message})
err = o.db.Sessions.SaveSession(chatRequest.Session) err = o.db.Sessions.SaveSession(chatRequest.Session)
} }
return return

View File

@ -10,7 +10,7 @@ func TestBuildChatSession(t *testing.T) {
Pattern: "test pattern", Pattern: "test pattern",
Message: "test message", Message: "test message",
} }
session, err := chat.BuildChatSession() session, err := chat.BuildChatSession(false)
if err != nil { if err != nil {
t.Fatalf("BuildChatSession() error = %v", err) t.Fatalf("BuildChatSession() error = %v", err)
} }

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/danielmiessler/fabric/vendors/groq" "github.com/danielmiessler/fabric/vendors/groq"
goopenai "github.com/sashabaranov/go-openai"
"os" "os"
"strconv" "strconv"
"strings" "strings"
@ -236,7 +237,7 @@ func (o *Fabric) CreateOutputFile(message string, fileName string) (err error) {
return return
} }
func (o *Chat) BuildChatSession() (ret *db.Session, err error) { func (o *Chat) BuildChatSession(userInsteadOfSystemRole bool) (ret *db.Session, err error) {
// new messages will be appended to the session and used to send the message // new messages will be appended to the session and used to send the message
if o.Session != nil { if o.Session != nil {
ret = o.Session ret = o.Session
@ -245,14 +246,20 @@ func (o *Chat) BuildChatSession() (ret *db.Session, err error) {
} }
systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern) systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern)
if systemMessage != "" {
ret.Append(&common.Message{Role: "system", Content: systemMessage})
}
userMessage := strings.TrimSpace(o.Message) userMessage := strings.TrimSpace(o.Message)
if userMessage != "" {
ret.Append(&common.Message{Role: "user", Content: userMessage}) if userInsteadOfSystemRole {
message := systemMessage + userMessage
if message != "" {
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleUser, Content: message})
}
} else {
if systemMessage != "" {
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleSystem, Content: systemMessage})
}
if userMessage != "" {
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleUser, Content: userMessage})
}
} }
if ret.IsEmpty() { if ret.IsEmpty() {

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
goopenai "github.com/sashabaranov/go-openai"
"github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/common"
"github.com/liushuangls/go-anthropic/v2" "github.com/liushuangls/go-anthropic/v2"
@ -121,10 +122,8 @@ func (an *Client) toMessages(msgs []*common.Message) (ret []anthropic.Message) {
for _, msg := range normalizedMessages { for _, msg := range normalizedMessages {
var message anthropic.Message var message anthropic.Message
switch msg.Role { switch msg.Role {
case "user": case goopenai.ChatMessageRoleUser:
message = anthropic.NewUserTextMessage(msg.Content) message = anthropic.NewUserTextMessage(msg.Content)
case "system":
message = anthropic.NewAssistantTextMessage(msg.Content)
default: default:
message = anthropic.NewAssistantTextMessage(msg.Content) message = anthropic.NewAssistantTextMessage(msg.Content)
} }

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
goopenai "github.com/sashabaranov/go-openai"
"github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/common"
) )
@ -35,9 +36,11 @@ func (c *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
for _, msg := range msgs { for _, msg := range msgs {
switch msg.Role { switch msg.Role {
case "system": case goopenai.ChatMessageRoleSystem:
output += fmt.Sprintf("System:\n%s\n\n", msg.Content) output += fmt.Sprintf("System:\n%s\n\n", msg.Content)
case "user": case goopenai.ChatMessageRoleAssistant:
output += fmt.Sprintf("Assistant:\n%s\n\n", msg.Content)
case goopenai.ChatMessageRoleUser:
output += fmt.Sprintf("User:\n%s\n\n", msg.Content) output += fmt.Sprintf("User:\n%s\n\n", msg.Content)
default: default:
output += fmt.Sprintf("%s:\n%s\n\n", msg.Role, msg.Content) output += fmt.Sprintf("%s:\n%s\n\n", msg.Role, msg.Content)
@ -56,14 +59,16 @@ func (c *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
return nil return nil
} }
func (c *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (string, error) { func (c *Client) Send(_ context.Context, msgs []*common.Message, opts *common.ChatOptions) (string, error) {
fmt.Println("Dry run: Would send the following request:") fmt.Println("Dry run: Would send the following request:")
for _, msg := range msgs { for _, msg := range msgs {
switch msg.Role { switch msg.Role {
case "system": case goopenai.ChatMessageRoleSystem:
fmt.Printf("System:\n%s\n\n", msg.Content) fmt.Printf("System:\n%s\n\n", msg.Content)
case "user": case goopenai.ChatMessageRoleAssistant:
fmt.Printf("Assistant:\n%s\n\n", msg.Content)
case goopenai.ChatMessageRoleUser:
fmt.Printf("User:\n%s\n\n", msg.Content) fmt.Printf("User:\n%s\n\n", msg.Content)
default: default:
fmt.Printf("%s:\n%s\n\n", msg.Role, msg.Content) fmt.Printf("%s:\n%s\n\n", msg.Role, msg.Content)
@ -84,6 +89,6 @@ func (c *Client) Setup() error {
return nil return nil
} }
func (c *Client) SetupFillEnvFileContent(buffer *bytes.Buffer) { func (c *Client) SetupFillEnvFileContent(_ *bytes.Buffer) {
// No environment variables needed for dry run // No environment variables needed for dry run
} }

View File

@ -111,17 +111,7 @@ func (o *Client) buildChatCompletionRequest(
msgs []*common.Message, opts *common.ChatOptions, msgs []*common.Message, opts *common.ChatOptions,
) (ret goopenai.ChatCompletionRequest) { ) (ret goopenai.ChatCompletionRequest) {
messages := lo.Map(msgs, func(message *common.Message, _ int) goopenai.ChatCompletionMessage { messages := lo.Map(msgs, func(message *common.Message, _ int) goopenai.ChatCompletionMessage {
var role string return goopenai.ChatCompletionMessage{Role: message.Role, Content: message.Content}
switch message.Role {
case "user":
role = goopenai.ChatMessageRoleUser
case "system":
role = goopenai.ChatMessageRoleSystem
default:
role = goopenai.ChatMessageRoleSystem
}
return goopenai.ChatCompletionMessage{Role: role, Content: message.Content}
}) })
ret = goopenai.ChatCompletionRequest{ ret = goopenai.ChatCompletionRequest{