You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
fabric/core/fabric.go

243 lines
5.6 KiB
Go

1 month ago
package core
import (
"bytes"
"fmt"
"os"
"strconv"
"strings"
"github.com/atotto/clipboard"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/db"
"github.com/danielmiessler/fabric/vendors/anthropic"
"github.com/danielmiessler/fabric/vendors/azure"
"github.com/danielmiessler/fabric/vendors/gemini"
"github.com/danielmiessler/fabric/vendors/grocq"
"github.com/danielmiessler/fabric/vendors/ollama"
"github.com/danielmiessler/fabric/vendors/openai"
"github.com/pkg/errors"
)
const (
DefaultPatternsGitRepoUrl = "https://github.com/danielmiessler/fabric.git"
DefaultPatternsGitRepoFolder = "patterns"
)
func NewFabric(db *db.Db) (ret *Fabric, err error) {
ret = NewFabricBase(db)
err = ret.Configure()
return
}
func NewFabricForSetup(db *db.Db) (ret *Fabric) {
ret = NewFabricBase(db)
_ = ret.Configure()
return
}
// NewFabricBase Create a new Fabric from a list of already configured VendorsController
func NewFabricBase(db *db.Db) (ret *Fabric) {
ret = &Fabric{
Db: db,
VendorsController: NewVendors(),
PatternsLoader: NewPatternsLoader(db.Patterns),
}
label := "Default"
ret.Configurable = &common.Configurable{
Label: label,
EnvNamePrefix: common.BuildEnvVariablePrefix(label),
ConfigureCustom: ret.configure,
}
ret.DefaultVendor = ret.AddSetting("Vendor", true)
ret.DefaultModel = ret.AddSetupQuestionCustom("Model", true,
"Enter the index the name of your default model")
ret.AddVendors(openai.NewClient(), azure.NewClient(), ollama.NewClient(), grocq.NewClient(),
gemini.NewClient(), anthropic.NewClient())
return
}
type Fabric struct {
*common.Configurable
*VendorsController
*PatternsLoader
Db *db.Db
DefaultVendor *common.Setting
DefaultModel *common.SetupQuestion
}
type ChannelName struct {
channel chan []string
name string
}
func (o *Fabric) SaveEnvFile() (err error) {
// Now create the .env with all configured VendorsController info
var envFileContent bytes.Buffer
o.Settings.FillEnvFileContent(&envFileContent)
o.PatternsLoader.FillEnvFileContent(&envFileContent)
for _, vendor := range o.Configured {
vendor.GetSettings().FillEnvFileContent(&envFileContent)
}
err = o.Db.SaveEnv(envFileContent.String())
return
}
func (o *Fabric) Setup() (err error) {
if err = o.SetupVendors(); err != nil {
return
}
if err = o.SetupDefaultModel(); err != nil {
return
}
if err = o.PatternsLoader.Setup(); err != nil {
return
}
err = o.SaveEnvFile()
return
}
func (o *Fabric) SetupDefaultModel() (err error) {
vendorsModels := o.GetModels()
vendorsModels.Print()
if err = o.Ask(o.Label); err != nil {
return
}
index, parseErr := strconv.Atoi(o.DefaultModel.Value)
if parseErr == nil {
o.DefaultVendor.Value, o.DefaultModel.Value = vendorsModels.GetVendorAndModelByModelIndex(index)
} else {
o.DefaultVendor.Value = vendorsModels.FindVendorsByModelFirst(o.DefaultModel.Value)
}
// verify
vendorNames := vendorsModels.FindVendorsByModel(o.DefaultModel.Value)
if len(vendorNames) == 0 {
err = errors.Errorf("You need to chose an available default model.")
return
}
fmt.Println()
o.DefaultVendor.Print()
o.DefaultModel.Print()
err = o.SaveEnvFile()
return
}
func (o *Fabric) SetupVendors() (err error) {
o.ResetConfigured()
for _, vendor := range o.All {
fmt.Println()
if vendorErr := vendor.Setup(); vendorErr == nil {
fmt.Printf("[%v] configured\n", vendor.GetName())
o.AddVendorConfigured(vendor)
} else {
fmt.Printf("[%v] skiped\n", vendor.GetName())
}
}
if !o.HasConfiguredVendors() {
err = errors.New("No vendors configured")
return
}
err = o.SaveEnvFile()
return
}
// Configure buildClient VendorsController based on the environment variables
func (o *Fabric) configure() (err error) {
for _, vendor := range o.All {
if vendorErr := vendor.Configure(); vendorErr == nil {
o.AddVendorConfigured(vendor)
}
}
err = o.PatternsLoader.Configure()
return
}
func (o *Fabric) GetChatter(model string, stream bool) (ret *Chatter, err error) {
ret = &Chatter{
db: o.Db,
Stream: stream,
}
if model == "" {
ret.vendor = o.FindByName(o.DefaultVendor.Value)
ret.model = o.DefaultModel.Value
} else {
ret.vendor = o.FindByName(o.GetModels().FindVendorsByModelFirst(model))
ret.model = model
}
if ret.vendor == nil {
err = fmt.Errorf(
"could not find vendor.\n Model = %s\n DefaultModel = %s\n DefaultVendor = %s",
model, o.DefaultModel.Value, o.DefaultVendor.Value)
return
}
return
}
func (o *Fabric) CopyToClipboard(message string) (err error) {
if err = clipboard.WriteAll(message); err != nil {
err = fmt.Errorf("could not copy to clipboard: %v", err)
}
return
}
func (o *Fabric) CreateOutputFile(message string, fileName string) (err error) {
var file *os.File
if file, err = os.Create(fileName); err != nil {
err = fmt.Errorf("error creating file: %v", err)
return
}
defer file.Close()
if _, err = file.WriteString(message); err != nil {
err = fmt.Errorf("error writing to file: %v", err)
}
return
}
func (o *Chat) BuildMessages() (ret []*common.Message, err error) {
if o.Session != nil && len(o.Session.Messages) > 0 {
ret = append(ret, o.Session.Messages...)
}
systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern)
if systemMessage != "" {
ret = append(ret, &common.Message{Role: "system", Content: systemMessage})
}
userMessage := strings.TrimSpace(o.Message)
if userMessage != "" {
ret = append(ret, &common.Message{Role: "user", Content: userMessage})
}
if ret == nil {
err = fmt.Errorf("no session, pattern or user messages provided")
}
return
}