feat: improve Gemini model name handling

This commit is contained in:
Eugen Eisler 2024-08-17 00:59:34 +02:00
parent 8bf32b1894
commit 92e32b926d

View File

@ -3,6 +3,8 @@ package gemini
import (
"context"
"errors"
"fmt"
"strings"
"github.com/danielmiessler/fabric/common"
"github.com/google/generative-ai-go/genai"
@ -10,6 +12,8 @@ import (
"google.golang.org/api/option"
)
const modelsNamePrefix = "models/"
func NewClient() (ret *Client) {
vendorName := "Gemini"
ret = &Client{}
@ -29,10 +33,10 @@ type Client struct {
ApiKey *common.SetupQuestion
}
func (ge *Client) ListModels() (ret []string, err error) {
func (o *Client) ListModels() (ret []string, err error) {
ctx := context.Background()
var client *genai.Client
if client, err = genai.NewClient(ctx, option.WithAPIKey(ge.ApiKey.Value)); err != nil {
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
return
}
defer client.Close()
@ -46,22 +50,24 @@ func (ge *Client) ListModels() (ret []string, err error) {
}
break
}
ret = append(ret, resp.Name)
name := o.buildModelNameSimple(resp.Name)
ret = append(ret, name)
}
return
}
func (ge *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
systemInstruction, userText := toContent(msgs)
ctx := context.Background()
var client *genai.Client
if client, err = genai.NewClient(ctx, option.WithAPIKey(ge.ApiKey.Value)); err != nil {
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
return
}
defer client.Close()
model := client.GenerativeModel(opts.Model)
model := client.GenerativeModel(o.buildModelNameFull(opts.Model))
model.SetTemperature(float32(opts.Temperature))
model.SetTopP(float32(opts.TopP))
model.SystemInstruction = systemInstruction
@ -71,21 +77,29 @@ func (ge *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret st
return
}
ret = ge.extractText(response)
ret = o.extractText(response)
return
}
func (ge *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, channel chan string) (err error) {
func (o *Client) buildModelNameSimple(fullModelName string) string {
return strings.TrimPrefix(fullModelName, modelsNamePrefix)
}
func (o *Client) buildModelNameFull(modelName string) string {
return fmt.Sprintf("%v%v", modelsNamePrefix, modelName)
}
func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, channel chan string) (err error) {
ctx := context.Background()
var client *genai.Client
if client, err = genai.NewClient(ctx, option.WithAPIKey(ge.ApiKey.Value)); err != nil {
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
return
}
defer client.Close()
systemInstruction, userText := toContent(msgs)
model := client.GenerativeModel(opts.Model)
model := client.GenerativeModel(o.buildModelNameFull(opts.Model))
model.SetTemperature(float32(opts.Temperature))
model.SetTopP(float32(opts.TopP))
model.SystemInstruction = systemInstruction
@ -112,7 +126,7 @@ func (ge *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, c
}
}
func (ge *Client) extractText(response *genai.GenerateContentResponse) (ret string) {
func (o *Client) extractText(response *genai.GenerateContentResponse) (ret string) {
for _, candidate := range response.Candidates {
if candidate.Content == nil {
break