diff --git a/cli/output.go b/cli/output.go new file mode 100644 index 0000000..f65b6ca --- /dev/null +++ b/cli/output.go @@ -0,0 +1,27 @@ +package cli + +import ( + "fmt" + "github.com/atotto/clipboard" + "os" +) + +func CopyToClipboard(message string) (err error) { + if err = clipboard.WriteAll(message); err != nil { + err = fmt.Errorf("could not copy to clipboard: %v", err) + } + return +} + +func CreateOutputFile(message string, fileName string) (err error) { + var file *os.File + if file, err = os.Create(fileName); err != nil { + err = fmt.Errorf("error creating file: %v", err) + return + } + defer file.Close() + if _, err = file.WriteString(message); err != nil { + err = fmt.Errorf("error writing to file: %v", err) + } + return +} diff --git a/cli/output_test.go b/cli/output_test.go new file mode 100644 index 0000000..1bfa1e2 --- /dev/null +++ b/cli/output_test.go @@ -0,0 +1,28 @@ +package cli + +import ( + "os" + "testing" +) + +func TestCopyToClipboard(t *testing.T) { + t.Skip("skipping test, because of docker env. in ci.") + + message := "test message" + err := CopyToClipboard(message) + if err != nil { + t.Fatalf("CopyToClipboard() error = %v", err) + } +} + +func TestCreateOutputFile(t *testing.T) { + + fileName := "test_output.txt" + message := "test message" + err := CreateOutputFile(message, fileName) + if err != nil { + t.Fatalf("CreateOutputFile() error = %v", err) + } + + defer os.Remove(fileName) +} diff --git a/common/groups_items.go b/common/groups_items.go new file mode 100644 index 0000000..b8d77c7 --- /dev/null +++ b/common/groups_items.go @@ -0,0 +1,134 @@ +package common + +import ( + "fmt" + "github.com/samber/lo" +) + +func NewGroupsItemsSelector[I any](selectionLabel string, + getItemLabel func(I) string) *GroupsItemsSelector[I] { + + return &GroupsItemsSelector[I]{SelectionLabel: selectionLabel, + GetItemKey: getItemLabel, + GroupsItems: make([]*GroupItems[I], 0), + } +} + +type GroupItems[I any] struct { + Group string + Items []I +} + +func (o *GroupItems[I]) Count() int { + return len(o.Items) +} + +func (o *GroupItems[I]) ContainsItemBy(predicate func(item I) bool) (ret bool) { + ret = lo.ContainsBy(o.Items, predicate) + return +} + +type GroupsItemsSelector[I any] struct { + SelectionLabel string + GetItemKey func(I) string + + GroupsItems []*GroupItems[I] +} + +func (o *GroupsItemsSelector[I]) AddGroupItems(group string, items ...I) { + o.GroupsItems = append(o.GroupsItems, &GroupItems[I]{group, items}) +} + +func (o *GroupsItemsSelector[I]) GetGroupAndItemByItemNumber(number int) (group string, item I, err error) { + var currentItemNumber int + found := false + + for _, groupItems := range o.GroupsItems { + if currentItemNumber+groupItems.Count() < number { + currentItemNumber += groupItems.Count() + continue + } + + for _, groupItem := range groupItems.Items { + currentItemNumber++ + if currentItemNumber == number { + group = groupItems.Group + item = groupItem + found = true + break + } + } + } + + if !found { + err = fmt.Errorf("number %d is out of range", number) + } + return +} + +func (o *GroupsItemsSelector[I]) Print() { + fmt.Printf("\n%v:\n", o.SelectionLabel) + + var currentItemIndex int + for _, groupItems := range o.GroupsItems { + fmt.Println() + fmt.Printf("%s\n", groupItems.Group) + fmt.Println() + + for _, item := range groupItems.Items { + currentItemIndex++ + fmt.Printf("\t[%d]\t%s\n", currentItemIndex, o.GetItemKey(item)) + + } + } +} + +func (o *GroupsItemsSelector[I]) HasGroup(group string) (ret bool) { + for _, groupItems := range o.GroupsItems { + if ret = groupItems.Group == group; ret { + break + } + } + return +} + +func (o *GroupsItemsSelector[I]) FindGroupsByItemFirst(item I) (ret string) { + itemKey := o.GetItemKey(item) + + for _, groupItems := range o.GroupsItems { + if groupItems.ContainsItemBy(func(groupItem I) bool { + groupItemKey := o.GetItemKey(groupItem) + return groupItemKey == itemKey + }) { + ret = groupItems.Group + break + } + } + return +} + +func (o *GroupsItemsSelector[I]) FindGroupsByItem(item I) (groups []string) { + itemKey := o.GetItemKey(item) + + for _, groupItems := range o.GroupsItems { + if groupItems.ContainsItemBy(func(groupItem I) bool { + groupItemKey := o.GetItemKey(groupItem) + return groupItemKey == itemKey + }) { + groups = append(groups, groupItems.Group) + } + } + return +} + +func ReturnItem(item string) string { + return item +} + +func NewGroupsItemsSelectorString(selectionLabel string) *GroupsItemsSelectorString { + return &GroupsItemsSelectorString{GroupsItemsSelector: NewGroupsItemsSelector(selectionLabel, ReturnItem)} +} + +type GroupsItemsSelectorString struct { + *GroupsItemsSelector[string] +} diff --git a/core/plugin_registry.go b/core/plugin_registry.go new file mode 100644 index 0000000..879e553 --- /dev/null +++ b/core/plugin_registry.go @@ -0,0 +1,203 @@ +package core + +import ( + "bytes" + "fmt" + "github.com/danielmiessler/fabric/common" + "github.com/danielmiessler/fabric/plugins/ai/azure" + "github.com/danielmiessler/fabric/plugins/tools" + "github.com/samber/lo" + "strconv" + + "github.com/danielmiessler/fabric/plugins" + "github.com/danielmiessler/fabric/plugins/ai" + "github.com/danielmiessler/fabric/plugins/ai/anthropic" + "github.com/danielmiessler/fabric/plugins/ai/dryrun" + "github.com/danielmiessler/fabric/plugins/ai/gemini" + "github.com/danielmiessler/fabric/plugins/ai/groq" + "github.com/danielmiessler/fabric/plugins/ai/mistral" + "github.com/danielmiessler/fabric/plugins/ai/ollama" + "github.com/danielmiessler/fabric/plugins/ai/openai" + "github.com/danielmiessler/fabric/plugins/ai/openrouter" + "github.com/danielmiessler/fabric/plugins/ai/siliconcloud" + "github.com/danielmiessler/fabric/plugins/db/fsdb" + "github.com/danielmiessler/fabric/plugins/tools/jina" + "github.com/danielmiessler/fabric/plugins/tools/lang" + "github.com/danielmiessler/fabric/plugins/tools/youtube" +) + +func NewPluginRegistry(db *fsdb.Db) (ret *PluginRegistry) { + ret = &PluginRegistry{ + Db: db, + VendorManager: ai.NewVendorsManager(), + VendorsAll: ai.NewVendorsManager(), + PatternsLoader: tools.NewPatternsLoader(db.Patterns), + YouTube: youtube.NewYouTube(), + Language: lang.NewLanguage(), + Jina: jina.NewClient(), + } + + ret.Defaults = tools.NeeDefaults(ret.VendorManager.GetModels) + + ret.VendorsAll.AddVendors(openai.NewClient(), ollama.NewClient(), azure.NewClient(), groq.NewClient(), + gemini.NewClient(), anthropic.NewClient(), siliconcloud.NewClient(), openrouter.NewClient(), mistral.NewClient()) + _ = ret.Configure() + + return +} + +type PluginRegistry struct { + Db *fsdb.Db + + VendorManager *ai.VendorsManager + VendorsAll *ai.VendorsManager + Defaults *tools.Defaults + PatternsLoader *tools.PatternsLoader + YouTube *youtube.YouTube + Language *lang.Language + Jina *jina.Client +} + +func (o *PluginRegistry) SaveEnvFile() (err error) { + // Now create the .env with all configured VendorsController info + var envFileContent bytes.Buffer + + o.Defaults.Settings.FillEnvFileContent(&envFileContent) + o.PatternsLoader.SetupFillEnvFileContent(&envFileContent) + + for _, vendor := range o.VendorManager.Vendors { + vendor.SetupFillEnvFileContent(&envFileContent) + } + + o.YouTube.SetupFillEnvFileContent(&envFileContent) + o.Jina.SetupFillEnvFileContent(&envFileContent) + o.Language.SetupFillEnvFileContent(&envFileContent) + + err = o.Db.SaveEnv(envFileContent.String()) + return +} + +func (o *PluginRegistry) Setup() (err error) { + setupQuestion := plugins.NewSetupQuestion("Enter the number of the plugin to setup") + groupsPlugins := common.NewGroupsItemsSelector[plugins.Plugin]("Available plugins", + func(plugin plugins.Plugin) string { + var configuredLabel string + if plugin.IsConfigured() { + configuredLabel = " (configured)" + } else { + configuredLabel = "" + } + return fmt.Sprintf("%v%v", plugin.GetSetupDescription(), configuredLabel) + }) + + groupsPlugins.AddGroupItems("AI Vendors [at least one, required]", lo.Map(o.VendorsAll.Vendors, + func(vendor ai.Vendor, _ int) plugins.Plugin { + return vendor + })...) + + groupsPlugins.AddGroupItems("Tools", o.Defaults, o.PatternsLoader, o.YouTube, o.Language, o.Jina) + + for { + groupsPlugins.Print() + + if answerErr := setupQuestion.Ask("Plugin Number"); answerErr != nil { + break + } + + if setupQuestion.Value == "" { + break + } + number, parseErr := strconv.Atoi(setupQuestion.Value) + setupQuestion.Value = "" + + if parseErr == nil { + var plugin plugins.Plugin + if _, plugin, err = groupsPlugins.GetGroupAndItemByItemNumber(number); err != nil { + return + } + + if pluginSetupErr := plugin.Setup(); pluginSetupErr != nil { + println(pluginSetupErr.Error()) + } else { + if err = o.SaveEnvFile(); err != nil { + break + } + } + + if _, ok := o.VendorManager.VendorsByName[plugin.GetName()]; !ok { + if vendor, ok := plugin.(ai.Vendor); ok { + o.VendorManager.AddVendors(vendor) + } + } + } else { + break + } + } + + err = o.SaveEnvFile() + + return +} + +func (o *PluginRegistry) SetupVendor(vendorName string) (err error) { + if err = o.VendorsAll.SetupVendor(vendorName, o.VendorManager.VendorsByName); err != nil { + return + } + err = o.SaveEnvFile() + return +} + +// Configure buildClient VendorsController based on the environment variables +func (o *PluginRegistry) Configure() (err error) { + for _, vendor := range o.VendorsAll.Vendors { + if vendorErr := vendor.Configure(); vendorErr == nil { + o.VendorManager.AddVendors(vendor) + } + } + _ = o.Defaults.Configure() + _ = o.PatternsLoader.Configure() + + //YouTube and Jina are not mandatory, so ignore not configured error + _ = o.YouTube.Configure() + _ = o.Jina.Configure() + _ = o.Language.Configure() + return +} + +func (o *PluginRegistry) GetChatter(model string, stream bool, dryRun bool) (ret *Chatter, err error) { + ret = &Chatter{ + db: o.Db, + Stream: stream, + DryRun: dryRun, + } + + defaultModel := o.Defaults.Model.Value + defaultVendor := o.Defaults.Vendor.Value + vendorManager := o.VendorManager + + if dryRun { + ret.vendor = dryrun.NewClient() + ret.model = model + if ret.model == "" { + ret.model = defaultModel + } + } else if model == "" { + ret.vendor = vendorManager.FindByName(defaultVendor) + ret.model = defaultModel + } else { + var models *ai.VendorsModels + if models, err = vendorManager.GetModels(); err != nil { + return + } + ret.vendor = vendorManager.FindByName(models.FindGroupsByItemFirst(model)) + ret.model = model + } + + if ret.vendor == nil { + err = fmt.Errorf( + "could not find vendor.\n Model = %s\n Model = %s\n Vendor = %s", + model, defaultModel, defaultVendor) + return + } + return +} diff --git a/core/plugin_registry_test.go b/core/plugin_registry_test.go new file mode 100644 index 0000000..76f8382 --- /dev/null +++ b/core/plugin_registry_test.go @@ -0,0 +1,16 @@ +package core + +import ( + "github.com/danielmiessler/fabric/plugins/db/fsdb" + "os" + "testing" +) + +func TestSaveEnvFile(t *testing.T) { + registry := NewPluginRegistry(fsdb.NewDb(os.TempDir())) + + err := registry.SaveEnvFile() + if err != nil { + t.Fatalf("SaveEnvFile() error = %v", err) + } +} diff --git a/plugins/ai/models.go b/plugins/ai/models.go new file mode 100644 index 0000000..a6ef7eb --- /dev/null +++ b/plugins/ai/models.go @@ -0,0 +1,13 @@ +package ai + +import ( + "github.com/danielmiessler/fabric/common" +) + +func NewVendorsModels() *VendorsModels { + return &VendorsModels{GroupsItemsSelectorString: common.NewGroupsItemsSelectorString("Available models")} +} + +type VendorsModels struct { + *common.GroupsItemsSelectorString +} diff --git a/plugins/ai/models_test.go b/plugins/ai/models_test.go new file mode 100644 index 0000000..b7e0ee4 --- /dev/null +++ b/plugins/ai/models_test.go @@ -0,0 +1,33 @@ +package ai + +import ( + "testing" +) + +func TestNewVendorsModels(t *testing.T) { + vendors := NewVendorsModels() + if vendors == nil { + t.Fatalf("NewVendorsModels() returned nil") + } + if len(vendors.GroupsItems) != 0 { + t.Fatalf("NewVendorsModels() returned non-empty VendorsModels map") + } +} + +func TestFindVendorsByModelFirst(t *testing.T) { + vendors := NewVendorsModels() + vendors.AddGroupItems("vendor1", []string{"model1", "model2"}...) + vendor := vendors.FindGroupsByItemFirst("model1") + if vendor != "vendor1" { + t.Fatalf("FindVendorsByModelFirst() = %v, want %v", vendor, "vendor1") + } +} + +func TestFindVendorsByModel(t *testing.T) { + vendors := NewVendorsModels() + vendors.AddGroupItems("vendor1", []string{"model1", "model2"}...) + foundVendors := vendors.FindGroupsByItem("model1") + if len(foundVendors) != 1 || foundVendors[0] != "vendor1" { + t.Fatalf("FindVendorsByModel() = %v, want %v", foundVendors, []string{"vendor1"}) + } +} diff --git a/plugins/ai/vendors.go b/plugins/ai/vendors.go new file mode 100644 index 0000000..8afc798 --- /dev/null +++ b/plugins/ai/vendors.go @@ -0,0 +1,147 @@ +package ai + +import ( + "bytes" + "context" + "fmt" + "github.com/danielmiessler/fabric/plugins" + "sync" +) + +func NewVendorsManager() *VendorsManager { + return &VendorsManager{ + Vendors: []Vendor{}, + VendorsByName: map[string]Vendor{}, + } +} + +type VendorsManager struct { + *plugins.PluginBase + Vendors []Vendor + VendorsByName map[string]Vendor + Models *VendorsModels +} + +func (o *VendorsManager) AddVendors(vendors ...Vendor) { + for _, vendor := range vendors { + o.VendorsByName[vendor.GetName()] = vendor + o.Vendors = append(o.Vendors, vendor) + } +} + +func (o *VendorsManager) SetupFillEnvFileContent(envFileContent *bytes.Buffer) { + for _, vendor := range o.Vendors { + vendor.SetupFillEnvFileContent(envFileContent) + } +} + +func (o *VendorsManager) GetModels() (ret *VendorsModels, err error) { + if o.Models == nil { + err = o.readModels() + } + ret = o.Models + return +} + +func (o *VendorsManager) Configure() (err error) { + for _, vendor := range o.Vendors { + _ = vendor.Configure() + } + return +} + +func (o *VendorsManager) HasVendors() bool { + return len(o.Vendors) > 0 +} + +func (o *VendorsManager) FindByName(name string) Vendor { + return o.VendorsByName[name] +} + +func (o *VendorsManager) readModels() (err error) { + if len(o.Vendors) == 0 { + + err = fmt.Errorf("no AI vendors configured to read models from. Please configure at least one AI vendor") + return + } + + o.Models = NewVendorsModels() + + var wg sync.WaitGroup + resultsChan := make(chan modelResult, len(o.Vendors)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for _, vendor := range o.Vendors { + wg.Add(1) + go o.fetchVendorModels(ctx, &wg, vendor, resultsChan) + } + + // Wait for all goroutines to finish + go func() { + wg.Wait() + close(resultsChan) + }() + + // Collect results + for result := range resultsChan { + if result.err != nil { + fmt.Println(result.vendorName, result.err) + cancel() // Cancel remaining goroutines if needed + } else { + o.Models.AddGroupItems(result.vendorName, result.models...) + } + } + return +} + +func (o *VendorsManager) fetchVendorModels( + ctx context.Context, wg *sync.WaitGroup, vendor Vendor, resultsChan chan<- modelResult) { + + defer wg.Done() + + models, err := vendor.ListModels() + select { + case <-ctx.Done(): + // Context canceled, don't send the result + return + case resultsChan <- modelResult{vendorName: vendor.GetName(), models: models, err: err}: + // Result sent + } +} + +func (o *VendorsManager) Setup() (ret map[string]Vendor, err error) { + ret = map[string]Vendor{} + for _, vendor := range o.Vendors { + fmt.Println() + o.setupVendorTo(vendor, ret) + } + return +} + +func (o *VendorsManager) SetupVendor(vendorName string, configuredVendors map[string]Vendor) (err error) { + vendor := o.FindByName(vendorName) + if vendor == nil { + err = fmt.Errorf("vendor %s not found", vendorName) + return + } + o.setupVendorTo(vendor, configuredVendors) + return +} + +func (o *VendorsManager) setupVendorTo(vendor Vendor, configuredVendors map[string]Vendor) { + if vendorErr := vendor.Setup(); vendorErr == nil { + fmt.Printf("[%v] configured\n", vendor.GetName()) + configuredVendors[vendor.GetName()] = vendor + } else { + delete(configuredVendors, vendor.GetName()) + fmt.Printf("[%v] skipped\n", vendor.GetName()) + } + return +} + +type modelResult struct { + vendorName string + models []string + err error +} diff --git a/plugins/db/fsdb/contexts.go b/plugins/db/fsdb/contexts.go new file mode 100644 index 0000000..f306e93 --- /dev/null +++ b/plugins/db/fsdb/contexts.go @@ -0,0 +1,32 @@ +package fsdb + +import "fmt" + +type ContextsEntity struct { + *StorageEntity +} + +// Get Load a context from file +func (o *ContextsEntity) Get(name string) (ret *Context, err error) { + var content []byte + if content, err = o.Load(name); err != nil { + return + } + + ret = &Context{Name: name, Content: string(content)} + return +} + +func (o *ContextsEntity) PrintContext(name string) (err error) { + var context *Context + if context, err = o.Get(name); err != nil { + return + } + fmt.Println(context.Content) + return +} + +type Context struct { + Name string + Content string +} diff --git a/plugins/db/fsdb/contexts_test.go b/plugins/db/fsdb/contexts_test.go new file mode 100644 index 0000000..83c10b7 --- /dev/null +++ b/plugins/db/fsdb/contexts_test.go @@ -0,0 +1,29 @@ +package fsdb + +import ( + "os" + "path/filepath" + "testing" +) + +func TestContexts_GetContext(t *testing.T) { + dir := t.TempDir() + contexts := &ContextsEntity{ + StorageEntity: &StorageEntity{Dir: dir}, + } + contextName := "testContext" + contextPath := filepath.Join(dir, contextName) + contextContent := "test content" + err := os.WriteFile(contextPath, []byte(contextContent), 0644) + if err != nil { + t.Fatalf("failed to write context file: %v", err) + } + context, err := contexts.Get(contextName) + if err != nil { + t.Fatalf("failed to get context: %v", err) + } + expectedContext := &Context{Name: contextName, Content: contextContent} + if *context != *expectedContext { + t.Errorf("expected %v, got %v", expectedContext, context) + } +} diff --git a/plugins/db/fsdb/db.go b/plugins/db/fsdb/db.go new file mode 100644 index 0000000..4e458de --- /dev/null +++ b/plugins/db/fsdb/db.go @@ -0,0 +1,91 @@ +package fsdb + +import ( + "fmt" + "github.com/joho/godotenv" + "os" + "path/filepath" + "time" +) + +func NewDb(dir string) (db *Db) { + + db = &Db{Dir: dir} + + db.EnvFilePath = db.FilePath(".env") + + db.Patterns = &PatternsEntity{ + StorageEntity: &StorageEntity{Label: "Patterns", Dir: db.FilePath("patterns"), ItemIsDir: true}, + SystemPatternFile: "system.md", + UniquePatternsFilePath: db.FilePath("unique_patterns.txt"), + } + + db.Sessions = &SessionsEntity{ + &StorageEntity{Label: "Sessions", Dir: db.FilePath("sessions"), FileExtension: ".json"}} + + db.Contexts = &ContextsEntity{ + &StorageEntity{Label: "Contexts", Dir: db.FilePath("contexts")}} + + return +} + +type Db struct { + Dir string + + Patterns *PatternsEntity + Sessions *SessionsEntity + Contexts *ContextsEntity + + EnvFilePath string +} + +func (o *Db) Configure() (err error) { + if err = os.MkdirAll(o.Dir, os.ModePerm); err != nil { + return + } + + if err = o.LoadEnvFile(); err != nil { + return + } + + if err = o.Patterns.Configure(); err != nil { + return + } + + if err = o.Sessions.Configure(); err != nil { + return + } + + if err = o.Contexts.Configure(); err != nil { + return + } + + return +} + +func (o *Db) LoadEnvFile() (err error) { + if err = godotenv.Load(o.EnvFilePath); err != nil { + err = fmt.Errorf("error loading .env file: %s", err) + } + return +} + +func (o *Db) IsEnvFileExists() (ret bool) { + _, err := os.Stat(o.EnvFilePath) + ret = !os.IsNotExist(err) + return +} + +func (o *Db) SaveEnv(content string) (err error) { + err = os.WriteFile(o.EnvFilePath, []byte(content), 0644) + return +} + +func (o *Db) FilePath(fileName string) (ret string) { + return filepath.Join(o.Dir, fileName) +} + +type DirectoryChange struct { + Dir string + Timestamp time.Time +} diff --git a/plugins/db/fsdb/db_test.go b/plugins/db/fsdb/db_test.go new file mode 100644 index 0000000..3971d1e --- /dev/null +++ b/plugins/db/fsdb/db_test.go @@ -0,0 +1,55 @@ +package fsdb + +import ( + "os" + "testing" +) + +func TestDb_Configure(t *testing.T) { + dir := t.TempDir() + db := NewDb(dir) + err := db.Configure() + if err == nil { + t.Fatalf("db is configured, but must not be at empty dir: %v", dir) + } + if db.IsEnvFileExists() { + t.Fatalf("db file exists, but must not be at empty dir: %v", dir) + } + + err = db.SaveEnv("") + if err != nil { + t.Fatalf("db can't save env for empty conf.: %v", err) + } + + err = db.Configure() + if err != nil { + t.Fatalf("db is not configured, but shall be after save: %v", err) + } +} + +func TestDb_LoadEnvFile(t *testing.T) { + dir := t.TempDir() + db := NewDb(dir) + content := "KEY=VALUE\n" + err := os.WriteFile(db.EnvFilePath, []byte(content), 0644) + if err != nil { + t.Fatalf("failed to write .env file: %v", err) + } + err = db.LoadEnvFile() + if err != nil { + t.Errorf("failed to load .env file: %v", err) + } +} + +func TestDb_SaveEnv(t *testing.T) { + dir := t.TempDir() + db := NewDb(dir) + content := "KEY=VALUE\n" + err := db.SaveEnv(content) + if err != nil { + t.Errorf("failed to save .env file: %v", err) + } + if _, err := os.Stat(db.EnvFilePath); os.IsNotExist(err) { + t.Errorf("expected .env file to be saved") + } +} diff --git a/plugins/db/fsdb/patterns.go b/plugins/db/fsdb/patterns.go new file mode 100644 index 0000000..0daa185 --- /dev/null +++ b/plugins/db/fsdb/patterns.go @@ -0,0 +1,68 @@ +package fsdb + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +type PatternsEntity struct { + *StorageEntity + SystemPatternFile string + UniquePatternsFilePath string +} + +func (o *PatternsEntity) Get(name string) (ret *Pattern, err error) { + patternPath := filepath.Join(o.Dir, name, o.SystemPatternFile) + + var pattern []byte + if pattern, err = os.ReadFile(patternPath); err != nil { + return + } + + patternStr := string(pattern) + ret = &Pattern{ + Name: name, + Pattern: patternStr, + } + return +} + +// GetApplyVariables finds a pattern by name and returns the pattern as an entry or an error +func (o *PatternsEntity) GetApplyVariables(name string, variables map[string]string) (ret *Pattern, err error) { + + if ret, err = o.Get(name); err != nil { + return + } + + if variables != nil && len(variables) > 0 { + for variableName, value := range variables { + ret.Pattern = strings.ReplaceAll(ret.Pattern, variableName, value) + } + } + return +} + +func (o *PatternsEntity) PrintLatestPatterns(latestNumber int) (err error) { + var contents []byte + if contents, err = os.ReadFile(o.UniquePatternsFilePath); err != nil { + err = fmt.Errorf("could not read unique patterns file. Pleas run --updatepatterns (%s)", err) + return + } + uniquePatterns := strings.Split(string(contents), "\n") + if latestNumber > len(uniquePatterns) { + latestNumber = len(uniquePatterns) + } + + for i := len(uniquePatterns) - 1; i > len(uniquePatterns)-latestNumber-1; i-- { + fmt.Println(uniquePatterns[i]) + } + return +} + +type Pattern struct { + Name string + Description string + Pattern string +} diff --git a/plugins/db/fsdb/patterns_test.go b/plugins/db/fsdb/patterns_test.go new file mode 100644 index 0000000..e6e477c --- /dev/null +++ b/plugins/db/fsdb/patterns_test.go @@ -0,0 +1 @@ +package fsdb diff --git a/plugins/db/fsdb/sessions.go b/plugins/db/fsdb/sessions.go new file mode 100644 index 0000000..930c04f --- /dev/null +++ b/plugins/db/fsdb/sessions.go @@ -0,0 +1,88 @@ +package fsdb + +import ( + "fmt" + "github.com/danielmiessler/fabric/common" +) + +type SessionsEntity struct { + *StorageEntity +} + +func (o *SessionsEntity) Get(name string) (session *Session, err error) { + session = &Session{Name: name} + + if o.Exists(name) { + err = o.LoadAsJson(name, &session.Messages) + } else { + fmt.Printf("Creating new session: %s\n", name) + } + return +} + +func (o *SessionsEntity) PrintSession(name string) (err error) { + if o.Exists(name) { + var session Session + if err = o.LoadAsJson(name, &session.Messages); err == nil { + fmt.Println(session.String()) + } + } + return +} + +func (o *SessionsEntity) SaveSession(session *Session) (err error) { + return o.SaveAsJson(session.Name, session.Messages) +} + +type Session struct { + Name string + Messages []*common.Message + + vendorMessages []*common.Message +} + +func (o *Session) IsEmpty() bool { + return len(o.Messages) == 0 +} + +func (o *Session) Append(messages ...*common.Message) { + if o.vendorMessages != nil { + for _, message := range messages { + o.Messages = append(o.Messages, message) + o.appendVendorMessage(message) + } + } else { + o.Messages = append(o.Messages, messages...) + } +} + +func (o *Session) GetVendorMessages() (ret []*common.Message) { + if o.vendorMessages == nil { + o.vendorMessages = []*common.Message{} + for _, message := range o.Messages { + o.appendVendorMessage(message) + } + } + ret = o.vendorMessages + return +} + +func (o *Session) appendVendorMessage(message *common.Message) { + if message.Role != common.ChatMessageRoleMeta { + o.vendorMessages = append(o.vendorMessages, message) + } +} + +func (o *Session) GetLastMessage() (ret *common.Message) { + if len(o.Messages) > 0 { + ret = o.Messages[len(o.Messages)-1] + } + return +} + +func (o *Session) String() (ret string) { + for _, message := range o.Messages { + ret += fmt.Sprintf("\n--- \n[%v]\n\n%v", message.Role, message.Content) + } + return +} diff --git a/plugins/db/fsdb/sessions_test.go b/plugins/db/fsdb/sessions_test.go new file mode 100644 index 0000000..70de5ac --- /dev/null +++ b/plugins/db/fsdb/sessions_test.go @@ -0,0 +1,38 @@ +package fsdb + +import ( + "testing" + + "github.com/danielmiessler/fabric/common" +) + +func TestSessions_GetOrCreateSession(t *testing.T) { + dir := t.TempDir() + sessions := &SessionsEntity{ + StorageEntity: &StorageEntity{Dir: dir, FileExtension: ".json"}, + } + sessionName := "testSession" + session, err := sessions.Get(sessionName) + if err != nil { + t.Fatalf("failed to get or create session: %v", err) + } + if session.Name != sessionName { + t.Errorf("expected session name %v, got %v", sessionName, session.Name) + } +} + +func TestSessions_SaveSession(t *testing.T) { + dir := t.TempDir() + sessions := &SessionsEntity{ + StorageEntity: &StorageEntity{Dir: dir, FileExtension: ".json"}, + } + sessionName := "testSession" + session := &Session{Name: sessionName, Messages: []*common.Message{{Content: "message1"}}} + err := sessions.SaveSession(session) + if err != nil { + t.Fatalf("failed to save session: %v", err) + } + if !sessions.Exists(sessionName) { + t.Errorf("expected session to be saved") + } +} diff --git a/plugins/db/fsdb/storage.go b/plugins/db/fsdb/storage.go new file mode 100644 index 0000000..44e3e0b --- /dev/null +++ b/plugins/db/fsdb/storage.go @@ -0,0 +1,148 @@ +package fsdb + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/samber/lo" +) + +type StorageEntity struct { + Label string + Dir string + ItemIsDir bool + FileExtension string +} + +func (o *StorageEntity) Configure() (err error) { + if err = os.MkdirAll(o.Dir, os.ModePerm); err != nil { + return + } + return +} + +// GetNames finds all patterns in the patterns directory and enters the id, name, and pattern into a slice of Entry structs. it returns these entries or an error +func (o *StorageEntity) GetNames() (ret []string, err error) { + var entries []os.DirEntry + if entries, err = os.ReadDir(o.Dir); err != nil { + err = fmt.Errorf("could not read items from directory: %v", err) + return + } + + if o.ItemIsDir { + ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) { + if ok = item.IsDir(); ok { + ret = item.Name() + } + return + }) + } else { + if o.FileExtension == "" { + ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) { + if ok = !item.IsDir(); ok { + ret = item.Name() + } + return + }) + } else { + ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) { + if ok = !item.IsDir() && filepath.Ext(item.Name()) == o.FileExtension; ok { + ret = strings.TrimSuffix(item.Name(), o.FileExtension) + } + return + }) + } + } + return +} + +func (o *StorageEntity) Delete(name string) (err error) { + if err = os.Remove(o.BuildFilePathByName(name)); err != nil { + err = fmt.Errorf("could not delete %s: %v", name, err) + } + return +} + +func (o *StorageEntity) Exists(name string) (ret bool) { + _, err := os.Stat(o.BuildFilePathByName(name)) + ret = !os.IsNotExist(err) + return +} + +func (o *StorageEntity) Rename(oldName, newName string) (err error) { + if err = os.Rename(o.BuildFilePathByName(oldName), o.BuildFilePathByName(newName)); err != nil { + err = fmt.Errorf("could not rename %s to %s: %v", oldName, newName, err) + } + return +} + +func (o *StorageEntity) Save(name string, content []byte) (err error) { + if err = os.WriteFile(o.BuildFilePathByName(name), content, 0644); err != nil { + err = fmt.Errorf("could not save %s: %v", name, err) + } + return +} + +func (o *StorageEntity) Load(name string) (ret []byte, err error) { + if ret, err = os.ReadFile(o.BuildFilePathByName(name)); err != nil { + err = fmt.Errorf("could not load %s: %v", name, err) + } + return +} + +func (o *StorageEntity) ListNames() (err error) { + var names []string + if names, err = o.GetNames(); err != nil { + return + } + + if len(names) == 0 { + fmt.Printf("\nNo %v\n", o.Label) + return + } + + for _, item := range names { + fmt.Printf("%s\n", item) + } + return +} + +func (o *StorageEntity) BuildFilePathByName(name string) (ret string) { + ret = o.BuildFilePath(o.buildFileName(name)) + return +} + +func (o *StorageEntity) BuildFilePath(fileName string) (ret string) { + ret = filepath.Join(o.Dir, fileName) + return +} + +func (o *StorageEntity) buildFileName(name string) string { + return fmt.Sprintf("%s%v", name, o.FileExtension) +} + +func (o *StorageEntity) SaveAsJson(name string, item interface{}) (err error) { + var jsonString []byte + if jsonString, err = json.Marshal(item); err == nil { + err = o.Save(name, jsonString) + } else { + err = fmt.Errorf("could not marshal %s: %s", name, err) + } + + return err +} + +func (o *StorageEntity) LoadAsJson(name string, item interface{}) (err error) { + var content []byte + if content, err = o.Load(name); err != nil { + return + } + + if err = json.Unmarshal(content, &item); err != nil { + err = fmt.Errorf("could not unmarshal %s: %s", name, err) + } + return +} diff --git a/plugins/db/fsdb/storage_test.go b/plugins/db/fsdb/storage_test.go new file mode 100644 index 0000000..761315e --- /dev/null +++ b/plugins/db/fsdb/storage_test.go @@ -0,0 +1,52 @@ +package fsdb + +import ( + "testing" +) + +func TestStorage_SaveAndLoad(t *testing.T) { + dir := t.TempDir() + storage := &StorageEntity{Dir: dir} + name := "test" + content := []byte("test content") + if err := storage.Save(name, content); err != nil { + t.Fatalf("failed to save content: %v", err) + } + loadedContent, err := storage.Load(name) + if err != nil { + t.Fatalf("failed to load content: %v", err) + } + if string(loadedContent) != string(content) { + t.Errorf("expected %v, got %v", string(content), string(loadedContent)) + } +} + +func TestStorage_Exists(t *testing.T) { + dir := t.TempDir() + storage := &StorageEntity{Dir: dir} + name := "test" + if storage.Exists(name) { + t.Errorf("expected file to not exist") + } + if err := storage.Save(name, []byte("test content")); err != nil { + t.Fatalf("failed to save content: %v", err) + } + if !storage.Exists(name) { + t.Errorf("expected file to exist") + } +} + +func TestStorage_Delete(t *testing.T) { + dir := t.TempDir() + storage := &StorageEntity{Dir: dir} + name := "test" + if err := storage.Save(name, []byte("test content")); err != nil { + t.Fatalf("failed to save content: %v", err) + } + if err := storage.Delete(name); err != nil { + t.Fatalf("failed to delete content: %v", err) + } + if storage.Exists(name) { + t.Errorf("expected file to be deleted") + } +} diff --git a/plugins/plugin.go b/plugins/plugin.go new file mode 100644 index 0000000..2574a35 --- /dev/null +++ b/plugins/plugin.go @@ -0,0 +1,242 @@ +package plugins + +import ( + "bytes" + "fmt" + "os" + "strings" +) + +const AnswerReset = "reset" + +type Plugin interface { + GetName() string + GetSetupDescription() string + IsConfigured() bool + Configure() error + Setup() error + SetupFillEnvFileContent(*bytes.Buffer) +} + +type PluginBase struct { + Settings + SetupQuestions + + Name string + SetupDescription string + EnvNamePrefix string + + ConfigureCustom func() error +} + +func (o *PluginBase) GetName() string { + return o.Name +} + +func (o *PluginBase) GetSetupDescription() (ret string) { + if ret = o.SetupDescription; ret == "" { + ret = o.GetName() + } + return +} + +func (o *PluginBase) AddSetting(name string, required bool) (ret *Setting) { + ret = NewSetting(fmt.Sprintf("%v%v", o.EnvNamePrefix, BuildEnvVariable(name)), required) + o.Settings = append(o.Settings, ret) + return +} + +func (o *PluginBase) AddSetupQuestion(name string, required bool) (ret *SetupQuestion) { + return o.AddSetupQuestionCustom(name, required, "") +} + +func (o *PluginBase) AddSetupQuestionCustom(name string, required bool, question string) (ret *SetupQuestion) { + setting := o.AddSetting(name, required) + ret = &SetupQuestion{Setting: setting, Question: question} + if ret.Question == "" { + ret.Question = fmt.Sprintf("Enter your %v %v", o.Name, strings.ToUpper(name)) + } + o.SetupQuestions = append(o.SetupQuestions, ret) + return +} + +func (o *PluginBase) Configure() (err error) { + if err = o.Settings.Configure(); err != nil { + return + } + + if o.ConfigureCustom != nil { + err = o.ConfigureCustom() + } + return +} + +func (o *PluginBase) Setup() (err error) { + if err = o.Ask(o.Name); err != nil { + return + } + + err = o.Configure() + return +} + +func (o *PluginBase) SetupOrSkip() (err error) { + if err = o.Setup(); err != nil { + fmt.Printf("[%v] skipped\n", o.GetName()) + } + return +} + +func (o *PluginBase) SetupFillEnvFileContent(fileEnvFileContent *bytes.Buffer) { + o.Settings.FillEnvFileContent(fileEnvFileContent) +} + +func NewSetting(envVariable string, required bool) *Setting { + return &Setting{ + EnvVariable: envVariable, + Required: required, + } +} + +type Setting struct { + EnvVariable string + Value string + Required bool +} + +func (o *Setting) IsValid() bool { + return o.IsDefined() || !o.Required +} + +func (o *Setting) IsValidErr() (err error) { + if !o.IsValid() { + err = fmt.Errorf("%v=%v, is not valid", o.EnvVariable, o.Value) + } + return +} + +func (o *Setting) IsDefined() bool { + return o.Value != "" +} + +func (o *Setting) Configure() error { + envValue := os.Getenv(o.EnvVariable) + if envValue != "" { + o.Value = envValue + } + return o.IsValidErr() +} + +func (o *Setting) FillEnvFileContent(buffer *bytes.Buffer) { + if o.IsDefined() { + buffer.WriteString(o.EnvVariable) + buffer.WriteString("=") + //buffer.WriteString("\"") + buffer.WriteString(o.Value) + //buffer.WriteString("\"") + buffer.WriteString("\n") + } + return +} + +func (o *Setting) Print() { + fmt.Printf("%v: %v\n", o.EnvVariable, o.Value) +} + +func NewSetupQuestion(question string) *SetupQuestion { + return &SetupQuestion{Setting: &Setting{}, Question: question} +} + +type SetupQuestion struct { + *Setting + Question string +} + +func (o *SetupQuestion) Ask(label string) (err error) { + var prefix string + + if label != "" { + prefix = fmt.Sprintf("[%v] ", label) + } else { + prefix = "" + } + + fmt.Println() + if o.Value != "" { + fmt.Printf("%v%v (leave empty for '%s' or type '%v' to remove the value):\n", + prefix, o.Question, o.Value, AnswerReset) + } else { + fmt.Printf("%v%v (leave empty to skip):\n", prefix, o.Question) + } + + var answer string + fmt.Scanln(&answer) + answer = strings.TrimRight(answer, "\n") + if answer == "" { + answer = o.Value + } else if strings.ToLower(answer) == AnswerReset { + answer = "" + } + err = o.OnAnswer(answer) + return +} + +func (o *SetupQuestion) OnAnswer(answer string) (err error) { + o.Value = answer + err = o.IsValidErr() + return +} + +type Settings []*Setting + +func (o Settings) IsConfigured() (ret bool) { + ret = true + for _, setting := range o { + if ret = setting.IsValid(); !ret { + break + } + } + return +} + +func (o Settings) Configure() (err error) { + for _, setting := range o { + if err = setting.Configure(); err != nil { + break + } + } + return +} + +func (o Settings) FillEnvFileContent(buffer *bytes.Buffer) { + for _, setting := range o { + setting.FillEnvFileContent(buffer) + } + return +} + +type SetupQuestions []*SetupQuestion + +func (o SetupQuestions) Ask(label string) (err error) { + fmt.Println() + fmt.Printf("[%v]\n", label) + for _, question := range o { + if err = question.Ask(""); err != nil { + break + } + } + return +} + +func BuildEnvVariablePrefix(name string) (ret string) { + ret = BuildEnvVariable(name) + if ret != "" { + ret += "_" + } + return +} + +func BuildEnvVariable(name string) string { + name = strings.TrimSpace(name) + return strings.ReplaceAll(strings.ToUpper(name), " ", "_") +} diff --git a/plugins/plugin_test.go b/plugins/plugin_test.go new file mode 100644 index 0000000..62f3ed8 --- /dev/null +++ b/plugins/plugin_test.go @@ -0,0 +1,176 @@ +package plugins + +import ( + "bytes" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConfigurable_AddSetting(t *testing.T) { + conf := &PluginBase{ + Settings: Settings{}, + Name: "TestConfigurable", + EnvNamePrefix: "TEST_", + } + + setting := conf.AddSetting("test_setting", true) + assert.Equal(t, "TEST_TEST_SETTING", setting.EnvVariable) + assert.True(t, setting.Required) + assert.Contains(t, conf.Settings, setting) +} + +func TestConfigurable_Configure(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Required: true, + } + conf := &PluginBase{ + Settings: Settings{setting}, + Name: "TestConfigurable", + } + + _ = os.Setenv("TEST_SETTING", "test_value") + err := conf.Configure() + assert.NoError(t, err) + assert.Equal(t, "test_value", setting.Value) +} + +func TestConfigurable_Setup(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Required: false, + } + conf := &PluginBase{ + Settings: Settings{setting}, + Name: "TestConfigurable", + } + + err := conf.Setup() + assert.NoError(t, err) +} + +func TestSetting_IsValid(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Value: "some_value", + Required: true, + } + + assert.True(t, setting.IsValid()) + + setting.Value = "" + assert.False(t, setting.IsValid()) +} + +func TestSetting_Configure(t *testing.T) { + _ = os.Setenv("TEST_SETTING", "test_value") + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Required: true, + } + err := setting.Configure() + assert.NoError(t, err) + assert.Equal(t, "test_value", setting.Value) +} + +func TestSetting_FillEnvFileContent(t *testing.T) { + buffer := &bytes.Buffer{} + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Value: "test_value", + } + setting.FillEnvFileContent(buffer) + + expected := "TEST_SETTING=test_value\n" + assert.Equal(t, expected, buffer.String()) +} + +func TestSetting_Print(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Value: "test_value", + } + expected := "TEST_SETTING: test_value\n" + fmtOutput := captureOutput(func() { + setting.Print() + }) + assert.Equal(t, expected, fmtOutput) +} + +func TestSetupQuestion_Ask(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Required: true, + } + question := &SetupQuestion{ + Setting: setting, + Question: "Enter test setting:", + } + input := "user_value\n" + fmtInput := captureInput(input) + defer fmtInput() + err := question.Ask("TestConfigurable") + assert.NoError(t, err) + assert.Equal(t, "user_value", setting.Value) +} + +func TestSettings_IsConfigured(t *testing.T) { + settings := Settings{ + {EnvVariable: "TEST_SETTING1", Value: "value1", Required: true}, + {EnvVariable: "TEST_SETTING2", Value: "", Required: false}, + } + + assert.True(t, settings.IsConfigured()) + + settings[0].Value = "" + assert.False(t, settings.IsConfigured()) +} + +func TestSettings_Configure(t *testing.T) { + _ = os.Setenv("TEST_SETTING", "test_value") + settings := Settings{ + {EnvVariable: "TEST_SETTING", Required: true}, + } + + err := settings.Configure() + assert.NoError(t, err) + assert.Equal(t, "test_value", settings[0].Value) +} + +func TestSettings_FillEnvFileContent(t *testing.T) { + buffer := &bytes.Buffer{} + settings := Settings{ + {EnvVariable: "TEST_SETTING", Value: "test_value"}, + } + settings.FillEnvFileContent(buffer) + + expected := "TEST_SETTING=test_value\n" + assert.Equal(t, expected, buffer.String()) +} + +// captureOutput captures the output of a function call +func captureOutput(f func()) string { + var buf bytes.Buffer + stdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + f() + _ = w.Close() + os.Stdout = stdout + _, _ = buf.ReadFrom(r) + return buf.String() +} + +// captureInput captures the input for a function call +func captureInput(input string) func() { + r, w, _ := os.Pipe() + _, _ = w.WriteString(input) + _ = w.Close() + stdin := os.Stdin + os.Stdin = r + return func() { + os.Stdin = stdin + } +} diff --git a/plugins/tools/defaults.go b/plugins/tools/defaults.go new file mode 100644 index 0000000..16b5330 --- /dev/null +++ b/plugins/tools/defaults.go @@ -0,0 +1,71 @@ +package tools + +import ( + "fmt" + "strconv" + + "github.com/danielmiessler/fabric/plugins" + "github.com/danielmiessler/fabric/plugins/ai" + "github.com/pkg/errors" +) + +func NeeDefaults(getVendorsModels func() (*ai.VendorsModels, error)) (ret *Defaults) { + vendorName := "Default" + ret = &Defaults{ + PluginBase: &plugins.PluginBase{ + Name: vendorName, + SetupDescription: "Default AI Vendor and Model [required]", + EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName), + }, + GetVendorsModels: getVendorsModels, + } + + ret.Vendor = ret.AddSetting("Vendor", true) + ret.Model = ret.AddSetupQuestionCustom("Model", true, + "Enter the index the name of your default model") + + return +} + +type Defaults struct { + *plugins.PluginBase + + Vendor *plugins.Setting + Model *plugins.SetupQuestion + GetVendorsModels func() (*ai.VendorsModels, error) +} + +func (o *Defaults) Setup() (err error) { + var vendorsModels *ai.VendorsModels + if vendorsModels, err = o.GetVendorsModels(); err != nil { + return + } + + vendorsModels.Print() + + if err = o.Ask(o.Name); err != nil { + return + } + + index, parseErr := strconv.Atoi(o.Model.Value) + if parseErr == nil { + if o.Vendor.Value, o.Model.Value, err = vendorsModels.GetGroupAndItemByItemNumber(index); err != nil { + return + } + } else { + o.Vendor.Value = vendorsModels.FindGroupsByItemFirst(o.Model.Value) + } + + //verify + vendorNames := vendorsModels.FindGroupsByItem(o.Model.Value) + if len(vendorNames) == 0 { + err = errors.Errorf("You need to chose an available default model.") + return + } + + fmt.Println() + o.Vendor.Print() + o.Model.Print() + + return +} diff --git a/plugins/tools/patterns_loader.go b/plugins/tools/patterns_loader.go new file mode 100644 index 0000000..815bbbe --- /dev/null +++ b/plugins/tools/patterns_loader.go @@ -0,0 +1,303 @@ +package tools + +import ( + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/danielmiessler/fabric/plugins" + "github.com/danielmiessler/fabric/plugins/db/fsdb" + + "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/plumbing" + "github.com/go-git/go-git/v5/plumbing/object" + "github.com/go-git/go-git/v5/storage/memory" + "github.com/otiai10/copy" +) + +const DefaultPatternsGitRepoUrl = "https://github.com/danielmiessler/fabric.git" +const DefaultPatternsGitRepoFolder = "patterns" + +func NewPatternsLoader(patterns *fsdb.PatternsEntity) (ret *PatternsLoader) { + label := "Patterns Loader" + ret = &PatternsLoader{ + Patterns: patterns, + loadedFilePath: patterns.BuildFilePath("loaded"), + } + + ret.PluginBase = &plugins.PluginBase{ + Name: label, + SetupDescription: "Patterns - Downloads patterns [required]", + EnvNamePrefix: plugins.BuildEnvVariablePrefix(label), + ConfigureCustom: ret.configure, + } + + ret.DefaultGitRepoUrl = ret.AddSetupQuestionCustom("Git Repo Url", true, + "Enter the default Git repository URL for the patterns") + ret.DefaultGitRepoUrl.Value = DefaultPatternsGitRepoUrl + + ret.DefaultFolder = ret.AddSetupQuestionCustom("Git Repo Patterns Folder", true, + "Enter the default folder in the Git repository where patterns are stored") + ret.DefaultFolder.Value = DefaultPatternsGitRepoFolder + + return +} + +type PatternsLoader struct { + *plugins.PluginBase + Patterns *fsdb.PatternsEntity + + DefaultGitRepoUrl *plugins.SetupQuestion + DefaultFolder *plugins.SetupQuestion + + loadedFilePath string + + pathPatternsPrefix string + tempPatternsFolder string +} + +func (o *PatternsLoader) configure() (err error) { + o.pathPatternsPrefix = fmt.Sprintf("%v/", o.DefaultFolder.Value) + o.tempPatternsFolder = filepath.Join(os.TempDir(), o.DefaultFolder.Value) + + return +} + +func (o *PatternsLoader) IsConfigured() (ret bool) { + ret = o.PluginBase.IsConfigured() + if ret { + if _, err := os.Stat(o.loadedFilePath); os.IsNotExist(err) { + ret = false + } + } + return +} + +func (o *PatternsLoader) Setup() (err error) { + if err = o.PluginBase.Setup(); err != nil { + return + } + + if err = o.PopulateDB(); err != nil { + return + } + return +} + +// PopulateDB downloads patterns from the internet and populates the patterns folder +func (o *PatternsLoader) PopulateDB() (err error) { + fmt.Printf("Downloading patterns and Populating %s..\n", o.Patterns.Dir) + fmt.Println() + if err = o.gitCloneAndCopy(); err != nil { + return + } + + if err = o.movePatterns(); err != nil { + return + } + return +} + +// PersistPatterns copies custom patterns to the updated patterns directory +func (o *PatternsLoader) PersistPatterns() (err error) { + var currentPatterns []os.DirEntry + if currentPatterns, err = os.ReadDir(o.Patterns.Dir); err != nil { + return + } + + newPatternsFolder := o.tempPatternsFolder + var newPatterns []os.DirEntry + if newPatterns, err = os.ReadDir(newPatternsFolder); err != nil { + return + } + + for _, currentPattern := range currentPatterns { + for _, newPattern := range newPatterns { + if currentPattern.Name() == newPattern.Name() { + break + } + err = copy.Copy(filepath.Join(o.Patterns.Dir, newPattern.Name()), filepath.Join(newPatternsFolder, newPattern.Name())) + } + } + return +} + +// movePatterns copies the new patterns into the config directory +func (o *PatternsLoader) movePatterns() (err error) { + if err = os.MkdirAll(o.Patterns.Dir, os.ModePerm); err != nil { + return + } + + patternsDir := o.tempPatternsFolder + if err = o.PersistPatterns(); err != nil { + return + } + + if err = copy.Copy(patternsDir, o.Patterns.Dir); err != nil { // copies the patterns to the config directory + return + } + + //create an empty file to indicate that the patterns have been updated if not exists + _, _ = os.Create(o.loadedFilePath) + + err = os.RemoveAll(patternsDir) + return +} + +func (o *PatternsLoader) gitCloneAndCopy() (err error) { + // Clones the given repository, creating the remote, the local branches + // and fetching the objects, everything in memory: + var r *git.Repository + if r, err = git.Clone(memory.NewStorage(), nil, &git.CloneOptions{ + URL: o.DefaultGitRepoUrl.Value, + }); err != nil { + fmt.Println(err) + return + } + + // ... retrieves the branch pointed by HEAD + var ref *plumbing.Reference + if ref, err = r.Head(); err != nil { + fmt.Println(err) + return + } + + // ... retrieves the commit history for /patterns folder + var cIter object.CommitIter + if cIter, err = r.Log(&git.LogOptions{ + From: ref.Hash(), + PathFilter: func(path string) bool { + return path == o.DefaultFolder.Value || strings.HasPrefix(path, o.pathPatternsPrefix) + }, + }); err != nil { + fmt.Println(err) + return err + } + + var changes []fsdb.DirectoryChange + // ... iterates over the commits + if err = cIter.ForEach(func(c *object.Commit) (err error) { + // GetApplyVariables the files changed in this commit by comparing with its parents + parentIter := c.Parents() + if err = parentIter.ForEach(func(parent *object.Commit) (err error) { + var patch *object.Patch + if patch, err = parent.Patch(c); err != nil { + fmt.Println(err) + return + } + + for _, fileStat := range patch.Stats() { + if strings.HasPrefix(fileStat.Name, o.pathPatternsPrefix) { + dir := filepath.Dir(fileStat.Name) + changes = append(changes, fsdb.DirectoryChange{Dir: dir, Timestamp: c.Committer.When}) + } + } + return + }); err != nil { + fmt.Println(err) + return + } + return + }); err != nil { + fmt.Println(err) + return + } + + // Sort changes by timestamp + sort.Slice(changes, func(i, j int) bool { + return changes[i].Timestamp.Before(changes[j].Timestamp) + }) + + if err = o.makeUniqueList(changes); err != nil { + return + } + + var commit *object.Commit + if commit, err = r.CommitObject(ref.Hash()); err != nil { + fmt.Println(err) + return + } + + var tree *object.Tree + if tree, err = commit.Tree(); err != nil { + fmt.Println(err) + return + } + + if err = tree.Files().ForEach(func(f *object.File) (err error) { + if strings.HasPrefix(f.Name, o.pathPatternsPrefix) { + // Create the local file path + localPath := filepath.Join(os.TempDir(), f.Name) + + // Create the directories if they don't exist + if err = os.MkdirAll(filepath.Dir(localPath), os.ModePerm); err != nil { + fmt.Println(err) + return + } + + // Write the file to the local filesystem + var blob *object.Blob + if blob, err = r.BlobObject(f.Hash); err != nil { + fmt.Println(err) + return + } + err = o.writeBlobToFile(blob, localPath) + return + } + + return + }); err != nil { + fmt.Println(err) + } + + return +} + +func (o *PatternsLoader) writeBlobToFile(blob *object.Blob, path string) (err error) { + var reader io.ReadCloser + if reader, err = blob.Reader(); err != nil { + return + } + defer reader.Close() + + // Create the file + var file *os.File + if file, err = os.Create(path); err != nil { + return + } + defer file.Close() + + // Copy the contents of the blob to the file + if _, err = io.Copy(file, reader); err != nil { + return + } + return +} + +func (o *PatternsLoader) makeUniqueList(changes []fsdb.DirectoryChange) (err error) { + uniqueItems := make(map[string]bool) + for _, change := range changes { + if strings.TrimSpace(change.Dir) != "" && !strings.Contains(change.Dir, "=>") { + pattern := strings.ReplaceAll(change.Dir, o.pathPatternsPrefix, "") + pattern = strings.TrimSpace(pattern) + uniqueItems[pattern] = true + } + } + + finalList := make([]string, 0, len(uniqueItems)) + for _, change := range changes { + pattern := strings.ReplaceAll(change.Dir, o.pathPatternsPrefix, "") + pattern = strings.TrimSpace(pattern) + if _, exists := uniqueItems[pattern]; exists { + finalList = append(finalList, pattern) + delete(uniqueItems, pattern) // Remove to avoid duplicates in the final list + } + } + + joined := strings.Join(finalList, "\n") + err = os.WriteFile(o.Patterns.UniquePatternsFilePath, []byte(joined), 0o644) + return +}