mirror of
https://github.com/danielmiessler/fabric
synced 2024-11-08 07:11:06 +00:00
feat: implement -u, --user-instead-of-system: Use the user role instead of the system role for the pattern. It is needed for Open AI o1 models for now.
This commit is contained in:
parent
19a0b8a1d6
commit
329c843567
@ -195,6 +195,7 @@ Application Options:
|
||||
-T, --topp= Set top P (default: 0.9)
|
||||
-s, --stream Stream
|
||||
-P, --presencepenalty= Set presence penalty (default: 0.0)
|
||||
-u, --user-instead-of-system Use the user role instead of the system role for the pattern
|
||||
-F, --frequencypenalty= Set frequency penalty (default: 0.0)
|
||||
-l, --listpatterns List all patterns
|
||||
-L, --listmodels List all available models
|
||||
|
@ -23,6 +23,7 @@ type Flags struct {
|
||||
TopP float64 `short:"T" long:"topp" description:"Set top P" default:"0.9"`
|
||||
Stream bool `short:"s" long:"stream" description:"Stream"`
|
||||
PresencePenalty float64 `short:"P" long:"presencepenalty" description:"Set presence penalty" default:"0.0"`
|
||||
UserInsteadOfSystemRole bool `short:"u" long:"user-instead-of-system" description:"Use the user role instead of the system role for the pattern"`
|
||||
FrequencyPenalty float64 `short:"F" long:"frequencypenalty" description:"Set frequency penalty" default:"0.0"`
|
||||
ListPatterns bool `short:"l" long:"listpatterns" description:"List all patterns"`
|
||||
ListAllModels bool `short:"L" long:"listmodels" description:"List all available models"`
|
||||
@ -93,6 +94,7 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
|
||||
TopP: o.TopP,
|
||||
PresencePenalty: o.PresencePenalty,
|
||||
FrequencyPenalty: o.FrequencyPenalty,
|
||||
UserInsteadOfSystemRole: o.UserInsteadOfSystemRole,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -60,6 +60,7 @@ func TestBuildChatOptions(t *testing.T) {
|
||||
TopP: 0.9,
|
||||
PresencePenalty: 0.1,
|
||||
FrequencyPenalty: 0.2,
|
||||
UserInsteadOfSystemRole: false,
|
||||
}
|
||||
options := flags.BuildChatOptions()
|
||||
assert.Equal(t, expectedOptions, options)
|
||||
|
@ -1,5 +1,7 @@
|
||||
package common
|
||||
|
||||
import goopenai "github.com/sashabaranov/go-openai"
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
@ -19,6 +21,7 @@ type ChatOptions struct {
|
||||
TopP float64
|
||||
PresencePenalty float64
|
||||
FrequencyPenalty float64
|
||||
UserInsteadOfSystemRole bool
|
||||
}
|
||||
|
||||
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
||||
@ -32,8 +35,8 @@ func NormalizeMessages(msgs []*Message, defaultUserMessage string) (ret []*Messa
|
||||
}
|
||||
|
||||
// Ensure, that each odd position shall be a user message
|
||||
if fullMessageIndex%2 == 0 && message.Role != "user" {
|
||||
ret = append(ret, &Message{Role: "user", Content: defaultUserMessage})
|
||||
if fullMessageIndex%2 == 0 && message.Role != goopenai.ChatMessageRoleUser {
|
||||
ret = append(ret, &Message{Role: goopenai.ChatMessageRoleUser, Content: defaultUserMessage})
|
||||
fullMessageIndex++
|
||||
}
|
||||
ret = append(ret, message)
|
||||
|
@ -1,23 +1,24 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeMessages(t *testing.T) {
|
||||
msgs := []*Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "bot", Content: "Hi there!"},
|
||||
{Role: "bot", Content: ""},
|
||||
{Role: "user", Content: ""},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
{Role: goopenai.ChatMessageRoleUser, Content: "Hello"},
|
||||
{Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"},
|
||||
{Role: goopenai.ChatMessageRoleUser, Content: ""},
|
||||
{Role: goopenai.ChatMessageRoleUser, Content: ""},
|
||||
{Role: goopenai.ChatMessageRoleUser, Content: "How are you?"},
|
||||
}
|
||||
|
||||
expected := []*Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "bot", Content: "Hi there!"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
{Role: goopenai.ChatMessageRoleUser, Content: "Hello"},
|
||||
{Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"},
|
||||
{Role: goopenai.ChatMessageRoleUser, Content: "How are you?"},
|
||||
}
|
||||
|
||||
actual := NormalizeMessages(msgs, "default")
|
||||
|
@ -3,10 +3,10 @@ package core
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/db"
|
||||
"github.com/danielmiessler/fabric/vendors"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type Chatter struct {
|
||||
@ -26,7 +26,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
|
||||
}
|
||||
|
||||
var session *db.Session
|
||||
if session, err = chatRequest.BuildChatSession(); err != nil {
|
||||
if session, err = chatRequest.BuildChatSession(opts.UserInsteadOfSystemRole); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -53,7 +53,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
|
||||
}
|
||||
|
||||
if chatRequest.Session != nil && message != "" {
|
||||
chatRequest.Session.Append(&common.Message{Role: "system", Content: message})
|
||||
chatRequest.Session.Append(&common.Message{Role: goopenai.ChatMessageRoleAssistant, Content: message})
|
||||
err = o.db.Sessions.SaveSession(chatRequest.Session)
|
||||
}
|
||||
return
|
||||
|
@ -10,7 +10,7 @@ func TestBuildChatSession(t *testing.T) {
|
||||
Pattern: "test pattern",
|
||||
Message: "test message",
|
||||
}
|
||||
session, err := chat.BuildChatSession()
|
||||
session, err := chat.BuildChatSession(false)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildChatSession() error = %v", err)
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/danielmiessler/fabric/vendors/groq"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -236,7 +237,7 @@ func (o *Fabric) CreateOutputFile(message string, fileName string) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Chat) BuildChatSession() (ret *db.Session, err error) {
|
||||
func (o *Chat) BuildChatSession(userInsteadOfSystemRole bool) (ret *db.Session, err error) {
|
||||
// new messages will be appended to the session and used to send the message
|
||||
if o.Session != nil {
|
||||
ret = o.Session
|
||||
@ -245,14 +246,20 @@ func (o *Chat) BuildChatSession() (ret *db.Session, err error) {
|
||||
}
|
||||
|
||||
systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern)
|
||||
|
||||
if systemMessage != "" {
|
||||
ret.Append(&common.Message{Role: "system", Content: systemMessage})
|
||||
}
|
||||
|
||||
userMessage := strings.TrimSpace(o.Message)
|
||||
|
||||
if userInsteadOfSystemRole {
|
||||
message := systemMessage + userMessage
|
||||
if message != "" {
|
||||
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleUser, Content: message})
|
||||
}
|
||||
} else {
|
||||
if systemMessage != "" {
|
||||
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleSystem, Content: systemMessage})
|
||||
}
|
||||
if userMessage != "" {
|
||||
ret.Append(&common.Message{Role: "user", Content: userMessage})
|
||||
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleUser, Content: userMessage})
|
||||
}
|
||||
}
|
||||
|
||||
if ret.IsEmpty() {
|
||||
|
5
vendors/anthropic/anthropic.go
vendored
5
vendors/anthropic/anthropic.go
vendored
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/liushuangls/go-anthropic/v2"
|
||||
@ -121,10 +122,8 @@ func (an *Client) toMessages(msgs []*common.Message) (ret []anthropic.Message) {
|
||||
for _, msg := range normalizedMessages {
|
||||
var message anthropic.Message
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
case goopenai.ChatMessageRoleUser:
|
||||
message = anthropic.NewUserTextMessage(msg.Content)
|
||||
case "system":
|
||||
message = anthropic.NewAssistantTextMessage(msg.Content)
|
||||
default:
|
||||
message = anthropic.NewAssistantTextMessage(msg.Content)
|
||||
}
|
||||
|
17
vendors/dryrun/dryrun.go
vendored
17
vendors/dryrun/dryrun.go
vendored
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
)
|
||||
@ -35,9 +36,11 @@ func (c *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
|
||||
|
||||
for _, msg := range msgs {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
case goopenai.ChatMessageRoleSystem:
|
||||
output += fmt.Sprintf("System:\n%s\n\n", msg.Content)
|
||||
case "user":
|
||||
case goopenai.ChatMessageRoleAssistant:
|
||||
output += fmt.Sprintf("Assistant:\n%s\n\n", msg.Content)
|
||||
case goopenai.ChatMessageRoleUser:
|
||||
output += fmt.Sprintf("User:\n%s\n\n", msg.Content)
|
||||
default:
|
||||
output += fmt.Sprintf("%s:\n%s\n\n", msg.Role, msg.Content)
|
||||
@ -56,14 +59,16 @@ func (c *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (string, error) {
|
||||
func (c *Client) Send(_ context.Context, msgs []*common.Message, opts *common.ChatOptions) (string, error) {
|
||||
fmt.Println("Dry run: Would send the following request:")
|
||||
|
||||
for _, msg := range msgs {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
case goopenai.ChatMessageRoleSystem:
|
||||
fmt.Printf("System:\n%s\n\n", msg.Content)
|
||||
case "user":
|
||||
case goopenai.ChatMessageRoleAssistant:
|
||||
fmt.Printf("Assistant:\n%s\n\n", msg.Content)
|
||||
case goopenai.ChatMessageRoleUser:
|
||||
fmt.Printf("User:\n%s\n\n", msg.Content)
|
||||
default:
|
||||
fmt.Printf("%s:\n%s\n\n", msg.Role, msg.Content)
|
||||
@ -84,6 +89,6 @@ func (c *Client) Setup() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) SetupFillEnvFileContent(buffer *bytes.Buffer) {
|
||||
func (c *Client) SetupFillEnvFileContent(_ *bytes.Buffer) {
|
||||
// No environment variables needed for dry run
|
||||
}
|
||||
|
12
vendors/openai/openai.go
vendored
12
vendors/openai/openai.go
vendored
@ -111,17 +111,7 @@ func (o *Client) buildChatCompletionRequest(
|
||||
msgs []*common.Message, opts *common.ChatOptions,
|
||||
) (ret goopenai.ChatCompletionRequest) {
|
||||
messages := lo.Map(msgs, func(message *common.Message, _ int) goopenai.ChatCompletionMessage {
|
||||
var role string
|
||||
|
||||
switch message.Role {
|
||||
case "user":
|
||||
role = goopenai.ChatMessageRoleUser
|
||||
case "system":
|
||||
role = goopenai.ChatMessageRoleSystem
|
||||
default:
|
||||
role = goopenai.ChatMessageRoleSystem
|
||||
}
|
||||
return goopenai.ChatCompletionMessage{Role: role, Content: message.Content}
|
||||
return goopenai.ChatCompletionMessage{Role: message.Role, Content: message.Content}
|
||||
})
|
||||
|
||||
ret = goopenai.ChatCompletionRequest{
|
||||
|
Loading…
Reference in New Issue
Block a user