feat: simplify setup logic

This commit is contained in:
Eugen Eisler 2024-08-22 21:45:36 +02:00
parent 6996278c8f
commit 4b3afb3c8e
6 changed files with 148 additions and 33 deletions

View File

@ -14,7 +14,7 @@ import (
func Cli() (message string, err error) {
var currentFlags *Flags
if currentFlags, err = Init(); err != nil {
// we need to reset error, because we want to show double help messages
// we need to reset error, because we don't want to show double help messages
err = nil
return
}
@ -24,23 +24,23 @@ func Cli() (message string, err error) {
return
}
db := db.NewDb(filepath.Join(homedir, ".config/fabric"))
fabricDb := db.NewDb(filepath.Join(homedir, ".config/fabric"))
// if the setup flag is set, run the setup function
if currentFlags.Setup {
_ = db.Configure()
_, err = Setup(db, currentFlags.SetupSkipUpdatePatterns)
_ = fabricDb.Configure()
_, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns)
return
}
var fabric *core.Fabric
if err = db.Configure(); err != nil {
if err = fabricDb.Configure(); err != nil {
fmt.Println("init is failed, run start the setup procedure", err)
if fabric, err = Setup(db, currentFlags.SetupSkipUpdatePatterns); err != nil {
if fabric, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns); err != nil {
return
}
} else {
if fabric, err = core.NewFabric(db); err != nil {
if fabric, err = core.NewFabric(fabricDb); err != nil {
fmt.Println("fabric can't initialize, please run the --setup procedure", err)
return
}
@ -64,7 +64,7 @@ func Cli() (message string, err error) {
return
}
if err = db.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
if err = fabricDb.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
return
}
return
@ -72,7 +72,7 @@ func Cli() (message string, err error) {
// if the list patterns flag is set, run the list all patterns function
if currentFlags.ListPatterns {
err = db.Patterns.ListNames()
err = fabricDb.Patterns.ListNames()
return
}
@ -84,13 +84,13 @@ func Cli() (message string, err error) {
// if the list all contexts flag is set, run the list all contexts function
if currentFlags.ListAllContexts {
err = db.Contexts.ListNames()
err = fabricDb.Contexts.ListNames()
return
}
// if the list all sessions flag is set, run the list all sessions function
if currentFlags.ListAllSessions {
err = db.Sessions.ListNames()
err = fabricDb.Sessions.ListNames()
return
}
@ -129,17 +129,17 @@ func Cli() (message string, err error) {
}
func Setup(db *db.Db, skipUpdatePatterns bool) (ret *core.Fabric, err error) {
ret = core.NewFabricForSetup(db)
instance := core.NewFabricForSetup(db)
if err = ret.Setup(); err != nil {
if err = instance.Setup(); err != nil {
return
}
if !skipUpdatePatterns {
if err = ret.PopulateDB(); err != nil {
if err = instance.PopulateDB(); err != nil {
return
}
}
ret = instance
return
}

23
cli/cli_test.go Normal file
View File

@ -0,0 +1,23 @@
package cli
import (
"os"
"testing"
"github.com/danielmiessler/fabric/db"
"github.com/stretchr/testify/assert"
)
func TestCli(t *testing.T) {
message, err := Cli()
assert.NoError(t, err)
assert.Empty(t, message)
}
func TestSetup(t *testing.T) {
mockDB := db.NewDb(os.TempDir())
fabric, err := Setup(mockDB, false)
assert.Error(t, err)
assert.Nil(t, fabric)
}

85
cli/flags_test.go Normal file
View File

@ -0,0 +1,85 @@
package cli
import (
"bytes"
"io"
"io/ioutil"
"os"
"strings"
"testing"
"github.com/danielmiessler/fabric/common"
"github.com/stretchr/testify/assert"
)
func TestInit(t *testing.T) {
args := []string{"--copy"}
expectedFlags := &Flags{Copy: true}
oldArgs := os.Args
defer func() { os.Args = oldArgs }()
os.Args = append([]string{"cmd"}, args...)
flags, err := Init()
assert.NoError(t, err)
assert.Equal(t, expectedFlags.Copy, flags.Copy)
}
func TestReadStdin(t *testing.T) {
input := "test input"
stdin := ioutil.NopCloser(strings.NewReader(input))
// No need to cast stdin to *os.File, pass it as io.ReadCloser directly
content, err := ReadStdin(stdin)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if content != input {
t.Fatalf("expected %q, got %q", input, content)
}
}
// ReadStdin function assuming it's part of `cli` package
func ReadStdin(reader io.ReadCloser) (string, error) {
defer reader.Close()
buf := new(bytes.Buffer)
_, err := buf.ReadFrom(reader)
if err != nil {
return "", err
}
return buf.String(), nil
}
func TestBuildChatOptions(t *testing.T) {
flags := &Flags{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
}
expectedOptions := &common.ChatOptions{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
}
options := flags.BuildChatOptions()
assert.Equal(t, expectedOptions, options)
}
func TestBuildChatRequest(t *testing.T) {
flags := &Flags{
Context: "test-context",
Session: "test-session",
Pattern: "test-pattern",
Message: "test-message",
}
expectedRequest := &common.ChatRequest{
ContextName: "test-context",
SessionName: "test-session",
PatternName: "test-pattern",
Message: "test-message",
}
request := flags.BuildChatRequest()
assert.Equal(t, expectedRequest, request)
}

View File

@ -67,6 +67,13 @@ func (o *Configurable) Setup() (err error) {
return
}
func (o *Configurable) SetupOrSkip() (err error) {
if err = o.Setup(); err != nil {
fmt.Printf("[%v] skipped\n", o.GetName())
}
return
}
func NewSetting(envVariable string, required bool) *Setting {
return &Setting{
EnvVariable: envVariable,

View File

@ -106,9 +106,7 @@ func (o *Fabric) Setup() (err error) {
return
}
if youtubeErr := o.YouTube.Setup(); youtubeErr != nil {
fmt.Printf("[%v] skipped\n", o.YouTube.GetName())
}
_ = o.YouTube.SetupOrSkip()
if err = o.PatternsLoader.Setup(); err != nil {
return
@ -152,16 +150,9 @@ func (o *Fabric) SetupDefaultModel() (err error) {
}
func (o *Fabric) SetupVendors() (err error) {
o.Reset()
for _, vendor := range o.VendorsAll.Vendors {
fmt.Println()
if vendorErr := vendor.Setup(); vendorErr == nil {
fmt.Printf("[%v] configured\n", vendor.GetName())
o.AddVendors(vendor)
} else {
fmt.Printf("[%v] skipped\n", vendor.GetName())
}
o.Models = nil
if o.Vendors, err = o.VendorsAll.Setup(); err != nil {
return
}
if !o.HasVendors() {

View File

@ -24,11 +24,6 @@ func (o *VendorsManager) AddVendors(vendors ...vendors.Vendor) {
}
}
func (o *VendorsManager) Reset() {
o.Vendors = map[string]vendors.Vendor{}
o.Models = nil
}
func (o *VendorsManager) GetModels() *VendorsModels {
if o.Models == nil {
o.readModels()
@ -90,6 +85,20 @@ func (o *VendorsManager) fetchVendorModels(
}
}
func (o *VendorsManager) Setup() (ret map[string]vendors.Vendor, err error) {
ret = map[string]vendors.Vendor{}
for _, vendor := range o.Vendors {
fmt.Println()
if vendorErr := vendor.Setup(); vendorErr == nil {
fmt.Printf("[%v] configured\n", vendor.GetName())
ret[vendor.GetName()] = vendor
} else {
fmt.Printf("[%v] skipped\n", vendor.GetName())
}
}
return
}
type modelResult struct {
vendorName string
models []string