mirror of
https://github.com/danielmiessler/fabric
synced 2024-11-08 07:11:06 +00:00
feat: simplify setup logic
This commit is contained in:
parent
6996278c8f
commit
4b3afb3c8e
30
cli/cli.go
30
cli/cli.go
@ -14,7 +14,7 @@ import (
|
||||
func Cli() (message string, err error) {
|
||||
var currentFlags *Flags
|
||||
if currentFlags, err = Init(); err != nil {
|
||||
// we need to reset error, because we want to show double help messages
|
||||
// we need to reset error, because we don't want to show double help messages
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
@ -24,23 +24,23 @@ func Cli() (message string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
db := db.NewDb(filepath.Join(homedir, ".config/fabric"))
|
||||
fabricDb := db.NewDb(filepath.Join(homedir, ".config/fabric"))
|
||||
|
||||
// if the setup flag is set, run the setup function
|
||||
if currentFlags.Setup {
|
||||
_ = db.Configure()
|
||||
_, err = Setup(db, currentFlags.SetupSkipUpdatePatterns)
|
||||
_ = fabricDb.Configure()
|
||||
_, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns)
|
||||
return
|
||||
}
|
||||
|
||||
var fabric *core.Fabric
|
||||
if err = db.Configure(); err != nil {
|
||||
if err = fabricDb.Configure(); err != nil {
|
||||
fmt.Println("init is failed, run start the setup procedure", err)
|
||||
if fabric, err = Setup(db, currentFlags.SetupSkipUpdatePatterns); err != nil {
|
||||
if fabric, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if fabric, err = core.NewFabric(db); err != nil {
|
||||
if fabric, err = core.NewFabric(fabricDb); err != nil {
|
||||
fmt.Println("fabric can't initialize, please run the --setup procedure", err)
|
||||
return
|
||||
}
|
||||
@ -64,7 +64,7 @@ func Cli() (message string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = db.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
|
||||
if err = fabricDb.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
@ -72,7 +72,7 @@ func Cli() (message string, err error) {
|
||||
|
||||
// if the list patterns flag is set, run the list all patterns function
|
||||
if currentFlags.ListPatterns {
|
||||
err = db.Patterns.ListNames()
|
||||
err = fabricDb.Patterns.ListNames()
|
||||
return
|
||||
}
|
||||
|
||||
@ -84,13 +84,13 @@ func Cli() (message string, err error) {
|
||||
|
||||
// if the list all contexts flag is set, run the list all contexts function
|
||||
if currentFlags.ListAllContexts {
|
||||
err = db.Contexts.ListNames()
|
||||
err = fabricDb.Contexts.ListNames()
|
||||
return
|
||||
}
|
||||
|
||||
// if the list all sessions flag is set, run the list all sessions function
|
||||
if currentFlags.ListAllSessions {
|
||||
err = db.Sessions.ListNames()
|
||||
err = fabricDb.Sessions.ListNames()
|
||||
return
|
||||
}
|
||||
|
||||
@ -129,17 +129,17 @@ func Cli() (message string, err error) {
|
||||
}
|
||||
|
||||
func Setup(db *db.Db, skipUpdatePatterns bool) (ret *core.Fabric, err error) {
|
||||
ret = core.NewFabricForSetup(db)
|
||||
instance := core.NewFabricForSetup(db)
|
||||
|
||||
if err = ret.Setup(); err != nil {
|
||||
if err = instance.Setup(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !skipUpdatePatterns {
|
||||
if err = ret.PopulateDB(); err != nil {
|
||||
if err = instance.PopulateDB(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ret = instance
|
||||
return
|
||||
}
|
||||
|
23
cli/cli_test.go
Normal file
23
cli/cli_test.go
Normal file
@ -0,0 +1,23 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/db"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCli(t *testing.T) {
|
||||
message, err := Cli()
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, message)
|
||||
}
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
mockDB := db.NewDb(os.TempDir())
|
||||
|
||||
fabric, err := Setup(mockDB, false)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, fabric)
|
||||
}
|
85
cli/flags_test.go
Normal file
85
cli/flags_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
args := []string{"--copy"}
|
||||
expectedFlags := &Flags{Copy: true}
|
||||
oldArgs := os.Args
|
||||
defer func() { os.Args = oldArgs }()
|
||||
os.Args = append([]string{"cmd"}, args...)
|
||||
|
||||
flags, err := Init()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedFlags.Copy, flags.Copy)
|
||||
}
|
||||
|
||||
func TestReadStdin(t *testing.T) {
|
||||
input := "test input"
|
||||
stdin := ioutil.NopCloser(strings.NewReader(input))
|
||||
// No need to cast stdin to *os.File, pass it as io.ReadCloser directly
|
||||
content, err := ReadStdin(stdin)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if content != input {
|
||||
t.Fatalf("expected %q, got %q", input, content)
|
||||
}
|
||||
}
|
||||
|
||||
// ReadStdin function assuming it's part of `cli` package
|
||||
func ReadStdin(reader io.ReadCloser) (string, error) {
|
||||
defer reader.Close()
|
||||
buf := new(bytes.Buffer)
|
||||
_, err := buf.ReadFrom(reader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func TestBuildChatOptions(t *testing.T) {
|
||||
flags := &Flags{
|
||||
Temperature: 0.8,
|
||||
TopP: 0.9,
|
||||
PresencePenalty: 0.1,
|
||||
FrequencyPenalty: 0.2,
|
||||
}
|
||||
|
||||
expectedOptions := &common.ChatOptions{
|
||||
Temperature: 0.8,
|
||||
TopP: 0.9,
|
||||
PresencePenalty: 0.1,
|
||||
FrequencyPenalty: 0.2,
|
||||
}
|
||||
options := flags.BuildChatOptions()
|
||||
assert.Equal(t, expectedOptions, options)
|
||||
}
|
||||
|
||||
func TestBuildChatRequest(t *testing.T) {
|
||||
flags := &Flags{
|
||||
Context: "test-context",
|
||||
Session: "test-session",
|
||||
Pattern: "test-pattern",
|
||||
Message: "test-message",
|
||||
}
|
||||
|
||||
expectedRequest := &common.ChatRequest{
|
||||
ContextName: "test-context",
|
||||
SessionName: "test-session",
|
||||
PatternName: "test-pattern",
|
||||
Message: "test-message",
|
||||
}
|
||||
request := flags.BuildChatRequest()
|
||||
assert.Equal(t, expectedRequest, request)
|
||||
}
|
@ -67,6 +67,13 @@ func (o *Configurable) Setup() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Configurable) SetupOrSkip() (err error) {
|
||||
if err = o.Setup(); err != nil {
|
||||
fmt.Printf("[%v] skipped\n", o.GetName())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func NewSetting(envVariable string, required bool) *Setting {
|
||||
return &Setting{
|
||||
EnvVariable: envVariable,
|
||||
|
@ -106,9 +106,7 @@ func (o *Fabric) Setup() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
if youtubeErr := o.YouTube.Setup(); youtubeErr != nil {
|
||||
fmt.Printf("[%v] skipped\n", o.YouTube.GetName())
|
||||
}
|
||||
_ = o.YouTube.SetupOrSkip()
|
||||
|
||||
if err = o.PatternsLoader.Setup(); err != nil {
|
||||
return
|
||||
@ -152,16 +150,9 @@ func (o *Fabric) SetupDefaultModel() (err error) {
|
||||
}
|
||||
|
||||
func (o *Fabric) SetupVendors() (err error) {
|
||||
o.Reset()
|
||||
|
||||
for _, vendor := range o.VendorsAll.Vendors {
|
||||
fmt.Println()
|
||||
if vendorErr := vendor.Setup(); vendorErr == nil {
|
||||
fmt.Printf("[%v] configured\n", vendor.GetName())
|
||||
o.AddVendors(vendor)
|
||||
} else {
|
||||
fmt.Printf("[%v] skipped\n", vendor.GetName())
|
||||
}
|
||||
o.Models = nil
|
||||
if o.Vendors, err = o.VendorsAll.Setup(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !o.HasVendors() {
|
||||
|
@ -24,11 +24,6 @@ func (o *VendorsManager) AddVendors(vendors ...vendors.Vendor) {
|
||||
}
|
||||
}
|
||||
|
||||
func (o *VendorsManager) Reset() {
|
||||
o.Vendors = map[string]vendors.Vendor{}
|
||||
o.Models = nil
|
||||
}
|
||||
|
||||
func (o *VendorsManager) GetModels() *VendorsModels {
|
||||
if o.Models == nil {
|
||||
o.readModels()
|
||||
@ -90,6 +85,20 @@ func (o *VendorsManager) fetchVendorModels(
|
||||
}
|
||||
}
|
||||
|
||||
func (o *VendorsManager) Setup() (ret map[string]vendors.Vendor, err error) {
|
||||
ret = map[string]vendors.Vendor{}
|
||||
for _, vendor := range o.Vendors {
|
||||
fmt.Println()
|
||||
if vendorErr := vendor.Setup(); vendorErr == nil {
|
||||
fmt.Printf("[%v] configured\n", vendor.GetName())
|
||||
ret[vendor.GetName()] = vendor
|
||||
} else {
|
||||
fmt.Printf("[%v] skipped\n", vendor.GetName())
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type modelResult struct {
|
||||
vendorName string
|
||||
models []string
|
||||
|
Loading…
Reference in New Issue
Block a user