mirror of
https://github.com/danielmiessler/fabric
synced 2024-11-08 07:11:06 +00:00
test: implement test for common package
This commit is contained in:
parent
69375f2fbc
commit
4d77ed30e9
176
common/configurable_test.go
Normal file
176
common/configurable_test.go
Normal file
@ -0,0 +1,176 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfigurable_AddSetting(t *testing.T) {
|
||||
conf := &Configurable{
|
||||
Settings: Settings{},
|
||||
Label: "TestConfigurable",
|
||||
EnvNamePrefix: "TEST_",
|
||||
}
|
||||
|
||||
setting := conf.AddSetting("test_setting", true)
|
||||
assert.Equal(t, "TEST_test_setting", setting.EnvVariable)
|
||||
assert.True(t, setting.Required)
|
||||
assert.Contains(t, conf.Settings, setting)
|
||||
}
|
||||
|
||||
func TestConfigurable_Configure(t *testing.T) {
|
||||
setting := &Setting{
|
||||
EnvVariable: "TEST_SETTING",
|
||||
Required: true,
|
||||
}
|
||||
conf := &Configurable{
|
||||
Settings: Settings{setting},
|
||||
Label: "TestConfigurable",
|
||||
}
|
||||
|
||||
os.Setenv("TEST_SETTING", "test_value")
|
||||
err := conf.Configure()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_value", setting.Value)
|
||||
}
|
||||
|
||||
func TestConfigurable_Setup(t *testing.T) {
|
||||
setting := &Setting{
|
||||
EnvVariable: "TEST_SETTING",
|
||||
Required: false,
|
||||
}
|
||||
conf := &Configurable{
|
||||
Settings: Settings{setting},
|
||||
Label: "TestConfigurable",
|
||||
}
|
||||
|
||||
err := conf.Setup()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSetting_IsValid(t *testing.T) {
|
||||
setting := &Setting{
|
||||
EnvVariable: "TEST_SETTING",
|
||||
Value: "some_value",
|
||||
Required: true,
|
||||
}
|
||||
|
||||
assert.True(t, setting.IsValid())
|
||||
|
||||
setting.Value = ""
|
||||
assert.False(t, setting.IsValid())
|
||||
}
|
||||
|
||||
func TestSetting_Configure(t *testing.T) {
|
||||
os.Setenv("TEST_SETTING", "test_value")
|
||||
setting := &Setting{
|
||||
EnvVariable: "TEST_SETTING",
|
||||
Required: true,
|
||||
}
|
||||
err := setting.Configure()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_value", setting.Value)
|
||||
}
|
||||
|
||||
func TestSetting_FillEnvFileContent(t *testing.T) {
|
||||
buffer := &bytes.Buffer{}
|
||||
setting := &Setting{
|
||||
EnvVariable: "TEST_SETTING",
|
||||
Value: "test_value",
|
||||
}
|
||||
setting.FillEnvFileContent(buffer)
|
||||
|
||||
expected := "TEST_SETTING=test_value\n"
|
||||
assert.Equal(t, expected, buffer.String())
|
||||
}
|
||||
|
||||
func TestSetting_Print(t *testing.T) {
|
||||
setting := &Setting{
|
||||
EnvVariable: "TEST_SETTING",
|
||||
Value: "test_value",
|
||||
}
|
||||
expected := "TEST_SETTING: test_value\n"
|
||||
fmtOutput := captureOutput(func() {
|
||||
setting.Print()
|
||||
})
|
||||
assert.Equal(t, expected, fmtOutput)
|
||||
}
|
||||
|
||||
func TestSetupQuestion_Ask(t *testing.T) {
|
||||
setting := &Setting{
|
||||
EnvVariable: "TEST_SETTING",
|
||||
Required: true,
|
||||
}
|
||||
question := &SetupQuestion{
|
||||
Setting: setting,
|
||||
Question: "Enter test setting:",
|
||||
}
|
||||
input := "user_value\n"
|
||||
fmtInput := captureInput(input)
|
||||
defer fmtInput()
|
||||
err := question.Ask("TestConfigurable")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "user_value", setting.Value)
|
||||
}
|
||||
|
||||
func TestSettings_IsConfigured(t *testing.T) {
|
||||
settings := Settings{
|
||||
{EnvVariable: "TEST_SETTING1", Value: "value1", Required: true},
|
||||
{EnvVariable: "TEST_SETTING2", Value: "", Required: false},
|
||||
}
|
||||
|
||||
assert.True(t, settings.IsConfigured())
|
||||
|
||||
settings[0].Value = ""
|
||||
assert.False(t, settings.IsConfigured())
|
||||
}
|
||||
|
||||
func TestSettings_Configure(t *testing.T) {
|
||||
os.Setenv("TEST_SETTING", "test_value")
|
||||
settings := Settings{
|
||||
{EnvVariable: "TEST_SETTING", Required: true},
|
||||
}
|
||||
|
||||
err := settings.Configure()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_value", settings[0].Value)
|
||||
}
|
||||
|
||||
func TestSettings_FillEnvFileContent(t *testing.T) {
|
||||
buffer := &bytes.Buffer{}
|
||||
settings := Settings{
|
||||
{EnvVariable: "TEST_SETTING", Value: "test_value"},
|
||||
}
|
||||
settings.FillEnvFileContent(buffer)
|
||||
|
||||
expected := "TEST_SETTING=test_value\n"
|
||||
assert.Equal(t, expected, buffer.String())
|
||||
}
|
||||
|
||||
// captureOutput captures the output of a function call
|
||||
func captureOutput(f func()) string {
|
||||
var buf bytes.Buffer
|
||||
stdout := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
f()
|
||||
_ = w.Close()
|
||||
os.Stdout = stdout
|
||||
buf.ReadFrom(r)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// captureInput captures the input for a function call
|
||||
func captureInput(input string) func() {
|
||||
r, w, _ := os.Pipe()
|
||||
_, _ = w.WriteString(input)
|
||||
w.Close()
|
||||
stdin := os.Stdin
|
||||
os.Stdin = r
|
||||
return func() {
|
||||
os.Stdin = stdin
|
||||
}
|
||||
}
|
@ -19,3 +19,24 @@ type ChatOptions struct {
|
||||
PresencePenalty float64
|
||||
FrequencyPenalty float64
|
||||
}
|
||||
|
||||
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
||||
func NormalizeMessages(msgs []*Message, defaultUserMessage string) (ret []*Message) {
|
||||
// Iterate over messages to enforce the odd position rule for user messages
|
||||
fullMessageIndex := 0
|
||||
for _, message := range msgs {
|
||||
if message.Content == "" {
|
||||
// Skip empty messages as the anthropic API doesn't accept them
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure, that each odd position shall be a user message
|
||||
if fullMessageIndex%2 == 0 && message.Role != "user" {
|
||||
ret = append(ret, &Message{Role: "user", Content: defaultUserMessage})
|
||||
fullMessageIndex++
|
||||
}
|
||||
ret = append(ret, message)
|
||||
fullMessageIndex++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
25
common/domain_test.go
Normal file
25
common/domain_test.go
Normal file
@ -0,0 +1,25 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeMessages(t *testing.T) {
|
||||
msgs := []*Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "bot", Content: "Hi there!"},
|
||||
{Role: "bot", Content: ""},
|
||||
{Role: "user", Content: ""},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
}
|
||||
|
||||
expected := []*Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "bot", Content: "Hi there!"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
}
|
||||
|
||||
actual := NormalizeMessages(msgs, "default")
|
||||
assert.Equal(t, expected, actual)
|
||||
}
|
@ -1,22 +0,0 @@
|
||||
package common
|
||||
|
||||
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
||||
func NormalizeMessages(msgs []*Message, defaultUserMessage string) (ret []*Message) {
|
||||
// Iterate over messages to enforce the odd position rule for user messages
|
||||
fullMessageIndex := 0
|
||||
for _, message := range msgs {
|
||||
if message.Content == "" {
|
||||
// Skip empty messages as the anthropic API doesn't accept them
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure, that each odd position shall be a user message
|
||||
if fullMessageIndex%2 == 0 && message.Role != "user" {
|
||||
ret = append(ret, &Message{Role: "user", Content: defaultUserMessage})
|
||||
fullMessageIndex++
|
||||
}
|
||||
ret = append(ret, message)
|
||||
fullMessageIndex++
|
||||
}
|
||||
return
|
||||
}
|
@ -1,12 +0,0 @@
|
||||
package common
|
||||
|
||||
type Vendor interface {
|
||||
GetName() string
|
||||
IsConfigured() bool
|
||||
Configure() error
|
||||
ListModels() ([]string, error)
|
||||
SendStream([]*Message, *ChatOptions, chan string) error
|
||||
Send([]*Message, *ChatOptions) (string, error)
|
||||
GetSettings() Settings
|
||||
Setup() error
|
||||
}
|
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/db"
|
||||
"github.com/danielmiessler/fabric/vendors"
|
||||
)
|
||||
|
||||
type Chatter struct {
|
||||
@ -12,7 +13,7 @@ type Chatter struct {
|
||||
Stream bool
|
||||
|
||||
model string
|
||||
vendor common.Vendor
|
||||
vendor vendors.Vendor
|
||||
}
|
||||
|
||||
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
|
||||
|
@ -3,29 +3,29 @@ package core
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/vendors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func NewVendorsManager() *VendorsManager {
|
||||
return &VendorsManager{
|
||||
Vendors: map[string]common.Vendor{},
|
||||
Vendors: map[string]vendors.Vendor{},
|
||||
}
|
||||
}
|
||||
|
||||
type VendorsManager struct {
|
||||
Vendors map[string]common.Vendor
|
||||
Vendors map[string]vendors.Vendor
|
||||
Models *VendorsModels
|
||||
}
|
||||
|
||||
func (o *VendorsManager) AddVendors(vendors ...common.Vendor) {
|
||||
func (o *VendorsManager) AddVendors(vendors ...vendors.Vendor) {
|
||||
for _, vendor := range vendors {
|
||||
o.Vendors[vendor.GetName()] = vendor
|
||||
}
|
||||
}
|
||||
|
||||
func (o *VendorsManager) Reset() {
|
||||
o.Vendors = map[string]common.Vendor{}
|
||||
o.Vendors = map[string]vendors.Vendor{}
|
||||
o.Models = nil
|
||||
}
|
||||
|
||||
@ -40,7 +40,7 @@ func (o *VendorsManager) HasVendors() bool {
|
||||
return len(o.Vendors) > 0
|
||||
}
|
||||
|
||||
func (o *VendorsManager) FindByName(name string) common.Vendor {
|
||||
func (o *VendorsManager) FindByName(name string) vendors.Vendor {
|
||||
return o.Vendors[name]
|
||||
}
|
||||
|
||||
@ -76,7 +76,7 @@ func (o *VendorsManager) readModels() {
|
||||
}
|
||||
|
||||
func (o *VendorsManager) fetchVendorModels(
|
||||
ctx context.Context, wg *sync.WaitGroup, vendor common.Vendor, resultsChan chan<- modelResult) {
|
||||
ctx context.Context, wg *sync.WaitGroup, vendor vendors.Vendor, resultsChan chan<- modelResult) {
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
|
4
go.mod
4
go.mod
@ -16,6 +16,7 @@ require (
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/samber/lo v1.47.0
|
||||
github.com/sashabaranov/go-openai v1.28.2
|
||||
github.com/stretchr/testify v1.9.0
|
||||
google.golang.org/api v0.192.0
|
||||
gopkg.in/gookit/color.v1 v1.1.6
|
||||
)
|
||||
@ -32,6 +33,7 @@ require (
|
||||
github.com/ProtonMail/go-crypto v1.0.0 // indirect
|
||||
github.com/cloudflare/circl v1.3.7 // indirect
|
||||
github.com/cyphar/filepath-securejoin v0.2.4 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
|
||||
@ -46,6 +48,7 @@ require (
|
||||
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
|
||||
github.com/kevinburke/ssh_config v1.2.0 // indirect
|
||||
github.com/pjbgf/sha1cd v0.3.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
|
||||
github.com/skeema/knownhosts v1.2.2 // indirect
|
||||
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
||||
@ -69,4 +72,5 @@ require (
|
||||
google.golang.org/grpc v1.64.1 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
gopkg.in/warnings.v0 v0.1.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
28
utils/log.go
28
utils/log.go
@ -1,28 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"gopkg.in/gookit/color.v1"
|
||||
)
|
||||
|
||||
func Print(info string) {
|
||||
fmt.Println(info)
|
||||
}
|
||||
|
||||
func PrintWarning (s string) {
|
||||
fmt.Println(color.Yellow.Render("Warning: " + s))
|
||||
}
|
||||
|
||||
func LogError(err error) {
|
||||
fmt.Fprintln(os.Stderr, color.Red.Render(err.Error()))
|
||||
}
|
||||
|
||||
func LogWarning(err error) {
|
||||
fmt.Fprintln(os.Stderr, color.Yellow.Render(err.Error()))
|
||||
}
|
||||
|
||||
func Log(info string) {
|
||||
fmt.Println(color.Green.Render(info))
|
||||
}
|
14
vendors/vendor.go
vendored
Normal file
14
vendors/vendor.go
vendored
Normal file
@ -0,0 +1,14 @@
|
||||
package vendors
|
||||
|
||||
import "github.com/danielmiessler/fabric/common"
|
||||
|
||||
type Vendor interface {
|
||||
GetName() string
|
||||
IsConfigured() bool
|
||||
Configure() error
|
||||
ListModels() ([]string, error)
|
||||
SendStream([]*common.Message, *common.ChatOptions, chan string) error
|
||||
Send([]*common.Message, *common.ChatOptions) (string, error)
|
||||
GetSettings() common.Settings
|
||||
Setup() error
|
||||
}
|
Loading…
Reference in New Issue
Block a user