feat: add last changes from fabric-go; fix some Gemini problems

This commit is contained in:
Eugen Eisler 2024-08-17 00:01:55 +02:00
parent 54e5076857
commit 75ee3ac5e4
10 changed files with 159 additions and 196 deletions

View File

@ -64,7 +64,7 @@ func Cli() (message string, err error) {
return
}
if err = db.Patterns.LatestPatterns(parsedToInt); err != nil {
if err = db.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
return
}
return

View File

@ -2,7 +2,6 @@ package core
import (
"fmt"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/db"
)
@ -17,13 +16,14 @@ type Chatter struct {
}
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
var chatRequest *Chat
if chatRequest, err = o.NewChat(request); err != nil {
return
}
var messages []*common.Message
if messages, err = chatRequest.BuildMessages(); err != nil {
var session *db.Session
if session, err = chatRequest.BuildChatSession(); err != nil {
return
}
@ -34,7 +34,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
if o.Stream {
channel := make(chan string)
go func() {
if streamErr := o.vendor.SendStream(messages, opts, channel); streamErr != nil {
if streamErr := o.vendor.SendStream(session.Messages, opts, channel); streamErr != nil {
channel <- streamErr.Error()
}
}()
@ -44,26 +44,25 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
fmt.Print(response)
}
} else {
if message, err = o.vendor.Send(messages, opts); err != nil {
if message, err = o.vendor.Send(session.Messages, opts); err != nil {
return
}
}
if chatRequest.Session != nil && message != "" {
chatRequest.Session.Append(
&common.Message{Role: "system", Content: message},
&common.Message{Role: "user", Content: chatRequest.Message})
err = chatRequest.Session.Save()
chatRequest.Session.Append(&common.Message{Role: "system", Content: message})
err = o.db.Sessions.SaveSession(chatRequest.Session)
}
return
}
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
ret = &Chat{}
if request.ContextName != "" {
var ctx *db.Context
if ctx, err = o.db.Contexts.LoadContext(request.ContextName); err != nil {
if ctx, err = o.db.Contexts.GetContext(request.ContextName); err != nil {
err = fmt.Errorf("could not find context %s: %v", request.ContextName, err)
return
}
@ -72,7 +71,7 @@ func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
if request.SessionName != "" {
var sess *db.Session
if sess, err = o.db.Sessions.LoadOrCreateSession(request.SessionName); err != nil {
if sess, err = o.db.Sessions.GetOrCreateSession(request.SessionName); err != nil {
err = fmt.Errorf("could not find session %s: %v", request.SessionName, err)
return
}
@ -81,7 +80,7 @@ func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
if request.PatternName != "" {
var pattern *db.Pattern
if pattern, err = o.db.Patterns.GetByName(request.PatternName); err != nil {
if pattern, err = o.db.Patterns.GetPattern(request.PatternName); err != nil {
err = fmt.Errorf("could not find pattern %s: %v", request.PatternName, err)
return
}

View File

@ -3,10 +3,6 @@ package core
import (
"bytes"
"fmt"
"os"
"strconv"
"strings"
"github.com/atotto/clipboard"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/db"
@ -17,12 +13,13 @@ import (
"github.com/danielmiessler/fabric/vendors/ollama"
"github.com/danielmiessler/fabric/vendors/openai"
"github.com/pkg/errors"
"os"
"strconv"
"strings"
)
const (
DefaultPatternsGitRepoUrl = "https://github.com/danielmiessler/fabric.git"
DefaultPatternsGitRepoFolder = "patterns"
)
const DefaultPatternsGitRepoUrl = "https://github.com/danielmiessler/fabric.git"
const DefaultPatternsGitRepoFolder = "patterns"
func NewFabric(db *db.Db) (ret *Fabric, err error) {
ret = NewFabricBase(db)
@ -38,10 +35,12 @@ func NewFabricForSetup(db *db.Db) (ret *Fabric) {
// NewFabricBase Create a new Fabric from a list of already configured VendorsController
func NewFabricBase(db *db.Db) (ret *Fabric) {
ret = &Fabric{
Db: db,
VendorsController: NewVendors(),
PatternsLoader: NewPatternsLoader(db.Patterns),
VendorsManager: NewVendorsManager(),
Db: db,
VendorsAll: NewVendorsManager(),
PatternsLoader: NewPatternsLoader(db.Patterns),
}
label := "Default"
@ -55,7 +54,7 @@ func NewFabricBase(db *db.Db) (ret *Fabric) {
ret.DefaultModel = ret.AddSetupQuestionCustom("Model", true,
"Enter the index the name of your default model")
ret.AddVendors(openai.NewClient(), azure.NewClient(), ollama.NewClient(), grocq.NewClient(),
ret.VendorsAll.AddVendors(openai.NewClient(), azure.NewClient(), ollama.NewClient(), grocq.NewClient(),
gemini.NewClient(), anthropic.NewClient())
return
@ -63,7 +62,8 @@ func NewFabricBase(db *db.Db) (ret *Fabric) {
type Fabric struct {
*common.Configurable
*VendorsController
*VendorsManager
VendorsAll *VendorsManager
*PatternsLoader
Db *db.Db
@ -84,7 +84,7 @@ func (o *Fabric) SaveEnvFile() (err error) {
o.Settings.FillEnvFileContent(&envFileContent)
o.PatternsLoader.FillEnvFileContent(&envFileContent)
for _, vendor := range o.Configured {
for _, vendor := range o.Vendors {
vendor.GetSettings().FillEnvFileContent(&envFileContent)
}
@ -126,7 +126,7 @@ func (o *Fabric) SetupDefaultModel() (err error) {
o.DefaultVendor.Value = vendorsModels.FindVendorsByModelFirst(o.DefaultModel.Value)
}
// verify
//verify
vendorNames := vendorsModels.FindVendorsByModel(o.DefaultModel.Value)
if len(vendorNames) == 0 {
err = errors.Errorf("You need to chose an available default model.")
@ -143,19 +143,19 @@ func (o *Fabric) SetupDefaultModel() (err error) {
}
func (o *Fabric) SetupVendors() (err error) {
o.ResetConfigured()
o.Reset()
for _, vendor := range o.All {
for _, vendor := range o.VendorsAll.Vendors {
fmt.Println()
if vendorErr := vendor.Setup(); vendorErr == nil {
fmt.Printf("[%v] configured\n", vendor.GetName())
o.AddVendorConfigured(vendor)
o.AddVendors(vendor)
} else {
fmt.Printf("[%v] skiped\n", vendor.GetName())
}
}
if !o.HasConfiguredVendors() {
if !o.HasVendors() {
err = errors.New("No vendors configured")
return
}
@ -167,9 +167,9 @@ func (o *Fabric) SetupVendors() (err error) {
// Configure buildClient VendorsController based on the environment variables
func (o *Fabric) configure() (err error) {
for _, vendor := range o.All {
for _, vendor := range o.VendorsAll.Vendors {
if vendorErr := vendor.Configure(); vendorErr == nil {
o.AddVendorConfigured(vendor)
o.AddVendors(vendor)
}
}
err = o.PatternsLoader.Configure()
@ -219,23 +219,27 @@ func (o *Fabric) CreateOutputFile(message string, fileName string) (err error) {
return
}
func (o *Chat) BuildMessages() (ret []*common.Message, err error) {
if o.Session != nil && len(o.Session.Messages) > 0 {
ret = append(ret, o.Session.Messages...)
func (o *Chat) BuildChatSession() (ret *db.Session, err error) {
// new messages will be appended to the session and used to send the message
if o.Session != nil {
ret = o.Session
} else {
ret = &db.Session{}
}
systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern)
if systemMessage != "" {
ret = append(ret, &common.Message{Role: "system", Content: systemMessage})
ret.Append(&common.Message{Role: "system", Content: systemMessage})
}
userMessage := strings.TrimSpace(o.Message)
if userMessage != "" {
ret = append(ret, &common.Message{Role: "user", Content: userMessage})
ret.Append(&common.Message{Role: "user", Content: userMessage})
}
if ret == nil {
if ret.IsEmpty() {
ret = nil
err = fmt.Errorf("no session, pattern or user messages provided")
}
return

View File

@ -1,108 +1,97 @@
package core
import (
"context"
"fmt"
"sync"
"github.com/danielmiessler/fabric/common"
"sync"
)
func NewVendors() (ret *VendorsController) {
ret = &VendorsController{
All: map[string]common.Vendor{},
Configured: map[string]common.Vendor{},
func NewVendorsManager() *VendorsManager {
return &VendorsManager{
Vendors: map[string]common.Vendor{},
}
return
}
type VendorsController struct {
All map[string]common.Vendor
Configured map[string]common.Vendor
Models *VendorsModels
type VendorsManager struct {
Vendors map[string]common.Vendor
Models *VendorsModels
}
func (o *VendorsController) AddVendors(vendors ...common.Vendor) {
func (o *VendorsManager) AddVendors(vendors ...common.Vendor) {
for _, vendor := range vendors {
o.All[vendor.GetName()] = vendor
o.Vendors[vendor.GetName()] = vendor
}
}
func (o *VendorsController) AddVendorConfigured(vendor common.Vendor) {
o.Configured[vendor.GetName()] = vendor
}
func (o *VendorsController) ResetConfigured() {
o.Configured = map[string]common.Vendor{}
func (o *VendorsManager) Reset() {
o.Vendors = map[string]common.Vendor{}
o.Models = nil
return
}
func (o *VendorsController) GetModels() (ret *VendorsModels) {
func (o *VendorsManager) GetModels() *VendorsModels {
if o.Models == nil {
o.readModels()
}
ret = o.Models
return
return o.Models
}
func (o *VendorsController) HasConfiguredVendors() bool {
return len(o.Configured) > 0
func (o *VendorsManager) HasVendors() bool {
return len(o.Vendors) > 0
}
func (o *VendorsController) readModels() {
func (o *VendorsManager) FindByName(name string) common.Vendor {
return o.Vendors[name]
}
func (o *VendorsManager) readModels() {
o.Models = NewVendorsModels()
var wg sync.WaitGroup
var channels []ChannelName
resultsChan := make(chan modelResult, len(o.Vendors))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errorsChan := make(chan error, 3)
for _, vendor := range o.Configured {
// For each vendor:
// - Create a channel to collect output from the vendor model's list
// - Create a goroutine to query the vendor on its model
cn := ChannelName{channel: make(chan []string, 1), name: vendor.GetName()}
channels = append(channels, cn)
o.createGoroutine(&wg, vendor, cn, errorsChan)
for _, vendor := range o.Vendors {
wg.Add(1)
go o.fetchVendorModels(ctx, &wg, vendor, resultsChan)
}
// Let's wait for completion
wg.Wait() // Wait for all goroutines to finish
close(errorsChan)
for err := range errorsChan {
fmt.Println(err)
o.Models.AddError(err)
}
// And collect output
for _, cn := range channels {
models := <-cn.channel
if models != nil {
o.Models.AddVendorModels(cn.name, models)
}
}
return
}
func (o *VendorsController) FindByName(name string) (ret common.Vendor) {
ret = o.Configured[name]
return
}
// Create a goroutine to list models for the given vendor
func (o *VendorsController) createGoroutine(wg *sync.WaitGroup, vendor common.Vendor, cn ChannelName, errorsChan chan error) {
wg.Add(1)
// Wait for all goroutines to finish
go func() {
defer wg.Done()
models, err := vendor.ListModels()
if err != nil {
errorsChan <- err
cn.channel <- nil
} else {
cn.channel <- models
}
wg.Wait()
close(resultsChan)
}()
// Collect results
for result := range resultsChan {
if result.err != nil {
fmt.Println(result.vendorName, result.err)
o.Models.AddError(result.err)
cancel() // Cancel remaining goroutines if needed
} else {
o.Models.AddVendorModels(result.vendorName, result.models)
}
}
}
func (o *VendorsManager) fetchVendorModels(
ctx context.Context, wg *sync.WaitGroup, vendor common.Vendor, resultsChan chan<- modelResult) {
defer wg.Done()
models, err := vendor.ListModels()
select {
case <-ctx.Done():
// Context canceled, don't send the result
return
case resultsChan <- modelResult{vendorName: vendor.GetName(), models: models, err: err}:
// Result sent
}
}
type modelResult struct {
vendorName string
models []string
err error
}

View File

@ -1,19 +1,13 @@
package db
import (
"os"
)
type Contexts struct {
*Storage
}
// LoadContext Load a context from file
func (o *Contexts) LoadContext(name string) (ret *Context, err error) {
path := o.BuildFilePathByName(name)
// GetContext Load a context from file
func (o *Contexts) GetContext(name string) (ret *Context, err error) {
var content []byte
if content, err = os.ReadFile(path); err != nil {
if content, err = o.Load(name); err != nil {
return
}
@ -24,12 +18,4 @@ func (o *Contexts) LoadContext(name string) (ret *Context, err error) {
type Context struct {
Name string
Content string
contexts *Contexts
}
// Save the session on disk
func (o *Context) Save() (err error) {
err = o.contexts.Save(o.Name, []byte(o.Content))
return err
}

View File

@ -19,8 +19,12 @@ func NewDb(dir string) (db *Db) {
SystemPatternFile: "system.md",
UniquePatternsFilePath: db.FilePath("unique_patterns.txt"),
}
db.Sessions = &Sessions{&Storage{Label: "Sessions", Dir: db.FilePath("sessions")}}
db.Contexts = &Contexts{&Storage{Label: "Contexts", Dir: db.FilePath("contexts")}}
db.Sessions = &Sessions{
&Storage{Label: "Sessions", Dir: db.FilePath("sessions"), FileExtension: ".json"}}
db.Contexts = &Contexts{
&Storage{Label: "Contexts", Dir: db.FilePath("contexts")}}
return
}

View File

@ -13,8 +13,8 @@ type Patterns struct {
UniquePatternsFilePath string
}
// GetByName finds a pattern by name and returns the pattern as an entry or an error
func (o *Patterns) GetByName(name string) (ret *Pattern, err error) {
// GetPattern finds a pattern by name and returns the pattern as an entry or an error
func (o *Patterns) GetPattern(name string) (ret *Pattern, err error) {
patternPath := filepath.Join(o.Dir, name, o.SystemPatternFile)
var pattern []byte
@ -28,7 +28,7 @@ func (o *Patterns) GetByName(name string) (ret *Pattern, err error) {
return
}
func (o *Patterns) LatestPatterns(latestNumber int) (err error) {
func (o *Patterns) PrintLatestPatterns(latestNumber int) (err error) {
var contents []byte
if contents, err = os.ReadFile(o.UniquePatternsFilePath); err != nil {
err = fmt.Errorf("could not read unique patterns file. Pleas run --updatepatterns (%s)", err)

View File

@ -1,11 +1,7 @@
package db
import (
"encoding/json"
"errors"
"fmt"
"os"
"github.com/danielmiessler/fabric/common"
)
@ -13,56 +9,30 @@ type Sessions struct {
*Storage
}
func (o *Sessions) LoadOrCreateSession(name string) (ret *Session, err error) {
if name == "" {
return &Session{}, nil
}
func (o *Sessions) GetOrCreateSession(name string) (session *Session, err error) {
session = &Session{Name: name}
path := o.BuildFilePath(name)
if _, statErr := os.Stat(path); errors.Is(statErr, os.ErrNotExist) {
fmt.Printf("Creating new session: %s\n", name)
ret = &Session{Name: name, sessions: o}
if o.Exists(name) {
err = o.LoadAsJson(name, &session.Messages)
} else {
ret, err = o.loadSession(name)
fmt.Printf("Creating new session: %s\n", name)
}
return
}
// LoadSession Load a session from file
func (o *Sessions) LoadSession(name string) (ret *Session, err error) {
if name == "" {
return &Session{}, nil
}
ret, err = o.loadSession(name)
return
}
func (o *Sessions) loadSession(name string) (ret *Session, err error) {
ret = &Session{Name: name, sessions: o}
if err = o.LoadAsJson(name, &ret.Messages); err != nil {
return
}
return
func (o *Sessions) SaveSession(session *Session) (err error) {
return o.SaveAsJson(session.Name, session.Messages)
}
type Session struct {
Name string
Messages []*common.Message
}
sessions *Sessions
func (o *Session) IsEmpty() bool {
return len(o.Messages) == 0
}
func (o *Session) Append(messages ...*common.Message) {
o.Messages = append(o.Messages, messages...)
}
// Save the session on disk
func (o *Session) Save() (err error) {
var jsonBytes []byte
if jsonBytes, err = json.Marshal(o.Messages); err == nil {
err = o.sessions.Save(o.Name, jsonBytes)
} else {
err = fmt.Errorf("could not marshal session %o: %o", o.Name, err)
}
return
}

View File

@ -6,13 +6,14 @@ import (
"github.com/samber/lo"
"os"
"path/filepath"
"strings"
)
type Storage struct {
Label string
Dir string
ItemIsDir bool
ItemExtension string
FileExtension string
}
func (o *Storage) Configure() (err error) {
@ -38,12 +39,21 @@ func (o *Storage) GetNames() (ret []string, err error) {
return
})
} else {
ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) {
if ok = !item.IsDir() && filepath.Ext(item.Name()) == o.ItemExtension; ok {
ret = item.Name()
}
return
})
if o.FileExtension == "" {
ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) {
if ok = !item.IsDir(); ok {
ret = item.Name()
}
return
})
} else {
ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) {
if ok = !item.IsDir() && filepath.Ext(item.Name()) == o.FileExtension; ok {
ret = strings.TrimSuffix(item.Name(), o.FileExtension)
}
return
})
}
}
return
}
@ -77,7 +87,7 @@ func (o *Storage) BuildFilePath(fileName string) (ret string) {
}
func (o *Storage) buildFileName(name string) string {
return fmt.Sprintf("%s%v", name, o.ItemExtension)
return fmt.Sprintf("%s%v", name, o.FileExtension)
}
func (o *Storage) Delete(name string) (err error) {

View File

@ -27,8 +27,6 @@ func NewClient() (ret *Client) {
type Client struct {
*common.Configurable
ApiKey *common.SetupQuestion
client *genai.Client
}
func (ge *Client) ListModels() (ret []string, err error) {
@ -43,6 +41,9 @@ func (ge *Client) ListModels() (ret []string, err error) {
for {
var resp *genai.ModelInfo
if resp, err = iter.Next(); err != nil {
if errors.Is(err, iterator.Done) {
err = nil
}
break
}
ret = append(ret, resp.Name)
@ -60,7 +61,7 @@ func (ge *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret st
}
defer client.Close()
model := ge.client.GenerativeModel(opts.Model)
model := client.GenerativeModel(opts.Model)
model.SetTemperature(float32(opts.Temperature))
model.SetTopP(float32(opts.TopP))
model.SystemInstruction = systemInstruction
@ -128,17 +129,17 @@ func (ge *Client) extractText(response *genai.GenerateContentResponse) (ret stri
// Current implementation does not support session
// We need to retrieve the System instruction and User instruction
// Considering how we've built msgs, it's the last 2 messages
// FIXME: I know it's not clean, but will make it for now
// FIXME: Session support will need to be added
func toContent(msgs []*common.Message) (ret *genai.Content, userText string) {
sys := msgs[len(msgs)-2]
usr := msgs[len(msgs)-1]
ret = &genai.Content{
Parts: []genai.Part{
genai.Part(genai.Text(sys.Content)),
},
if len(msgs) >= 2 {
ret = &genai.Content{
Parts: []genai.Part{
genai.Part(genai.Text(msgs[0].Content)),
},
}
userText = msgs[1].Content
} else {
userText = msgs[0].Content
}
userText = usr.Content
return
}