feat: implement -u, --user-instead-of-system: Use the user role instead of the system role for the pattern. It is needed for Open AI o1 models for now.

This commit is contained in:
Eugen Eisler 2024-09-15 15:09:45 +02:00
parent 19a0b8a1d6
commit 329c843567
11 changed files with 64 additions and 55 deletions

View File

@ -195,6 +195,7 @@ Application Options:
-T, --topp= Set top P (default: 0.9)
-s, --stream Stream
-P, --presencepenalty= Set presence penalty (default: 0.0)
-u, --user-instead-of-system Use the user role instead of the system role for the pattern
-F, --frequencypenalty= Set frequency penalty (default: 0.0)
-l, --listpatterns List all patterns
-L, --listmodels List all available models

View File

@ -23,6 +23,7 @@ type Flags struct {
TopP float64 `short:"T" long:"topp" description:"Set top P" default:"0.9"`
Stream bool `short:"s" long:"stream" description:"Stream"`
PresencePenalty float64 `short:"P" long:"presencepenalty" description:"Set presence penalty" default:"0.0"`
UserInsteadOfSystemRole bool `short:"u" long:"user-instead-of-system" description:"Use the user role instead of the system role for the pattern"`
FrequencyPenalty float64 `short:"F" long:"frequencypenalty" description:"Set frequency penalty" default:"0.0"`
ListPatterns bool `short:"l" long:"listpatterns" description:"List all patterns"`
ListAllModels bool `short:"L" long:"listmodels" description:"List all available models"`
@ -93,6 +94,7 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
TopP: o.TopP,
PresencePenalty: o.PresencePenalty,
FrequencyPenalty: o.FrequencyPenalty,
UserInsteadOfSystemRole: o.UserInsteadOfSystemRole,
}
return
}

View File

@ -60,6 +60,7 @@ func TestBuildChatOptions(t *testing.T) {
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
UserInsteadOfSystemRole: false,
}
options := flags.BuildChatOptions()
assert.Equal(t, expectedOptions, options)

View File

@ -1,5 +1,7 @@
package common
import goopenai "github.com/sashabaranov/go-openai"
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
@ -19,6 +21,7 @@ type ChatOptions struct {
TopP float64
PresencePenalty float64
FrequencyPenalty float64
UserInsteadOfSystemRole bool
}
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
@ -32,8 +35,8 @@ func NormalizeMessages(msgs []*Message, defaultUserMessage string) (ret []*Messa
}
// Ensure, that each odd position shall be a user message
if fullMessageIndex%2 == 0 && message.Role != "user" {
ret = append(ret, &Message{Role: "user", Content: defaultUserMessage})
if fullMessageIndex%2 == 0 && message.Role != goopenai.ChatMessageRoleUser {
ret = append(ret, &Message{Role: goopenai.ChatMessageRoleUser, Content: defaultUserMessage})
fullMessageIndex++
}
ret = append(ret, message)

View File

@ -1,23 +1,24 @@
package common
import (
goopenai "github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
"testing"
)
func TestNormalizeMessages(t *testing.T) {
msgs := []*Message{
{Role: "user", Content: "Hello"},
{Role: "bot", Content: "Hi there!"},
{Role: "bot", Content: ""},
{Role: "user", Content: ""},
{Role: "user", Content: "How are you?"},
{Role: goopenai.ChatMessageRoleUser, Content: "Hello"},
{Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"},
{Role: goopenai.ChatMessageRoleUser, Content: ""},
{Role: goopenai.ChatMessageRoleUser, Content: ""},
{Role: goopenai.ChatMessageRoleUser, Content: "How are you?"},
}
expected := []*Message{
{Role: "user", Content: "Hello"},
{Role: "bot", Content: "Hi there!"},
{Role: "user", Content: "How are you?"},
{Role: goopenai.ChatMessageRoleUser, Content: "Hello"},
{Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"},
{Role: goopenai.ChatMessageRoleUser, Content: "How are you?"},
}
actual := NormalizeMessages(msgs, "default")

View File

@ -3,10 +3,10 @@ package core
import (
"context"
"fmt"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/db"
"github.com/danielmiessler/fabric/vendors"
goopenai "github.com/sashabaranov/go-openai"
)
type Chatter struct {
@ -26,7 +26,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
}
var session *db.Session
if session, err = chatRequest.BuildChatSession(); err != nil {
if session, err = chatRequest.BuildChatSession(opts.UserInsteadOfSystemRole); err != nil {
return
}
@ -53,7 +53,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
}
if chatRequest.Session != nil && message != "" {
chatRequest.Session.Append(&common.Message{Role: "system", Content: message})
chatRequest.Session.Append(&common.Message{Role: goopenai.ChatMessageRoleAssistant, Content: message})
err = o.db.Sessions.SaveSession(chatRequest.Session)
}
return

View File

@ -10,7 +10,7 @@ func TestBuildChatSession(t *testing.T) {
Pattern: "test pattern",
Message: "test message",
}
session, err := chat.BuildChatSession()
session, err := chat.BuildChatSession(false)
if err != nil {
t.Fatalf("BuildChatSession() error = %v", err)
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"github.com/danielmiessler/fabric/vendors/groq"
goopenai "github.com/sashabaranov/go-openai"
"os"
"strconv"
"strings"
@ -236,7 +237,7 @@ func (o *Fabric) CreateOutputFile(message string, fileName string) (err error) {
return
}
func (o *Chat) BuildChatSession() (ret *db.Session, err error) {
func (o *Chat) BuildChatSession(userInsteadOfSystemRole bool) (ret *db.Session, err error) {
// new messages will be appended to the session and used to send the message
if o.Session != nil {
ret = o.Session
@ -245,14 +246,20 @@ func (o *Chat) BuildChatSession() (ret *db.Session, err error) {
}
systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern)
if systemMessage != "" {
ret.Append(&common.Message{Role: "system", Content: systemMessage})
}
userMessage := strings.TrimSpace(o.Message)
if userInsteadOfSystemRole {
message := systemMessage + userMessage
if message != "" {
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleUser, Content: message})
}
} else {
if systemMessage != "" {
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleSystem, Content: systemMessage})
}
if userMessage != "" {
ret.Append(&common.Message{Role: "user", Content: userMessage})
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleUser, Content: userMessage})
}
}
if ret.IsEmpty() {

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
goopenai "github.com/sashabaranov/go-openai"
"github.com/danielmiessler/fabric/common"
"github.com/liushuangls/go-anthropic/v2"
@ -121,10 +122,8 @@ func (an *Client) toMessages(msgs []*common.Message) (ret []anthropic.Message) {
for _, msg := range normalizedMessages {
var message anthropic.Message
switch msg.Role {
case "user":
case goopenai.ChatMessageRoleUser:
message = anthropic.NewUserTextMessage(msg.Content)
case "system":
message = anthropic.NewAssistantTextMessage(msg.Content)
default:
message = anthropic.NewAssistantTextMessage(msg.Content)
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
goopenai "github.com/sashabaranov/go-openai"
"github.com/danielmiessler/fabric/common"
)
@ -35,9 +36,11 @@ func (c *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
for _, msg := range msgs {
switch msg.Role {
case "system":
case goopenai.ChatMessageRoleSystem:
output += fmt.Sprintf("System:\n%s\n\n", msg.Content)
case "user":
case goopenai.ChatMessageRoleAssistant:
output += fmt.Sprintf("Assistant:\n%s\n\n", msg.Content)
case goopenai.ChatMessageRoleUser:
output += fmt.Sprintf("User:\n%s\n\n", msg.Content)
default:
output += fmt.Sprintf("%s:\n%s\n\n", msg.Role, msg.Content)
@ -56,14 +59,16 @@ func (c *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
return nil
}
func (c *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (string, error) {
func (c *Client) Send(_ context.Context, msgs []*common.Message, opts *common.ChatOptions) (string, error) {
fmt.Println("Dry run: Would send the following request:")
for _, msg := range msgs {
switch msg.Role {
case "system":
case goopenai.ChatMessageRoleSystem:
fmt.Printf("System:\n%s\n\n", msg.Content)
case "user":
case goopenai.ChatMessageRoleAssistant:
fmt.Printf("Assistant:\n%s\n\n", msg.Content)
case goopenai.ChatMessageRoleUser:
fmt.Printf("User:\n%s\n\n", msg.Content)
default:
fmt.Printf("%s:\n%s\n\n", msg.Role, msg.Content)
@ -84,6 +89,6 @@ func (c *Client) Setup() error {
return nil
}
func (c *Client) SetupFillEnvFileContent(buffer *bytes.Buffer) {
func (c *Client) SetupFillEnvFileContent(_ *bytes.Buffer) {
// No environment variables needed for dry run
}

View File

@ -111,17 +111,7 @@ func (o *Client) buildChatCompletionRequest(
msgs []*common.Message, opts *common.ChatOptions,
) (ret goopenai.ChatCompletionRequest) {
messages := lo.Map(msgs, func(message *common.Message, _ int) goopenai.ChatCompletionMessage {
var role string
switch message.Role {
case "user":
role = goopenai.ChatMessageRoleUser
case "system":
role = goopenai.ChatMessageRoleSystem
default:
role = goopenai.ChatMessageRoleSystem
}
return goopenai.ChatCompletionMessage{Role: role, Content: message.Content}
return goopenai.ChatCompletionMessage{Role: message.Role, Content: message.Content}
})
ret = goopenai.ChatCompletionRequest{