mirror of
https://github.com/danielmiessler/fabric
synced 2024-11-08 07:11:06 +00:00
feat: add last changes from fabric-go; fix some Gemini problems
This commit is contained in:
parent
54e5076857
commit
75ee3ac5e4
@ -64,7 +64,7 @@ func Cli() (message string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = db.Patterns.LatestPatterns(parsedToInt); err != nil {
|
||||
if err = db.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
|
@ -2,7 +2,6 @@ package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/db"
|
||||
)
|
||||
@ -17,13 +16,14 @@ type Chatter struct {
|
||||
}
|
||||
|
||||
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
|
||||
|
||||
var chatRequest *Chat
|
||||
if chatRequest, err = o.NewChat(request); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var messages []*common.Message
|
||||
if messages, err = chatRequest.BuildMessages(); err != nil {
|
||||
var session *db.Session
|
||||
if session, err = chatRequest.BuildChatSession(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -34,7 +34,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
|
||||
if o.Stream {
|
||||
channel := make(chan string)
|
||||
go func() {
|
||||
if streamErr := o.vendor.SendStream(messages, opts, channel); streamErr != nil {
|
||||
if streamErr := o.vendor.SendStream(session.Messages, opts, channel); streamErr != nil {
|
||||
channel <- streamErr.Error()
|
||||
}
|
||||
}()
|
||||
@ -44,26 +44,25 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
|
||||
fmt.Print(response)
|
||||
}
|
||||
} else {
|
||||
if message, err = o.vendor.Send(messages, opts); err != nil {
|
||||
if message, err = o.vendor.Send(session.Messages, opts); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if chatRequest.Session != nil && message != "" {
|
||||
chatRequest.Session.Append(
|
||||
&common.Message{Role: "system", Content: message},
|
||||
&common.Message{Role: "user", Content: chatRequest.Message})
|
||||
err = chatRequest.Session.Save()
|
||||
chatRequest.Session.Append(&common.Message{Role: "system", Content: message})
|
||||
err = o.db.Sessions.SaveSession(chatRequest.Session)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
|
||||
|
||||
ret = &Chat{}
|
||||
|
||||
if request.ContextName != "" {
|
||||
var ctx *db.Context
|
||||
if ctx, err = o.db.Contexts.LoadContext(request.ContextName); err != nil {
|
||||
if ctx, err = o.db.Contexts.GetContext(request.ContextName); err != nil {
|
||||
err = fmt.Errorf("could not find context %s: %v", request.ContextName, err)
|
||||
return
|
||||
}
|
||||
@ -72,7 +71,7 @@ func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
|
||||
|
||||
if request.SessionName != "" {
|
||||
var sess *db.Session
|
||||
if sess, err = o.db.Sessions.LoadOrCreateSession(request.SessionName); err != nil {
|
||||
if sess, err = o.db.Sessions.GetOrCreateSession(request.SessionName); err != nil {
|
||||
err = fmt.Errorf("could not find session %s: %v", request.SessionName, err)
|
||||
return
|
||||
}
|
||||
@ -81,7 +80,7 @@ func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
|
||||
|
||||
if request.PatternName != "" {
|
||||
var pattern *db.Pattern
|
||||
if pattern, err = o.db.Patterns.GetByName(request.PatternName); err != nil {
|
||||
if pattern, err = o.db.Patterns.GetPattern(request.PatternName); err != nil {
|
||||
err = fmt.Errorf("could not find pattern %s: %v", request.PatternName, err)
|
||||
return
|
||||
}
|
||||
|
@ -3,10 +3,6 @@ package core
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/atotto/clipboard"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/db"
|
||||
@ -17,12 +13,13 @@ import (
|
||||
"github.com/danielmiessler/fabric/vendors/ollama"
|
||||
"github.com/danielmiessler/fabric/vendors/openai"
|
||||
"github.com/pkg/errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultPatternsGitRepoUrl = "https://github.com/danielmiessler/fabric.git"
|
||||
DefaultPatternsGitRepoFolder = "patterns"
|
||||
)
|
||||
const DefaultPatternsGitRepoUrl = "https://github.com/danielmiessler/fabric.git"
|
||||
const DefaultPatternsGitRepoFolder = "patterns"
|
||||
|
||||
func NewFabric(db *db.Db) (ret *Fabric, err error) {
|
||||
ret = NewFabricBase(db)
|
||||
@ -38,10 +35,12 @@ func NewFabricForSetup(db *db.Db) (ret *Fabric) {
|
||||
|
||||
// NewFabricBase Create a new Fabric from a list of already configured VendorsController
|
||||
func NewFabricBase(db *db.Db) (ret *Fabric) {
|
||||
|
||||
ret = &Fabric{
|
||||
Db: db,
|
||||
VendorsController: NewVendors(),
|
||||
PatternsLoader: NewPatternsLoader(db.Patterns),
|
||||
VendorsManager: NewVendorsManager(),
|
||||
Db: db,
|
||||
VendorsAll: NewVendorsManager(),
|
||||
PatternsLoader: NewPatternsLoader(db.Patterns),
|
||||
}
|
||||
|
||||
label := "Default"
|
||||
@ -55,7 +54,7 @@ func NewFabricBase(db *db.Db) (ret *Fabric) {
|
||||
ret.DefaultModel = ret.AddSetupQuestionCustom("Model", true,
|
||||
"Enter the index the name of your default model")
|
||||
|
||||
ret.AddVendors(openai.NewClient(), azure.NewClient(), ollama.NewClient(), grocq.NewClient(),
|
||||
ret.VendorsAll.AddVendors(openai.NewClient(), azure.NewClient(), ollama.NewClient(), grocq.NewClient(),
|
||||
gemini.NewClient(), anthropic.NewClient())
|
||||
|
||||
return
|
||||
@ -63,7 +62,8 @@ func NewFabricBase(db *db.Db) (ret *Fabric) {
|
||||
|
||||
type Fabric struct {
|
||||
*common.Configurable
|
||||
*VendorsController
|
||||
*VendorsManager
|
||||
VendorsAll *VendorsManager
|
||||
*PatternsLoader
|
||||
|
||||
Db *db.Db
|
||||
@ -84,7 +84,7 @@ func (o *Fabric) SaveEnvFile() (err error) {
|
||||
o.Settings.FillEnvFileContent(&envFileContent)
|
||||
o.PatternsLoader.FillEnvFileContent(&envFileContent)
|
||||
|
||||
for _, vendor := range o.Configured {
|
||||
for _, vendor := range o.Vendors {
|
||||
vendor.GetSettings().FillEnvFileContent(&envFileContent)
|
||||
}
|
||||
|
||||
@ -126,7 +126,7 @@ func (o *Fabric) SetupDefaultModel() (err error) {
|
||||
o.DefaultVendor.Value = vendorsModels.FindVendorsByModelFirst(o.DefaultModel.Value)
|
||||
}
|
||||
|
||||
// verify
|
||||
//verify
|
||||
vendorNames := vendorsModels.FindVendorsByModel(o.DefaultModel.Value)
|
||||
if len(vendorNames) == 0 {
|
||||
err = errors.Errorf("You need to chose an available default model.")
|
||||
@ -143,19 +143,19 @@ func (o *Fabric) SetupDefaultModel() (err error) {
|
||||
}
|
||||
|
||||
func (o *Fabric) SetupVendors() (err error) {
|
||||
o.ResetConfigured()
|
||||
o.Reset()
|
||||
|
||||
for _, vendor := range o.All {
|
||||
for _, vendor := range o.VendorsAll.Vendors {
|
||||
fmt.Println()
|
||||
if vendorErr := vendor.Setup(); vendorErr == nil {
|
||||
fmt.Printf("[%v] configured\n", vendor.GetName())
|
||||
o.AddVendorConfigured(vendor)
|
||||
o.AddVendors(vendor)
|
||||
} else {
|
||||
fmt.Printf("[%v] skiped\n", vendor.GetName())
|
||||
}
|
||||
}
|
||||
|
||||
if !o.HasConfiguredVendors() {
|
||||
if !o.HasVendors() {
|
||||
err = errors.New("No vendors configured")
|
||||
return
|
||||
}
|
||||
@ -167,9 +167,9 @@ func (o *Fabric) SetupVendors() (err error) {
|
||||
|
||||
// Configure buildClient VendorsController based on the environment variables
|
||||
func (o *Fabric) configure() (err error) {
|
||||
for _, vendor := range o.All {
|
||||
for _, vendor := range o.VendorsAll.Vendors {
|
||||
if vendorErr := vendor.Configure(); vendorErr == nil {
|
||||
o.AddVendorConfigured(vendor)
|
||||
o.AddVendors(vendor)
|
||||
}
|
||||
}
|
||||
err = o.PatternsLoader.Configure()
|
||||
@ -219,23 +219,27 @@ func (o *Fabric) CreateOutputFile(message string, fileName string) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Chat) BuildMessages() (ret []*common.Message, err error) {
|
||||
if o.Session != nil && len(o.Session.Messages) > 0 {
|
||||
ret = append(ret, o.Session.Messages...)
|
||||
func (o *Chat) BuildChatSession() (ret *db.Session, err error) {
|
||||
// new messages will be appended to the session and used to send the message
|
||||
if o.Session != nil {
|
||||
ret = o.Session
|
||||
} else {
|
||||
ret = &db.Session{}
|
||||
}
|
||||
|
||||
systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern)
|
||||
|
||||
if systemMessage != "" {
|
||||
ret = append(ret, &common.Message{Role: "system", Content: systemMessage})
|
||||
ret.Append(&common.Message{Role: "system", Content: systemMessage})
|
||||
}
|
||||
|
||||
userMessage := strings.TrimSpace(o.Message)
|
||||
if userMessage != "" {
|
||||
ret = append(ret, &common.Message{Role: "user", Content: userMessage})
|
||||
ret.Append(&common.Message{Role: "user", Content: userMessage})
|
||||
}
|
||||
|
||||
if ret == nil {
|
||||
if ret.IsEmpty() {
|
||||
ret = nil
|
||||
err = fmt.Errorf("no session, pattern or user messages provided")
|
||||
}
|
||||
return
|
||||
|
135
core/vendors.go
135
core/vendors.go
@ -1,108 +1,97 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func NewVendors() (ret *VendorsController) {
|
||||
ret = &VendorsController{
|
||||
All: map[string]common.Vendor{},
|
||||
Configured: map[string]common.Vendor{},
|
||||
func NewVendorsManager() *VendorsManager {
|
||||
return &VendorsManager{
|
||||
Vendors: map[string]common.Vendor{},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type VendorsController struct {
|
||||
All map[string]common.Vendor
|
||||
Configured map[string]common.Vendor
|
||||
|
||||
Models *VendorsModels
|
||||
type VendorsManager struct {
|
||||
Vendors map[string]common.Vendor
|
||||
Models *VendorsModels
|
||||
}
|
||||
|
||||
func (o *VendorsController) AddVendors(vendors ...common.Vendor) {
|
||||
func (o *VendorsManager) AddVendors(vendors ...common.Vendor) {
|
||||
for _, vendor := range vendors {
|
||||
o.All[vendor.GetName()] = vendor
|
||||
o.Vendors[vendor.GetName()] = vendor
|
||||
}
|
||||
}
|
||||
|
||||
func (o *VendorsController) AddVendorConfigured(vendor common.Vendor) {
|
||||
o.Configured[vendor.GetName()] = vendor
|
||||
}
|
||||
|
||||
func (o *VendorsController) ResetConfigured() {
|
||||
o.Configured = map[string]common.Vendor{}
|
||||
func (o *VendorsManager) Reset() {
|
||||
o.Vendors = map[string]common.Vendor{}
|
||||
o.Models = nil
|
||||
return
|
||||
}
|
||||
|
||||
func (o *VendorsController) GetModels() (ret *VendorsModels) {
|
||||
func (o *VendorsManager) GetModels() *VendorsModels {
|
||||
if o.Models == nil {
|
||||
o.readModels()
|
||||
}
|
||||
ret = o.Models
|
||||
return
|
||||
return o.Models
|
||||
}
|
||||
|
||||
func (o *VendorsController) HasConfiguredVendors() bool {
|
||||
return len(o.Configured) > 0
|
||||
func (o *VendorsManager) HasVendors() bool {
|
||||
return len(o.Vendors) > 0
|
||||
}
|
||||
|
||||
func (o *VendorsController) readModels() {
|
||||
func (o *VendorsManager) FindByName(name string) common.Vendor {
|
||||
return o.Vendors[name]
|
||||
}
|
||||
|
||||
func (o *VendorsManager) readModels() {
|
||||
o.Models = NewVendorsModels()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var channels []ChannelName
|
||||
resultsChan := make(chan modelResult, len(o.Vendors))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errorsChan := make(chan error, 3)
|
||||
|
||||
for _, vendor := range o.Configured {
|
||||
// For each vendor:
|
||||
// - Create a channel to collect output from the vendor model's list
|
||||
// - Create a goroutine to query the vendor on its model
|
||||
cn := ChannelName{channel: make(chan []string, 1), name: vendor.GetName()}
|
||||
channels = append(channels, cn)
|
||||
o.createGoroutine(&wg, vendor, cn, errorsChan)
|
||||
for _, vendor := range o.Vendors {
|
||||
wg.Add(1)
|
||||
go o.fetchVendorModels(ctx, &wg, vendor, resultsChan)
|
||||
}
|
||||
|
||||
// Let's wait for completion
|
||||
wg.Wait() // Wait for all goroutines to finish
|
||||
close(errorsChan)
|
||||
|
||||
for err := range errorsChan {
|
||||
fmt.Println(err)
|
||||
o.Models.AddError(err)
|
||||
}
|
||||
|
||||
// And collect output
|
||||
for _, cn := range channels {
|
||||
models := <-cn.channel
|
||||
if models != nil {
|
||||
o.Models.AddVendorModels(cn.name, models)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (o *VendorsController) FindByName(name string) (ret common.Vendor) {
|
||||
ret = o.Configured[name]
|
||||
return
|
||||
}
|
||||
|
||||
// Create a goroutine to list models for the given vendor
|
||||
func (o *VendorsController) createGoroutine(wg *sync.WaitGroup, vendor common.Vendor, cn ChannelName, errorsChan chan error) {
|
||||
wg.Add(1)
|
||||
|
||||
// Wait for all goroutines to finish
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
models, err := vendor.ListModels()
|
||||
if err != nil {
|
||||
errorsChan <- err
|
||||
cn.channel <- nil
|
||||
} else {
|
||||
cn.channel <- models
|
||||
}
|
||||
wg.Wait()
|
||||
close(resultsChan)
|
||||
}()
|
||||
|
||||
// Collect results
|
||||
for result := range resultsChan {
|
||||
if result.err != nil {
|
||||
fmt.Println(result.vendorName, result.err)
|
||||
o.Models.AddError(result.err)
|
||||
cancel() // Cancel remaining goroutines if needed
|
||||
} else {
|
||||
o.Models.AddVendorModels(result.vendorName, result.models)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (o *VendorsManager) fetchVendorModels(
|
||||
ctx context.Context, wg *sync.WaitGroup, vendor common.Vendor, resultsChan chan<- modelResult) {
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
models, err := vendor.ListModels()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context canceled, don't send the result
|
||||
return
|
||||
case resultsChan <- modelResult{vendorName: vendor.GetName(), models: models, err: err}:
|
||||
// Result sent
|
||||
}
|
||||
}
|
||||
|
||||
type modelResult struct {
|
||||
vendorName string
|
||||
models []string
|
||||
err error
|
||||
}
|
||||
|
@ -1,19 +1,13 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
type Contexts struct {
|
||||
*Storage
|
||||
}
|
||||
|
||||
// LoadContext Load a context from file
|
||||
func (o *Contexts) LoadContext(name string) (ret *Context, err error) {
|
||||
path := o.BuildFilePathByName(name)
|
||||
|
||||
// GetContext Load a context from file
|
||||
func (o *Contexts) GetContext(name string) (ret *Context, err error) {
|
||||
var content []byte
|
||||
if content, err = os.ReadFile(path); err != nil {
|
||||
if content, err = o.Load(name); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -24,12 +18,4 @@ func (o *Contexts) LoadContext(name string) (ret *Context, err error) {
|
||||
type Context struct {
|
||||
Name string
|
||||
Content string
|
||||
|
||||
contexts *Contexts
|
||||
}
|
||||
|
||||
// Save the session on disk
|
||||
func (o *Context) Save() (err error) {
|
||||
err = o.contexts.Save(o.Name, []byte(o.Content))
|
||||
return err
|
||||
}
|
||||
|
8
db/db.go
8
db/db.go
@ -19,8 +19,12 @@ func NewDb(dir string) (db *Db) {
|
||||
SystemPatternFile: "system.md",
|
||||
UniquePatternsFilePath: db.FilePath("unique_patterns.txt"),
|
||||
}
|
||||
db.Sessions = &Sessions{&Storage{Label: "Sessions", Dir: db.FilePath("sessions")}}
|
||||
db.Contexts = &Contexts{&Storage{Label: "Contexts", Dir: db.FilePath("contexts")}}
|
||||
|
||||
db.Sessions = &Sessions{
|
||||
&Storage{Label: "Sessions", Dir: db.FilePath("sessions"), FileExtension: ".json"}}
|
||||
|
||||
db.Contexts = &Contexts{
|
||||
&Storage{Label: "Contexts", Dir: db.FilePath("contexts")}}
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -13,8 +13,8 @@ type Patterns struct {
|
||||
UniquePatternsFilePath string
|
||||
}
|
||||
|
||||
// GetByName finds a pattern by name and returns the pattern as an entry or an error
|
||||
func (o *Patterns) GetByName(name string) (ret *Pattern, err error) {
|
||||
// GetPattern finds a pattern by name and returns the pattern as an entry or an error
|
||||
func (o *Patterns) GetPattern(name string) (ret *Pattern, err error) {
|
||||
patternPath := filepath.Join(o.Dir, name, o.SystemPatternFile)
|
||||
|
||||
var pattern []byte
|
||||
@ -28,7 +28,7 @@ func (o *Patterns) GetByName(name string) (ret *Pattern, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Patterns) LatestPatterns(latestNumber int) (err error) {
|
||||
func (o *Patterns) PrintLatestPatterns(latestNumber int) (err error) {
|
||||
var contents []byte
|
||||
if contents, err = os.ReadFile(o.UniquePatternsFilePath); err != nil {
|
||||
err = fmt.Errorf("could not read unique patterns file. Pleas run --updatepatterns (%s)", err)
|
||||
|
@ -1,11 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
)
|
||||
|
||||
@ -13,56 +9,30 @@ type Sessions struct {
|
||||
*Storage
|
||||
}
|
||||
|
||||
func (o *Sessions) LoadOrCreateSession(name string) (ret *Session, err error) {
|
||||
if name == "" {
|
||||
return &Session{}, nil
|
||||
}
|
||||
func (o *Sessions) GetOrCreateSession(name string) (session *Session, err error) {
|
||||
session = &Session{Name: name}
|
||||
|
||||
path := o.BuildFilePath(name)
|
||||
if _, statErr := os.Stat(path); errors.Is(statErr, os.ErrNotExist) {
|
||||
fmt.Printf("Creating new session: %s\n", name)
|
||||
ret = &Session{Name: name, sessions: o}
|
||||
if o.Exists(name) {
|
||||
err = o.LoadAsJson(name, &session.Messages)
|
||||
} else {
|
||||
ret, err = o.loadSession(name)
|
||||
fmt.Printf("Creating new session: %s\n", name)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// LoadSession Load a session from file
|
||||
func (o *Sessions) LoadSession(name string) (ret *Session, err error) {
|
||||
if name == "" {
|
||||
return &Session{}, nil
|
||||
}
|
||||
ret, err = o.loadSession(name)
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Sessions) loadSession(name string) (ret *Session, err error) {
|
||||
ret = &Session{Name: name, sessions: o}
|
||||
if err = o.LoadAsJson(name, &ret.Messages); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
func (o *Sessions) SaveSession(session *Session) (err error) {
|
||||
return o.SaveAsJson(session.Name, session.Messages)
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
Name string
|
||||
Messages []*common.Message
|
||||
}
|
||||
|
||||
sessions *Sessions
|
||||
func (o *Session) IsEmpty() bool {
|
||||
return len(o.Messages) == 0
|
||||
}
|
||||
|
||||
func (o *Session) Append(messages ...*common.Message) {
|
||||
o.Messages = append(o.Messages, messages...)
|
||||
}
|
||||
|
||||
// Save the session on disk
|
||||
func (o *Session) Save() (err error) {
|
||||
var jsonBytes []byte
|
||||
if jsonBytes, err = json.Marshal(o.Messages); err == nil {
|
||||
err = o.sessions.Save(o.Name, jsonBytes)
|
||||
} else {
|
||||
err = fmt.Errorf("could not marshal session %o: %o", o.Name, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -6,13 +6,14 @@ import (
|
||||
"github.com/samber/lo"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
Label string
|
||||
Dir string
|
||||
ItemIsDir bool
|
||||
ItemExtension string
|
||||
FileExtension string
|
||||
}
|
||||
|
||||
func (o *Storage) Configure() (err error) {
|
||||
@ -38,12 +39,21 @@ func (o *Storage) GetNames() (ret []string, err error) {
|
||||
return
|
||||
})
|
||||
} else {
|
||||
ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) {
|
||||
if ok = !item.IsDir() && filepath.Ext(item.Name()) == o.ItemExtension; ok {
|
||||
ret = item.Name()
|
||||
}
|
||||
return
|
||||
})
|
||||
if o.FileExtension == "" {
|
||||
ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) {
|
||||
if ok = !item.IsDir(); ok {
|
||||
ret = item.Name()
|
||||
}
|
||||
return
|
||||
})
|
||||
} else {
|
||||
ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) {
|
||||
if ok = !item.IsDir() && filepath.Ext(item.Name()) == o.FileExtension; ok {
|
||||
ret = strings.TrimSuffix(item.Name(), o.FileExtension)
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -77,7 +87,7 @@ func (o *Storage) BuildFilePath(fileName string) (ret string) {
|
||||
}
|
||||
|
||||
func (o *Storage) buildFileName(name string) string {
|
||||
return fmt.Sprintf("%s%v", name, o.ItemExtension)
|
||||
return fmt.Sprintf("%s%v", name, o.FileExtension)
|
||||
}
|
||||
|
||||
func (o *Storage) Delete(name string) (err error) {
|
||||
|
27
vendors/gemini/gemini.go
vendored
27
vendors/gemini/gemini.go
vendored
@ -27,8 +27,6 @@ func NewClient() (ret *Client) {
|
||||
type Client struct {
|
||||
*common.Configurable
|
||||
ApiKey *common.SetupQuestion
|
||||
|
||||
client *genai.Client
|
||||
}
|
||||
|
||||
func (ge *Client) ListModels() (ret []string, err error) {
|
||||
@ -43,6 +41,9 @@ func (ge *Client) ListModels() (ret []string, err error) {
|
||||
for {
|
||||
var resp *genai.ModelInfo
|
||||
if resp, err = iter.Next(); err != nil {
|
||||
if errors.Is(err, iterator.Done) {
|
||||
err = nil
|
||||
}
|
||||
break
|
||||
}
|
||||
ret = append(ret, resp.Name)
|
||||
@ -60,7 +61,7 @@ func (ge *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret st
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
model := ge.client.GenerativeModel(opts.Model)
|
||||
model := client.GenerativeModel(opts.Model)
|
||||
model.SetTemperature(float32(opts.Temperature))
|
||||
model.SetTopP(float32(opts.TopP))
|
||||
model.SystemInstruction = systemInstruction
|
||||
@ -128,17 +129,17 @@ func (ge *Client) extractText(response *genai.GenerateContentResponse) (ret stri
|
||||
// Current implementation does not support session
|
||||
// We need to retrieve the System instruction and User instruction
|
||||
// Considering how we've built msgs, it's the last 2 messages
|
||||
// FIXME: I know it's not clean, but will make it for now
|
||||
// FIXME: Session support will need to be added
|
||||
func toContent(msgs []*common.Message) (ret *genai.Content, userText string) {
|
||||
sys := msgs[len(msgs)-2]
|
||||
usr := msgs[len(msgs)-1]
|
||||
|
||||
ret = &genai.Content{
|
||||
Parts: []genai.Part{
|
||||
genai.Part(genai.Text(sys.Content)),
|
||||
},
|
||||
if len(msgs) >= 2 {
|
||||
ret = &genai.Content{
|
||||
Parts: []genai.Part{
|
||||
genai.Part(genai.Text(msgs[0].Content)),
|
||||
},
|
||||
}
|
||||
userText = msgs[1].Content
|
||||
} else {
|
||||
userText = msgs[0].Content
|
||||
}
|
||||
userText = usr.Content
|
||||
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user