fabric/core/chatter.go

104 lines
2.2 KiB
Go
Raw Normal View History

2024-08-16 19:43:27 +00:00
package core
import (
"fmt"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/db"
"github.com/danielmiessler/fabric/vendors"
2024-08-16 19:43:27 +00:00
)
type Chatter struct {
db *db.Db
Stream bool
model string
vendor vendors.Vendor
2024-08-16 19:43:27 +00:00
}
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
2024-08-16 19:43:27 +00:00
var chatRequest *Chat
if chatRequest, err = o.NewChat(request); err != nil {
return
}
var session *db.Session
if session, err = chatRequest.BuildChatSession(); err != nil {
2024-08-16 19:43:27 +00:00
return
}
if opts.Model == "" {
opts.Model = o.model
}
if o.Stream {
channel := make(chan string)
go func() {
if streamErr := o.vendor.SendStream(session.Messages, opts, channel); streamErr != nil {
2024-08-16 19:43:27 +00:00
channel <- streamErr.Error()
}
}()
for response := range channel {
message += response
fmt.Print(response)
}
} else {
if message, err = o.vendor.Send(session.Messages, opts); err != nil {
2024-08-16 19:43:27 +00:00
return
}
}
if chatRequest.Session != nil && message != "" {
chatRequest.Session.Append(&common.Message{Role: "system", Content: message})
err = o.db.Sessions.SaveSession(chatRequest.Session)
2024-08-16 19:43:27 +00:00
}
return
}
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
2024-08-16 19:43:27 +00:00
ret = &Chat{}
if request.ContextName != "" {
var ctx *db.Context
if ctx, err = o.db.Contexts.GetContext(request.ContextName); err != nil {
2024-08-16 19:43:27 +00:00
err = fmt.Errorf("could not find context %s: %v", request.ContextName, err)
return
}
ret.Context = ctx.Content
}
if request.SessionName != "" {
var sess *db.Session
if sess, err = o.db.Sessions.GetOrCreateSession(request.SessionName); err != nil {
2024-08-16 19:43:27 +00:00
err = fmt.Errorf("could not find session %s: %v", request.SessionName, err)
return
}
ret.Session = sess
}
if request.PatternName != "" {
var pattern *db.Pattern
if pattern, err = o.db.Patterns.GetPattern(request.PatternName); err != nil {
2024-08-16 19:43:27 +00:00
err = fmt.Errorf("could not find pattern %s: %v", request.PatternName, err)
return
}
if pattern.Pattern != "" {
ret.Pattern = pattern.Pattern
}
}
ret.Message = request.Message
return
}
type Chat struct {
Context string
Pattern string
Message string
Session *db.Session
}