mirror of
https://github.com/danielmiessler/fabric
synced 2024-11-08 07:11:06 +00:00
feat: implement -u, --user-instead-of-system: Use the user role instead of the system role for the pattern. It is needed for Open AI o1 models for now.
This commit is contained in:
parent
19a0b8a1d6
commit
329c843567
@ -195,6 +195,7 @@ Application Options:
|
|||||||
-T, --topp= Set top P (default: 0.9)
|
-T, --topp= Set top P (default: 0.9)
|
||||||
-s, --stream Stream
|
-s, --stream Stream
|
||||||
-P, --presencepenalty= Set presence penalty (default: 0.0)
|
-P, --presencepenalty= Set presence penalty (default: 0.0)
|
||||||
|
-u, --user-instead-of-system Use the user role instead of the system role for the pattern
|
||||||
-F, --frequencypenalty= Set frequency penalty (default: 0.0)
|
-F, --frequencypenalty= Set frequency penalty (default: 0.0)
|
||||||
-l, --listpatterns List all patterns
|
-l, --listpatterns List all patterns
|
||||||
-L, --listmodels List all available models
|
-L, --listmodels List all available models
|
||||||
|
10
cli/flags.go
10
cli/flags.go
@ -23,6 +23,7 @@ type Flags struct {
|
|||||||
TopP float64 `short:"T" long:"topp" description:"Set top P" default:"0.9"`
|
TopP float64 `short:"T" long:"topp" description:"Set top P" default:"0.9"`
|
||||||
Stream bool `short:"s" long:"stream" description:"Stream"`
|
Stream bool `short:"s" long:"stream" description:"Stream"`
|
||||||
PresencePenalty float64 `short:"P" long:"presencepenalty" description:"Set presence penalty" default:"0.0"`
|
PresencePenalty float64 `short:"P" long:"presencepenalty" description:"Set presence penalty" default:"0.0"`
|
||||||
|
UserInsteadOfSystemRole bool `short:"u" long:"user-instead-of-system" description:"Use the user role instead of the system role for the pattern"`
|
||||||
FrequencyPenalty float64 `short:"F" long:"frequencypenalty" description:"Set frequency penalty" default:"0.0"`
|
FrequencyPenalty float64 `short:"F" long:"frequencypenalty" description:"Set frequency penalty" default:"0.0"`
|
||||||
ListPatterns bool `short:"l" long:"listpatterns" description:"List all patterns"`
|
ListPatterns bool `short:"l" long:"listpatterns" description:"List all patterns"`
|
||||||
ListAllModels bool `short:"L" long:"listmodels" description:"List all available models"`
|
ListAllModels bool `short:"L" long:"listmodels" description:"List all available models"`
|
||||||
@ -89,10 +90,11 @@ func readStdin() (string, error) {
|
|||||||
|
|
||||||
func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
|
func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
|
||||||
ret = &common.ChatOptions{
|
ret = &common.ChatOptions{
|
||||||
Temperature: o.Temperature,
|
Temperature: o.Temperature,
|
||||||
TopP: o.TopP,
|
TopP: o.TopP,
|
||||||
PresencePenalty: o.PresencePenalty,
|
PresencePenalty: o.PresencePenalty,
|
||||||
FrequencyPenalty: o.FrequencyPenalty,
|
FrequencyPenalty: o.FrequencyPenalty,
|
||||||
|
UserInsteadOfSystemRole: o.UserInsteadOfSystemRole,
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -56,10 +56,11 @@ func TestBuildChatOptions(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
expectedOptions := &common.ChatOptions{
|
expectedOptions := &common.ChatOptions{
|
||||||
Temperature: 0.8,
|
Temperature: 0.8,
|
||||||
TopP: 0.9,
|
TopP: 0.9,
|
||||||
PresencePenalty: 0.1,
|
PresencePenalty: 0.1,
|
||||||
FrequencyPenalty: 0.2,
|
FrequencyPenalty: 0.2,
|
||||||
|
UserInsteadOfSystemRole: false,
|
||||||
}
|
}
|
||||||
options := flags.BuildChatOptions()
|
options := flags.BuildChatOptions()
|
||||||
assert.Equal(t, expectedOptions, options)
|
assert.Equal(t, expectedOptions, options)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
|
import goopenai "github.com/sashabaranov/go-openai"
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
@ -14,11 +16,12 @@ type ChatRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatOptions struct {
|
type ChatOptions struct {
|
||||||
Model string
|
Model string
|
||||||
Temperature float64
|
Temperature float64
|
||||||
TopP float64
|
TopP float64
|
||||||
PresencePenalty float64
|
PresencePenalty float64
|
||||||
FrequencyPenalty float64
|
FrequencyPenalty float64
|
||||||
|
UserInsteadOfSystemRole bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
||||||
@ -32,8 +35,8 @@ func NormalizeMessages(msgs []*Message, defaultUserMessage string) (ret []*Messa
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Ensure, that each odd position shall be a user message
|
// Ensure, that each odd position shall be a user message
|
||||||
if fullMessageIndex%2 == 0 && message.Role != "user" {
|
if fullMessageIndex%2 == 0 && message.Role != goopenai.ChatMessageRoleUser {
|
||||||
ret = append(ret, &Message{Role: "user", Content: defaultUserMessage})
|
ret = append(ret, &Message{Role: goopenai.ChatMessageRoleUser, Content: defaultUserMessage})
|
||||||
fullMessageIndex++
|
fullMessageIndex++
|
||||||
}
|
}
|
||||||
ret = append(ret, message)
|
ret = append(ret, message)
|
||||||
|
@ -1,23 +1,24 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
goopenai "github.com/sashabaranov/go-openai"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNormalizeMessages(t *testing.T) {
|
func TestNormalizeMessages(t *testing.T) {
|
||||||
msgs := []*Message{
|
msgs := []*Message{
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: goopenai.ChatMessageRoleUser, Content: "Hello"},
|
||||||
{Role: "bot", Content: "Hi there!"},
|
{Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"},
|
||||||
{Role: "bot", Content: ""},
|
{Role: goopenai.ChatMessageRoleUser, Content: ""},
|
||||||
{Role: "user", Content: ""},
|
{Role: goopenai.ChatMessageRoleUser, Content: ""},
|
||||||
{Role: "user", Content: "How are you?"},
|
{Role: goopenai.ChatMessageRoleUser, Content: "How are you?"},
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := []*Message{
|
expected := []*Message{
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: goopenai.ChatMessageRoleUser, Content: "Hello"},
|
||||||
{Role: "bot", Content: "Hi there!"},
|
{Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"},
|
||||||
{Role: "user", Content: "How are you?"},
|
{Role: goopenai.ChatMessageRoleUser, Content: "How are you?"},
|
||||||
}
|
}
|
||||||
|
|
||||||
actual := NormalizeMessages(msgs, "default")
|
actual := NormalizeMessages(msgs, "default")
|
||||||
|
@ -3,10 +3,10 @@ package core
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/danielmiessler/fabric/common"
|
"github.com/danielmiessler/fabric/common"
|
||||||
"github.com/danielmiessler/fabric/db"
|
"github.com/danielmiessler/fabric/db"
|
||||||
"github.com/danielmiessler/fabric/vendors"
|
"github.com/danielmiessler/fabric/vendors"
|
||||||
|
goopenai "github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Chatter struct {
|
type Chatter struct {
|
||||||
@ -26,7 +26,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
|
|||||||
}
|
}
|
||||||
|
|
||||||
var session *db.Session
|
var session *db.Session
|
||||||
if session, err = chatRequest.BuildChatSession(); err != nil {
|
if session, err = chatRequest.BuildChatSession(opts.UserInsteadOfSystemRole); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
|
|||||||
}
|
}
|
||||||
|
|
||||||
if chatRequest.Session != nil && message != "" {
|
if chatRequest.Session != nil && message != "" {
|
||||||
chatRequest.Session.Append(&common.Message{Role: "system", Content: message})
|
chatRequest.Session.Append(&common.Message{Role: goopenai.ChatMessageRoleAssistant, Content: message})
|
||||||
err = o.db.Sessions.SaveSession(chatRequest.Session)
|
err = o.db.Sessions.SaveSession(chatRequest.Session)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -10,7 +10,7 @@ func TestBuildChatSession(t *testing.T) {
|
|||||||
Pattern: "test pattern",
|
Pattern: "test pattern",
|
||||||
Message: "test message",
|
Message: "test message",
|
||||||
}
|
}
|
||||||
session, err := chat.BuildChatSession()
|
session, err := chat.BuildChatSession(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("BuildChatSession() error = %v", err)
|
t.Fatalf("BuildChatSession() error = %v", err)
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/danielmiessler/fabric/vendors/groq"
|
"github.com/danielmiessler/fabric/vendors/groq"
|
||||||
|
goopenai "github.com/sashabaranov/go-openai"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -236,7 +237,7 @@ func (o *Fabric) CreateOutputFile(message string, fileName string) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Chat) BuildChatSession() (ret *db.Session, err error) {
|
func (o *Chat) BuildChatSession(userInsteadOfSystemRole bool) (ret *db.Session, err error) {
|
||||||
// new messages will be appended to the session and used to send the message
|
// new messages will be appended to the session and used to send the message
|
||||||
if o.Session != nil {
|
if o.Session != nil {
|
||||||
ret = o.Session
|
ret = o.Session
|
||||||
@ -245,14 +246,20 @@ func (o *Chat) BuildChatSession() (ret *db.Session, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern)
|
systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern)
|
||||||
|
|
||||||
if systemMessage != "" {
|
|
||||||
ret.Append(&common.Message{Role: "system", Content: systemMessage})
|
|
||||||
}
|
|
||||||
|
|
||||||
userMessage := strings.TrimSpace(o.Message)
|
userMessage := strings.TrimSpace(o.Message)
|
||||||
if userMessage != "" {
|
|
||||||
ret.Append(&common.Message{Role: "user", Content: userMessage})
|
if userInsteadOfSystemRole {
|
||||||
|
message := systemMessage + userMessage
|
||||||
|
if message != "" {
|
||||||
|
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleUser, Content: message})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if systemMessage != "" {
|
||||||
|
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleSystem, Content: systemMessage})
|
||||||
|
}
|
||||||
|
if userMessage != "" {
|
||||||
|
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleUser, Content: userMessage})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if ret.IsEmpty() {
|
if ret.IsEmpty() {
|
||||||
|
5
vendors/anthropic/anthropic.go
vendored
5
vendors/anthropic/anthropic.go
vendored
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
goopenai "github.com/sashabaranov/go-openai"
|
||||||
|
|
||||||
"github.com/danielmiessler/fabric/common"
|
"github.com/danielmiessler/fabric/common"
|
||||||
"github.com/liushuangls/go-anthropic/v2"
|
"github.com/liushuangls/go-anthropic/v2"
|
||||||
@ -121,10 +122,8 @@ func (an *Client) toMessages(msgs []*common.Message) (ret []anthropic.Message) {
|
|||||||
for _, msg := range normalizedMessages {
|
for _, msg := range normalizedMessages {
|
||||||
var message anthropic.Message
|
var message anthropic.Message
|
||||||
switch msg.Role {
|
switch msg.Role {
|
||||||
case "user":
|
case goopenai.ChatMessageRoleUser:
|
||||||
message = anthropic.NewUserTextMessage(msg.Content)
|
message = anthropic.NewUserTextMessage(msg.Content)
|
||||||
case "system":
|
|
||||||
message = anthropic.NewAssistantTextMessage(msg.Content)
|
|
||||||
default:
|
default:
|
||||||
message = anthropic.NewAssistantTextMessage(msg.Content)
|
message = anthropic.NewAssistantTextMessage(msg.Content)
|
||||||
}
|
}
|
||||||
|
17
vendors/dryrun/dryrun.go
vendored
17
vendors/dryrun/dryrun.go
vendored
@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
goopenai "github.com/sashabaranov/go-openai"
|
||||||
|
|
||||||
"github.com/danielmiessler/fabric/common"
|
"github.com/danielmiessler/fabric/common"
|
||||||
)
|
)
|
||||||
@ -35,9 +36,11 @@ func (c *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
|
|||||||
|
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
switch msg.Role {
|
switch msg.Role {
|
||||||
case "system":
|
case goopenai.ChatMessageRoleSystem:
|
||||||
output += fmt.Sprintf("System:\n%s\n\n", msg.Content)
|
output += fmt.Sprintf("System:\n%s\n\n", msg.Content)
|
||||||
case "user":
|
case goopenai.ChatMessageRoleAssistant:
|
||||||
|
output += fmt.Sprintf("Assistant:\n%s\n\n", msg.Content)
|
||||||
|
case goopenai.ChatMessageRoleUser:
|
||||||
output += fmt.Sprintf("User:\n%s\n\n", msg.Content)
|
output += fmt.Sprintf("User:\n%s\n\n", msg.Content)
|
||||||
default:
|
default:
|
||||||
output += fmt.Sprintf("%s:\n%s\n\n", msg.Role, msg.Content)
|
output += fmt.Sprintf("%s:\n%s\n\n", msg.Role, msg.Content)
|
||||||
@ -56,14 +59,16 @@ func (c *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (string, error) {
|
func (c *Client) Send(_ context.Context, msgs []*common.Message, opts *common.ChatOptions) (string, error) {
|
||||||
fmt.Println("Dry run: Would send the following request:")
|
fmt.Println("Dry run: Would send the following request:")
|
||||||
|
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
switch msg.Role {
|
switch msg.Role {
|
||||||
case "system":
|
case goopenai.ChatMessageRoleSystem:
|
||||||
fmt.Printf("System:\n%s\n\n", msg.Content)
|
fmt.Printf("System:\n%s\n\n", msg.Content)
|
||||||
case "user":
|
case goopenai.ChatMessageRoleAssistant:
|
||||||
|
fmt.Printf("Assistant:\n%s\n\n", msg.Content)
|
||||||
|
case goopenai.ChatMessageRoleUser:
|
||||||
fmt.Printf("User:\n%s\n\n", msg.Content)
|
fmt.Printf("User:\n%s\n\n", msg.Content)
|
||||||
default:
|
default:
|
||||||
fmt.Printf("%s:\n%s\n\n", msg.Role, msg.Content)
|
fmt.Printf("%s:\n%s\n\n", msg.Role, msg.Content)
|
||||||
@ -84,6 +89,6 @@ func (c *Client) Setup() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) SetupFillEnvFileContent(buffer *bytes.Buffer) {
|
func (c *Client) SetupFillEnvFileContent(_ *bytes.Buffer) {
|
||||||
// No environment variables needed for dry run
|
// No environment variables needed for dry run
|
||||||
}
|
}
|
||||||
|
12
vendors/openai/openai.go
vendored
12
vendors/openai/openai.go
vendored
@ -111,17 +111,7 @@ func (o *Client) buildChatCompletionRequest(
|
|||||||
msgs []*common.Message, opts *common.ChatOptions,
|
msgs []*common.Message, opts *common.ChatOptions,
|
||||||
) (ret goopenai.ChatCompletionRequest) {
|
) (ret goopenai.ChatCompletionRequest) {
|
||||||
messages := lo.Map(msgs, func(message *common.Message, _ int) goopenai.ChatCompletionMessage {
|
messages := lo.Map(msgs, func(message *common.Message, _ int) goopenai.ChatCompletionMessage {
|
||||||
var role string
|
return goopenai.ChatCompletionMessage{Role: message.Role, Content: message.Content}
|
||||||
|
|
||||||
switch message.Role {
|
|
||||||
case "user":
|
|
||||||
role = goopenai.ChatMessageRoleUser
|
|
||||||
case "system":
|
|
||||||
role = goopenai.ChatMessageRoleSystem
|
|
||||||
default:
|
|
||||||
role = goopenai.ChatMessageRoleSystem
|
|
||||||
}
|
|
||||||
return goopenai.ChatCompletionMessage{Role: role, Content: message.Content}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
ret = goopenai.ChatCompletionRequest{
|
ret = goopenai.ChatCompletionRequest{
|
||||||
|
Loading…
Reference in New Issue
Block a user