diff --git a/cli/cli.go b/cli/cli.go index 12b32b9..1a5d4f4 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -142,43 +142,8 @@ func Cli() (message string, err error) { } } - if currentFlags.DryRun { - var patternContent string - var contextContent string - - if currentFlags.Pattern != "" { - pattern, patternErr := fabric.Db.Patterns.GetPattern(currentFlags.Pattern) - if patternErr != nil { - fmt.Printf("Error getting pattern content: %v\n", patternErr) - return "", patternErr - } - patternContent = pattern.Pattern // Assuming the content is stored in the 'Pattern' field - } - - if currentFlags.Context != "" { - context, contextErr := fabric.Db.Contexts.GetContext(currentFlags.Context) - if contextErr != nil { - fmt.Printf("Error getting context content: %v\n", contextErr) - return "", contextErr - } - contextContent = context.Content - } - - systemMessage := strings.TrimSpace(contextContent) + strings.TrimSpace(patternContent) - userMessage := strings.TrimSpace(currentFlags.Message) - - fmt.Println("Dry run: Would send the following request:\n") - if systemMessage != "" { - fmt.Printf("System:\n%s\n\n", systemMessage) - } - if userMessage != "" { - fmt.Printf("User:\n%s\n", userMessage) - } - return "", nil - } - var chatter *core.Chatter - if chatter, err = fabric.GetChatter(currentFlags.Model, currentFlags.Stream); err != nil { + if chatter, err = fabric.GetChatter(currentFlags.Model, currentFlags.Stream, currentFlags.DryRun); err != nil { return } diff --git a/core/chatter.go b/core/chatter.go index 70123f3..12dbfd1 100644 --- a/core/chatter.go +++ b/core/chatter.go @@ -2,6 +2,7 @@ package core import ( "fmt" + "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/db" "github.com/danielmiessler/fabric/vendors" @@ -11,6 +12,7 @@ type Chatter struct { db *db.Db Stream bool + DryRun bool model string vendor vendors.Vendor diff --git a/core/fabric.go b/core/fabric.go index 7e295d2..a93edeb 100644 --- a/core/fabric.go +++ b/core/fabric.go @@ -3,20 +3,22 @@ package core import ( "bytes" "fmt" + "os" + "strconv" + "strings" + "github.com/atotto/clipboard" "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/db" "github.com/danielmiessler/fabric/vendors/anthropic" "github.com/danielmiessler/fabric/vendors/azure" + "github.com/danielmiessler/fabric/vendors/dryrun" "github.com/danielmiessler/fabric/vendors/gemini" "github.com/danielmiessler/fabric/vendors/groc" "github.com/danielmiessler/fabric/vendors/ollama" "github.com/danielmiessler/fabric/vendors/openai" "github.com/danielmiessler/fabric/youtube" "github.com/pkg/errors" - "os" - "strconv" - "strings" ) const DefaultPatternsGitRepoUrl = "https://github.com/danielmiessler/fabric.git" @@ -57,7 +59,7 @@ func NewFabricBase(db *db.Db) (ret *Fabric) { "Enter the index the name of your default model") ret.VendorsAll.AddVendors(openai.NewClient(), azure.NewClient(), ollama.NewClient(), groc.NewClient(), - gemini.NewClient(), anthropic.NewClient()) + gemini.NewClient(), anthropic.NewClient(), dryrun.NewClient()) return } @@ -182,13 +184,20 @@ func (o *Fabric) configure() (err error) { return } -func (o *Fabric) GetChatter(model string, stream bool) (ret *Chatter, err error) { +func (o *Fabric) GetChatter(model string, stream bool, dryRun bool) (ret *Chatter, err error) { ret = &Chatter{ db: o.Db, Stream: stream, + DryRun: dryRun, } - if model == "" { + if dryRun { + ret.vendor = dryrun.NewClient() + ret.model = model + if ret.model == "" { + ret.model = o.DefaultModel.Value + } + } else if model == "" { ret.vendor = o.FindByName(o.DefaultVendor.Value) ret.model = o.DefaultModel.Value } else { diff --git a/core/models.go b/core/models.go index 980508e..2eaf775 100644 --- a/core/models.go +++ b/core/models.go @@ -16,8 +16,10 @@ type VendorsModels struct { } func (o *VendorsModels) AddVendorModels(vendor string, models []string) { - o.Vendors = append(o.Vendors, vendor) - o.VendorsModels[vendor] = models + if vendor != "DryRun" { + o.Vendors = append(o.Vendors, vendor) + o.VendorsModels[vendor] = models + } } func (o *VendorsModels) GetVendorAndModelByModelIndex(modelIndex int) (vendor string, model string) { diff --git a/vendors/dryrun/dryrun.go b/vendors/dryrun/dryrun.go new file mode 100644 index 0000000..0d2e246 --- /dev/null +++ b/vendors/dryrun/dryrun.go @@ -0,0 +1,88 @@ +package dryrun + +import ( + "bytes" + "fmt" + + "github.com/danielmiessler/fabric/common" +) + +type Client struct{} + +func NewClient() *Client { + return &Client{} +} + +func (c *Client) GetName() string { + return "DryRun" +} + +func (c *Client) IsConfigured() bool { + return true +} + +func (c *Client) Configure() error { + return nil +} + +func (c *Client) ListModels() ([]string, error) { + return []string{"dry-run-model"}, nil +} + +func (c *Client) SendStream(messages []*common.Message, options *common.ChatOptions, channel chan string) error { + output := "Dry run: Would send the following request:\n\n" + + for _, msg := range messages { + switch msg.Role { + case "system": + output += fmt.Sprintf("System:\n%s\n\n", msg.Content) + case "user": + output += fmt.Sprintf("User:\n%s\n\n", msg.Content) + default: + output += fmt.Sprintf("%s:\n%s\n\n", msg.Role, msg.Content) + } + } + + output += "Options:\n" + output += fmt.Sprintf("Model: %s\n", options.Model) + output += fmt.Sprintf("Temperature: %f\n", options.Temperature) + output += fmt.Sprintf("TopP: %f\n", options.TopP) + output += fmt.Sprintf("PresencePenalty: %f\n", options.PresencePenalty) + output += fmt.Sprintf("FrequencyPenalty: %f\n", options.FrequencyPenalty) + + channel <- output + close(channel) + return nil +} + +func (c *Client) Send(messages []*common.Message, options *common.ChatOptions) (string, error) { + fmt.Println("Dry run: Would send the following request:") + + for _, msg := range messages { + switch msg.Role { + case "system": + fmt.Printf("System:\n%s\n\n", msg.Content) + case "user": + fmt.Printf("User:\n%s\n\n", msg.Content) + default: + fmt.Printf("%s:\n%s\n\n", msg.Role, msg.Content) + } + } + + fmt.Println("Options:") + fmt.Printf("Model: %s\n", options.Model) + fmt.Printf("Temperature: %f\n", options.Temperature) + fmt.Printf("TopP: %f\n", options.TopP) + fmt.Printf("PresencePenalty: %f\n", options.PresencePenalty) + fmt.Printf("FrequencyPenalty: %f\n", options.FrequencyPenalty) + + return "", nil +} + +func (c *Client) Setup() error { + return nil +} + +func (c *Client) SetupFillEnvFileContent(buffer *bytes.Buffer) { + // No environment variables needed for dry run +}