refactor: accept context as parameter of Vendor.Send

In golang, contexts should be propagated downwards in order to be able
to provide features such as cancellation.

This commit refactors the Vendor interface to accept a context as a
first parameter so that it can be propagated downwards.
This commit is contained in:
ALX99 2024-08-26 19:34:15 +09:00
parent e8d5fba256
commit 21f4b5f774
7 changed files with 18 additions and 18 deletions

View File

@ -1,7 +1,9 @@
package core
import (
"context"
"fmt"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/db"
"github.com/danielmiessler/fabric/vendors"
@ -17,7 +19,6 @@ type Chatter struct {
}
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
var chatRequest *Chat
if chatRequest, err = o.NewChat(request); err != nil {
return
@ -45,7 +46,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
fmt.Print(response)
}
} else {
if message, err = o.vendor.Send(session.Messages, opts); err != nil {
if message, err = o.vendor.Send(context.Background(), session.Messages, opts); err != nil {
return
}
}
@ -58,7 +59,6 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
}
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
ret = &Chat{}
if request.ContextName != "" {

View File

@ -2,8 +2,10 @@ package core
import (
"bytes"
"github.com/danielmiessler/fabric/common"
"context"
"testing"
"github.com/danielmiessler/fabric/common"
)
func TestNewVendorsManager(t *testing.T) {
@ -90,17 +92,17 @@ type MockVendor struct {
}
func (o *MockVendor) SendStream(messages []*common.Message, options *common.ChatOptions, strings chan string) error {
//TODO implement me
// TODO implement me
panic("implement me")
}
func (o *MockVendor) Send(messages []*common.Message, options *common.ChatOptions) (string, error) {
//TODO implement me
func (o *MockVendor) Send(ctx context.Context, messages []*common.Message, options *common.ChatOptions) (string, error) {
// TODO implement me
panic("implement me")
}
func (o *MockVendor) SetupFillEnvFileContent(buffer *bytes.Buffer) {
//TODO implement me
// TODO implement me
panic("implement me")
}

View File

@ -79,8 +79,7 @@ func (an *Client) SendStream(
return
}
func (an *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
ctx := context.Background()
func (an *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
req := an.buildMessagesRequest(msgs, opts)
req.Stream = false

View File

@ -57,10 +57,9 @@ func (o *Client) ListModels() (ret []string, err error) {
return
}
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
systemInstruction, messages := toMessages(msgs)
ctx := context.Background()
var client *genai.Client
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
return

View File

@ -79,7 +79,7 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
return
}
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
bf := false
req := o.createChatRequest(msgs, opts)
@ -90,8 +90,6 @@ func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret str
return
}
ctx := context.Background()
if err = o.client.Chat(ctx, &req, respFunc); err != nil {
fmt.Printf("FRED --> %s\n", err)
}

View File

@ -96,11 +96,11 @@ func (o *Client) SendStream(
return
}
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
req := o.buildChatCompletionRequest(msgs, opts)
var resp goopenai.ChatCompletionResponse
if resp, err = o.ApiClient.CreateChatCompletion(context.Background(), req); err != nil {
if resp, err = o.ApiClient.CreateChatCompletion(ctx, req); err != nil {
return
}
ret = resp.Choices[0].Message.Content

4
vendors/vendor.go vendored
View File

@ -2,6 +2,8 @@ package vendors
import (
"bytes"
"context"
"github.com/danielmiessler/fabric/common"
)
@ -11,7 +13,7 @@ type Vendor interface {
Configure() error
ListModels() ([]string, error)
SendStream([]*common.Message, *common.ChatOptions, chan string) error
Send([]*common.Message, *common.ChatOptions) (string, error)
Send(context.Context, []*common.Message, *common.ChatOptions) (string, error)
Setup() error
SetupFillEnvFileContent(*bytes.Buffer)
}