mirror of
https://github.com/danielmiessler/fabric
synced 2024-11-08 07:11:06 +00:00
refactor: accept context as parameter of Vendor.Send
In golang, contexts should be propagated downwards in order to be able to provide features such as cancellation. This commit refactors the Vendor interface to accept a context as a first parameter so that it can be propagated downwards.
This commit is contained in:
parent
e8d5fba256
commit
21f4b5f774
@ -1,7 +1,9 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/db"
|
||||
"github.com/danielmiessler/fabric/vendors"
|
||||
@ -17,7 +19,6 @@ type Chatter struct {
|
||||
}
|
||||
|
||||
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
|
||||
|
||||
var chatRequest *Chat
|
||||
if chatRequest, err = o.NewChat(request); err != nil {
|
||||
return
|
||||
@ -45,7 +46,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
|
||||
fmt.Print(response)
|
||||
}
|
||||
} else {
|
||||
if message, err = o.vendor.Send(session.Messages, opts); err != nil {
|
||||
if message, err = o.vendor.Send(context.Background(), session.Messages, opts); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -58,7 +59,6 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
|
||||
}
|
||||
|
||||
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
|
||||
|
||||
ret = &Chat{}
|
||||
|
||||
if request.ContextName != "" {
|
||||
|
@ -2,8 +2,10 @@ package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
)
|
||||
|
||||
func TestNewVendorsManager(t *testing.T) {
|
||||
@ -90,17 +92,17 @@ type MockVendor struct {
|
||||
}
|
||||
|
||||
func (o *MockVendor) SendStream(messages []*common.Message, options *common.ChatOptions, strings chan string) error {
|
||||
//TODO implement me
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (o *MockVendor) Send(messages []*common.Message, options *common.ChatOptions) (string, error) {
|
||||
//TODO implement me
|
||||
func (o *MockVendor) Send(ctx context.Context, messages []*common.Message, options *common.ChatOptions) (string, error) {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (o *MockVendor) SetupFillEnvFileContent(buffer *bytes.Buffer) {
|
||||
//TODO implement me
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
|
3
vendors/anthropic/anthropic.go
vendored
3
vendors/anthropic/anthropic.go
vendored
@ -79,8 +79,7 @@ func (an *Client) SendStream(
|
||||
return
|
||||
}
|
||||
|
||||
func (an *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||
ctx := context.Background()
|
||||
func (an *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||
req := an.buildMessagesRequest(msgs, opts)
|
||||
req.Stream = false
|
||||
|
||||
|
3
vendors/gemini/gemini.go
vendored
3
vendors/gemini/gemini.go
vendored
@ -57,10 +57,9 @@ func (o *Client) ListModels() (ret []string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||
func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||
systemInstruction, messages := toMessages(msgs)
|
||||
|
||||
ctx := context.Background()
|
||||
var client *genai.Client
|
||||
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
|
||||
return
|
||||
|
4
vendors/ollama/ollama.go
vendored
4
vendors/ollama/ollama.go
vendored
@ -79,7 +79,7 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||
func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||
bf := false
|
||||
|
||||
req := o.createChatRequest(msgs, opts)
|
||||
@ -90,8 +90,6 @@ func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret str
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
if err = o.client.Chat(ctx, &req, respFunc); err != nil {
|
||||
fmt.Printf("FRED --> %s\n", err)
|
||||
}
|
||||
|
4
vendors/openai/openai.go
vendored
4
vendors/openai/openai.go
vendored
@ -96,11 +96,11 @@ func (o *Client) SendStream(
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||
func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||
req := o.buildChatCompletionRequest(msgs, opts)
|
||||
|
||||
var resp goopenai.ChatCompletionResponse
|
||||
if resp, err = o.ApiClient.CreateChatCompletion(context.Background(), req); err != nil {
|
||||
if resp, err = o.ApiClient.CreateChatCompletion(ctx, req); err != nil {
|
||||
return
|
||||
}
|
||||
ret = resp.Choices[0].Message.Content
|
||||
|
4
vendors/vendor.go
vendored
4
vendors/vendor.go
vendored
@ -2,6 +2,8 @@ package vendors
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
)
|
||||
|
||||
@ -11,7 +13,7 @@ type Vendor interface {
|
||||
Configure() error
|
||||
ListModels() ([]string, error)
|
||||
SendStream([]*common.Message, *common.ChatOptions, chan string) error
|
||||
Send([]*common.Message, *common.ChatOptions) (string, error)
|
||||
Send(context.Context, []*common.Message, *common.ChatOptions) (string, error)
|
||||
Setup() error
|
||||
SetupFillEnvFileContent(*bytes.Buffer)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user