fabric/vendors/gemini/gemini.go

158 lines
3.7 KiB
Go
Raw Normal View History

2024-08-16 19:43:27 +00:00
package gemini
import (
"context"
"errors"
"fmt"
"strings"
2024-08-16 19:43:27 +00:00
"github.com/danielmiessler/fabric/common"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)
const modelsNamePrefix = "models/"
2024-08-16 19:43:27 +00:00
func NewClient() (ret *Client) {
vendorName := "Gemini"
ret = &Client{}
ret.Configurable = &common.Configurable{
Label: vendorName,
EnvNamePrefix: common.BuildEnvVariablePrefix(vendorName),
}
ret.ApiKey = ret.Configurable.AddSetupQuestion("API key", true)
return
}
type Client struct {
*common.Configurable
ApiKey *common.SetupQuestion
}
func (o *Client) ListModels() (ret []string, err error) {
2024-08-16 19:43:27 +00:00
ctx := context.Background()
var client *genai.Client
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
2024-08-16 19:43:27 +00:00
return
}
defer client.Close()
iter := client.ListModels(ctx)
for {
var resp *genai.ModelInfo
if resp, err = iter.Next(); err != nil {
if errors.Is(err, iterator.Done) {
err = nil
}
2024-08-16 19:43:27 +00:00
break
}
name := o.buildModelNameSimple(resp.Name)
ret = append(ret, name)
2024-08-16 19:43:27 +00:00
}
return
}
func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
systemInstruction, messages := toMessages(msgs)
2024-08-16 19:43:27 +00:00
var client *genai.Client
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
2024-08-16 19:43:27 +00:00
return
}
defer client.Close()
model := client.GenerativeModel(o.buildModelNameFull(opts.Model))
2024-08-16 19:43:27 +00:00
model.SetTemperature(float32(opts.Temperature))
model.SetTopP(float32(opts.TopP))
model.SystemInstruction = systemInstruction
var response *genai.GenerateContentResponse
if response, err = model.GenerateContent(ctx, messages...); err != nil {
2024-08-16 19:43:27 +00:00
return
}
ret = o.extractText(response)
2024-08-16 19:43:27 +00:00
return
}
func (o *Client) buildModelNameSimple(fullModelName string) string {
return strings.TrimPrefix(fullModelName, modelsNamePrefix)
}
func (o *Client) buildModelNameFull(modelName string) string {
return fmt.Sprintf("%v%v", modelsNamePrefix, modelName)
}
func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, channel chan string) (err error) {
2024-08-16 19:43:27 +00:00
ctx := context.Background()
var client *genai.Client
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
2024-08-16 19:43:27 +00:00
return
}
defer client.Close()
systemInstruction, messages := toMessages(msgs)
2024-08-16 19:43:27 +00:00
model := client.GenerativeModel(o.buildModelNameFull(opts.Model))
2024-08-16 19:43:27 +00:00
model.SetTemperature(float32(opts.Temperature))
model.SetTopP(float32(opts.TopP))
model.SystemInstruction = systemInstruction
iter := model.GenerateContentStream(ctx, messages...)
2024-08-16 19:43:27 +00:00
for {
if resp, iterErr := iter.Next(); iterErr == nil {
2024-08-16 19:43:27 +00:00
for _, candidate := range resp.Candidates {
if candidate.Content != nil {
for _, part := range candidate.Content.Parts {
if text, ok := part.(genai.Text); ok {
channel <- string(text)
}
}
}
}
} else {
if !errors.Is(iterErr, iterator.Done) {
channel <- fmt.Sprintf("%v\n", iterErr)
}
2024-08-16 19:43:27 +00:00
close(channel)
break
2024-08-16 19:43:27 +00:00
}
}
return
2024-08-16 19:43:27 +00:00
}
func (o *Client) extractText(response *genai.GenerateContentResponse) (ret string) {
2024-08-16 19:43:27 +00:00
for _, candidate := range response.Candidates {
if candidate.Content == nil {
break
}
for _, part := range candidate.Content.Parts {
if text, ok := part.(genai.Text); ok {
ret += string(text)
}
}
}
return
}
func toMessages(msgs []*common.Message) (systemInstruction *genai.Content, messages []genai.Part) {
if len(msgs) >= 2 {
systemInstruction = &genai.Content{
Parts: []genai.Part{
genai.Text(msgs[0].Content),
},
}
for _, msg := range msgs[1:] {
messages = append(messages, genai.Text(msg.Content))
}
} else {
messages = append(messages, genai.Text(msgs[0].Content))
2024-08-16 19:43:27 +00:00
}
return
}