mirror of
https://github.com/danielmiessler/fabric
synced 2024-11-10 07:10:31 +00:00
130 lines
3.0 KiB
Go
130 lines
3.0 KiB
Go
|
package openai
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
|
||
|
"github.com/danielmiessler/fabric/common"
|
||
|
"github.com/samber/lo"
|
||
|
"github.com/sashabaranov/go-openai"
|
||
|
goopenai "github.com/sashabaranov/go-openai"
|
||
|
)
|
||
|
|
||
|
func NewClient() (ret *Client) {
|
||
|
return NewClientCompatible("OpenAI", nil)
|
||
|
}
|
||
|
|
||
|
func NewClientCompatible(vendorName string, configureCustom func() error) (ret *Client) {
|
||
|
ret = &Client{}
|
||
|
|
||
|
if configureCustom == nil {
|
||
|
configureCustom = ret.configure
|
||
|
}
|
||
|
|
||
|
ret.Configurable = &common.Configurable{
|
||
|
Label: vendorName,
|
||
|
EnvNamePrefix: common.BuildEnvVariablePrefix(vendorName),
|
||
|
ConfigureCustom: configureCustom,
|
||
|
}
|
||
|
|
||
|
ret.ApiKey = ret.AddSetupQuestion("API key", true)
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
type Client struct {
|
||
|
*common.Configurable
|
||
|
ApiKey *common.SetupQuestion
|
||
|
ApiClient *openai.Client
|
||
|
}
|
||
|
|
||
|
func (o *Client) configure() (ret error) {
|
||
|
o.ApiClient = openai.NewClient(o.ApiKey.Value)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (o *Client) ListModels() (ret []string, err error) {
|
||
|
var models openai.ModelsList
|
||
|
if models, err = o.ApiClient.ListModels(context.Background()); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
model := models.Models
|
||
|
for _, mod := range model {
|
||
|
ret = append(ret, mod.ID)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (o *Client) SendStream(
|
||
|
msgs []*common.Message, opts *common.ChatOptions, channel chan string,
|
||
|
) (err error) {
|
||
|
req := o.buildChatCompletionRequest(msgs, opts)
|
||
|
req.Stream = true
|
||
|
|
||
|
var stream *openai.ChatCompletionStream
|
||
|
if stream, err = o.ApiClient.CreateChatCompletionStream(context.Background(), req); err != nil {
|
||
|
fmt.Printf("ChatCompletionStream error: %v\n", err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
defer stream.Close()
|
||
|
|
||
|
for {
|
||
|
var response openai.ChatCompletionStreamResponse
|
||
|
if response, err = stream.Recv(); err == nil {
|
||
|
channel <- response.Choices[0].Delta.Content
|
||
|
} else if errors.Is(err, io.EOF) {
|
||
|
channel <- "\n"
|
||
|
close(channel)
|
||
|
err = nil
|
||
|
break
|
||
|
} else if err != nil {
|
||
|
fmt.Printf("\nStream error: %v\n", err)
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||
|
req := o.buildChatCompletionRequest(msgs, opts)
|
||
|
|
||
|
var resp goopenai.ChatCompletionResponse
|
||
|
if resp, err = o.ApiClient.CreateChatCompletion(context.Background(), req); err != nil {
|
||
|
return
|
||
|
}
|
||
|
ret = resp.Choices[0].Message.Content
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (o *Client) buildChatCompletionRequest(
|
||
|
msgs []*common.Message, opts *common.ChatOptions,
|
||
|
) (ret goopenai.ChatCompletionRequest) {
|
||
|
messages := lo.Map(msgs, func(message *common.Message, _ int) goopenai.ChatCompletionMessage {
|
||
|
var role string
|
||
|
|
||
|
switch message.Role {
|
||
|
case "user":
|
||
|
role = goopenai.ChatMessageRoleUser
|
||
|
case "system":
|
||
|
role = goopenai.ChatMessageRoleSystem
|
||
|
default:
|
||
|
role = goopenai.ChatMessageRoleSystem
|
||
|
}
|
||
|
return goopenai.ChatCompletionMessage{Role: role, Content: message.Content}
|
||
|
})
|
||
|
|
||
|
ret = goopenai.ChatCompletionRequest{
|
||
|
Model: opts.Model,
|
||
|
Temperature: float32(opts.Temperature),
|
||
|
TopP: float32(opts.TopP),
|
||
|
PresencePenalty: float32(opts.PresencePenalty),
|
||
|
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||
|
Messages: messages,
|
||
|
}
|
||
|
return
|
||
|
}
|