From 4b3afb3c8ef004855ab82ac9f3dba08bdf76cbac Mon Sep 17 00:00:00 2001 From: Eugen Eisler Date: Thu, 22 Aug 2024 21:45:36 +0200 Subject: [PATCH] feat: simplify setup logic --- cli/cli.go | 30 +++++++-------- cli/cli_test.go | 23 ++++++++++++ cli/flags_test.go | 85 ++++++++++++++++++++++++++++++++++++++++++ common/configurable.go | 7 ++++ core/fabric.go | 17 ++------- core/vendors.go | 19 +++++++--- 6 files changed, 148 insertions(+), 33 deletions(-) create mode 100644 cli/cli_test.go create mode 100644 cli/flags_test.go diff --git a/cli/cli.go b/cli/cli.go index 6028e65..bb6b773 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -14,7 +14,7 @@ import ( func Cli() (message string, err error) { var currentFlags *Flags if currentFlags, err = Init(); err != nil { - // we need to reset error, because we want to show double help messages + // we need to reset error, because we don't want to show double help messages err = nil return } @@ -24,23 +24,23 @@ func Cli() (message string, err error) { return } - db := db.NewDb(filepath.Join(homedir, ".config/fabric")) + fabricDb := db.NewDb(filepath.Join(homedir, ".config/fabric")) // if the setup flag is set, run the setup function if currentFlags.Setup { - _ = db.Configure() - _, err = Setup(db, currentFlags.SetupSkipUpdatePatterns) + _ = fabricDb.Configure() + _, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns) return } var fabric *core.Fabric - if err = db.Configure(); err != nil { + if err = fabricDb.Configure(); err != nil { fmt.Println("init is failed, run start the setup procedure", err) - if fabric, err = Setup(db, currentFlags.SetupSkipUpdatePatterns); err != nil { + if fabric, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns); err != nil { return } } else { - if fabric, err = core.NewFabric(db); err != nil { + if fabric, err = core.NewFabric(fabricDb); err != nil { fmt.Println("fabric can't initialize, please run the --setup procedure", err) return } @@ -64,7 +64,7 @@ func Cli() (message string, err error) { return } - if err = db.Patterns.PrintLatestPatterns(parsedToInt); err != nil { + if err = fabricDb.Patterns.PrintLatestPatterns(parsedToInt); err != nil { return } return @@ -72,7 +72,7 @@ func Cli() (message string, err error) { // if the list patterns flag is set, run the list all patterns function if currentFlags.ListPatterns { - err = db.Patterns.ListNames() + err = fabricDb.Patterns.ListNames() return } @@ -84,13 +84,13 @@ func Cli() (message string, err error) { // if the list all contexts flag is set, run the list all contexts function if currentFlags.ListAllContexts { - err = db.Contexts.ListNames() + err = fabricDb.Contexts.ListNames() return } // if the list all sessions flag is set, run the list all sessions function if currentFlags.ListAllSessions { - err = db.Sessions.ListNames() + err = fabricDb.Sessions.ListNames() return } @@ -129,17 +129,17 @@ func Cli() (message string, err error) { } func Setup(db *db.Db, skipUpdatePatterns bool) (ret *core.Fabric, err error) { - ret = core.NewFabricForSetup(db) + instance := core.NewFabricForSetup(db) - if err = ret.Setup(); err != nil { + if err = instance.Setup(); err != nil { return } if !skipUpdatePatterns { - if err = ret.PopulateDB(); err != nil { + if err = instance.PopulateDB(); err != nil { return } } - + ret = instance return } diff --git a/cli/cli_test.go b/cli/cli_test.go new file mode 100644 index 0000000..95b8701 --- /dev/null +++ b/cli/cli_test.go @@ -0,0 +1,23 @@ +package cli + +import ( + "os" + "testing" + + "github.com/danielmiessler/fabric/db" + "github.com/stretchr/testify/assert" +) + +func TestCli(t *testing.T) { + message, err := Cli() + assert.NoError(t, err) + assert.Empty(t, message) +} + +func TestSetup(t *testing.T) { + mockDB := db.NewDb(os.TempDir()) + + fabric, err := Setup(mockDB, false) + assert.Error(t, err) + assert.Nil(t, fabric) +} diff --git a/cli/flags_test.go b/cli/flags_test.go new file mode 100644 index 0000000..992d70d --- /dev/null +++ b/cli/flags_test.go @@ -0,0 +1,85 @@ +package cli + +import ( + "bytes" + "io" + "io/ioutil" + "os" + "strings" + "testing" + + "github.com/danielmiessler/fabric/common" + "github.com/stretchr/testify/assert" +) + +func TestInit(t *testing.T) { + args := []string{"--copy"} + expectedFlags := &Flags{Copy: true} + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + os.Args = append([]string{"cmd"}, args...) + + flags, err := Init() + assert.NoError(t, err) + assert.Equal(t, expectedFlags.Copy, flags.Copy) +} + +func TestReadStdin(t *testing.T) { + input := "test input" + stdin := ioutil.NopCloser(strings.NewReader(input)) + // No need to cast stdin to *os.File, pass it as io.ReadCloser directly + content, err := ReadStdin(stdin) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != input { + t.Fatalf("expected %q, got %q", input, content) + } +} + +// ReadStdin function assuming it's part of `cli` package +func ReadStdin(reader io.ReadCloser) (string, error) { + defer reader.Close() + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(reader) + if err != nil { + return "", err + } + return buf.String(), nil +} + +func TestBuildChatOptions(t *testing.T) { + flags := &Flags{ + Temperature: 0.8, + TopP: 0.9, + PresencePenalty: 0.1, + FrequencyPenalty: 0.2, + } + + expectedOptions := &common.ChatOptions{ + Temperature: 0.8, + TopP: 0.9, + PresencePenalty: 0.1, + FrequencyPenalty: 0.2, + } + options := flags.BuildChatOptions() + assert.Equal(t, expectedOptions, options) +} + +func TestBuildChatRequest(t *testing.T) { + flags := &Flags{ + Context: "test-context", + Session: "test-session", + Pattern: "test-pattern", + Message: "test-message", + } + + expectedRequest := &common.ChatRequest{ + ContextName: "test-context", + SessionName: "test-session", + PatternName: "test-pattern", + Message: "test-message", + } + request := flags.BuildChatRequest() + assert.Equal(t, expectedRequest, request) +} diff --git a/common/configurable.go b/common/configurable.go index 0ec61d0..1386f44 100644 --- a/common/configurable.go +++ b/common/configurable.go @@ -67,6 +67,13 @@ func (o *Configurable) Setup() (err error) { return } +func (o *Configurable) SetupOrSkip() (err error) { + if err = o.Setup(); err != nil { + fmt.Printf("[%v] skipped\n", o.GetName()) + } + return +} + func NewSetting(envVariable string, required bool) *Setting { return &Setting{ EnvVariable: envVariable, diff --git a/core/fabric.go b/core/fabric.go index 99a0864..1289dc1 100644 --- a/core/fabric.go +++ b/core/fabric.go @@ -106,9 +106,7 @@ func (o *Fabric) Setup() (err error) { return } - if youtubeErr := o.YouTube.Setup(); youtubeErr != nil { - fmt.Printf("[%v] skipped\n", o.YouTube.GetName()) - } + _ = o.YouTube.SetupOrSkip() if err = o.PatternsLoader.Setup(); err != nil { return @@ -152,16 +150,9 @@ func (o *Fabric) SetupDefaultModel() (err error) { } func (o *Fabric) SetupVendors() (err error) { - o.Reset() - - for _, vendor := range o.VendorsAll.Vendors { - fmt.Println() - if vendorErr := vendor.Setup(); vendorErr == nil { - fmt.Printf("[%v] configured\n", vendor.GetName()) - o.AddVendors(vendor) - } else { - fmt.Printf("[%v] skipped\n", vendor.GetName()) - } + o.Models = nil + if o.Vendors, err = o.VendorsAll.Setup(); err != nil { + return } if !o.HasVendors() { diff --git a/core/vendors.go b/core/vendors.go index b81d26b..82f1a71 100644 --- a/core/vendors.go +++ b/core/vendors.go @@ -24,11 +24,6 @@ func (o *VendorsManager) AddVendors(vendors ...vendors.Vendor) { } } -func (o *VendorsManager) Reset() { - o.Vendors = map[string]vendors.Vendor{} - o.Models = nil -} - func (o *VendorsManager) GetModels() *VendorsModels { if o.Models == nil { o.readModels() @@ -90,6 +85,20 @@ func (o *VendorsManager) fetchVendorModels( } } +func (o *VendorsManager) Setup() (ret map[string]vendors.Vendor, err error) { + ret = map[string]vendors.Vendor{} + for _, vendor := range o.Vendors { + fmt.Println() + if vendorErr := vendor.Setup(); vendorErr == nil { + fmt.Printf("[%v] configured\n", vendor.GetName()) + ret[vendor.GetName()] = vendor + } else { + fmt.Printf("[%v] skipped\n", vendor.GetName()) + } + } + return +} + type modelResult struct { vendorName string models []string