mirror of
https://github.com/danielmiessler/fabric
synced 2024-11-08 07:11:06 +00:00
feat: improve Gemini model name handling
This commit is contained in:
parent
8bf32b1894
commit
92e32b926d
36
vendors/gemini/gemini.go
vendored
36
vendors/gemini/gemini.go
vendored
@ -3,6 +3,8 @@ package gemini
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/google/generative-ai-go/genai"
|
||||
@ -10,6 +12,8 @@ import (
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const modelsNamePrefix = "models/"
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
vendorName := "Gemini"
|
||||
ret = &Client{}
|
||||
@ -29,10 +33,10 @@ type Client struct {
|
||||
ApiKey *common.SetupQuestion
|
||||
}
|
||||
|
||||
func (ge *Client) ListModels() (ret []string, err error) {
|
||||
func (o *Client) ListModels() (ret []string, err error) {
|
||||
ctx := context.Background()
|
||||
var client *genai.Client
|
||||
if client, err = genai.NewClient(ctx, option.WithAPIKey(ge.ApiKey.Value)); err != nil {
|
||||
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
|
||||
return
|
||||
}
|
||||
defer client.Close()
|
||||
@ -46,22 +50,24 @@ func (ge *Client) ListModels() (ret []string, err error) {
|
||||
}
|
||||
break
|
||||
}
|
||||
ret = append(ret, resp.Name)
|
||||
|
||||
name := o.buildModelNameSimple(resp.Name)
|
||||
ret = append(ret, name)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ge *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||
systemInstruction, userText := toContent(msgs)
|
||||
|
||||
ctx := context.Background()
|
||||
var client *genai.Client
|
||||
if client, err = genai.NewClient(ctx, option.WithAPIKey(ge.ApiKey.Value)); err != nil {
|
||||
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
|
||||
return
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
model := client.GenerativeModel(opts.Model)
|
||||
model := client.GenerativeModel(o.buildModelNameFull(opts.Model))
|
||||
model.SetTemperature(float32(opts.Temperature))
|
||||
model.SetTopP(float32(opts.TopP))
|
||||
model.SystemInstruction = systemInstruction
|
||||
@ -71,21 +77,29 @@ func (ge *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret st
|
||||
return
|
||||
}
|
||||
|
||||
ret = ge.extractText(response)
|
||||
ret = o.extractText(response)
|
||||
return
|
||||
}
|
||||
|
||||
func (ge *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
func (o *Client) buildModelNameSimple(fullModelName string) string {
|
||||
return strings.TrimPrefix(fullModelName, modelsNamePrefix)
|
||||
}
|
||||
|
||||
func (o *Client) buildModelNameFull(modelName string) string {
|
||||
return fmt.Sprintf("%v%v", modelsNamePrefix, modelName)
|
||||
}
|
||||
|
||||
func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
ctx := context.Background()
|
||||
var client *genai.Client
|
||||
if client, err = genai.NewClient(ctx, option.WithAPIKey(ge.ApiKey.Value)); err != nil {
|
||||
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
|
||||
return
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
systemInstruction, userText := toContent(msgs)
|
||||
|
||||
model := client.GenerativeModel(opts.Model)
|
||||
model := client.GenerativeModel(o.buildModelNameFull(opts.Model))
|
||||
model.SetTemperature(float32(opts.Temperature))
|
||||
model.SetTopP(float32(opts.TopP))
|
||||
model.SystemInstruction = systemInstruction
|
||||
@ -112,7 +126,7 @@ func (ge *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, c
|
||||
}
|
||||
}
|
||||
|
||||
func (ge *Client) extractText(response *genai.GenerateContentResponse) (ret string) {
|
||||
func (o *Client) extractText(response *genai.GenerateContentResponse) (ret string) {
|
||||
for _, candidate := range response.Candidates {
|
||||
if candidate.Content == nil {
|
||||
break
|
||||
|
Loading…
Reference in New Issue
Block a user