feat: allow patching req body with client config (#534)

This commit is contained in:
sigoden 2024-05-22 21:29:23 +08:00 committed by GitHub
parent 91a06543b2
commit ba3bcfd67c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 194 additions and 115 deletions

33
Cargo.lock generated
View File

@ -54,6 +54,7 @@ dependencies = [
"indexmap",
"inquire",
"is-terminal",
"json-patch",
"lazy_static",
"log",
"mime_guess",
@ -715,6 +716,15 @@ dependencies = [
"miniz_oxide",
]
[[package]]
name = "fluent-uri"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17c704e9dbe1ddd863da1e6ff3567795087b1eb201ce80d8fa81162e1516500d"
dependencies = [
"bitflags 1.3.2",
]
[[package]]
name = "fnv"
version = "1.0.7"
@ -1139,6 +1149,29 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "json-patch"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b1fb8864823fad91877e6caea0baca82e49e8db50f8e5c9f9a453e27d3330fc"
dependencies = [
"jsonptr",
"serde",
"serde_json",
"thiserror",
]
[[package]]
name = "jsonptr"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c6e529149475ca0b2820835d3dce8fcc41c6b943ca608d32f35b449255e4627"
dependencies = [
"fluent-uri",
"serde",
"serde_json",
]
[[package]]
name = "lazy_static"
version = "1.4.0"

View File

@ -58,6 +58,7 @@ urlencoding = "2.1.3"
unicode-segmentation = "1.11.0"
num_cpus = "1.16.0"
threadpool = "1.8.1"
json-patch = { version = "2.0.0", default-features = false }
[dependencies.reqwest]
version = "0.12.0"

View File

@ -37,8 +37,9 @@ clients:
# - name: xxxx # The model name
# max_input_tokens: 100000
# supports_vision: true
# extra_fields: # Set custom parameters, will merge with the body json
# key: value
# patches:
# <regex>: # The regex to match model names, e.g. '.*' 'gpt-4o' 'gpt-4o|gpt-4-.*'
# request_body: # The JSON to be merged with the request body.
# extra:
# proxy: socks5://127.0.0.1:1080 # Set https/socks5 proxy. ENV: HTTPS_PROXY/https_proxy/ALL_PROXY/all_proxy
# connect_timeout: 10 # Set timeout in seconds for connect to api
@ -62,15 +63,18 @@ clients:
# See https://ai.google.dev/docs
- type: gemini
api_key: xxx # ENV: {client}_API_KEY
safetySettings:
- category: HARM_CATEGORY_HARASSMENT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_HATE_SPEECH
threshold: BLOCK_NONE
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_DANGEROUS_CONTENT
threshold: BLOCK_NONE
patches:
'.*':
request_body: # Override safetySettings for all models
safetySettings:
- category: HARM_CATEGORY_HARASSMENT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_HATE_SPEECH
threshold: BLOCK_NONE
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_DANGEROUS_CONTENT
threshold: BLOCK_NONE
# See https://docs.anthropic.com/claude/reference/getting-started-with-the-api
- type: claude
@ -123,15 +127,18 @@ clients:
# Run `gcloud auth application-default login` to init the adc file
# see https://cloud.google.com/docs/authentication/external/set-up-adc
adc_file: <path-to/gcloud/application_default_credentials.json>
safetySettings:
- category: HARM_CATEGORY_HARASSMENT
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_HATE_SPEECH
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_DANGEROUS_CONTENT
threshold: BLOCK_ONLY_HIGH
patches:
'gemini-.*':
request_body: # Override safetySettings for all gemini models
safetySettings:
- category: HARM_CATEGORY_HARASSMENT
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_HATE_SPEECH
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_DANGEROUS_CONTENT
threshold: BLOCK_ONLY_HIGH
# See https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude
- type: vertexai-claude

View File

@ -1,5 +1,7 @@
use super::openai::openai_build_body;
use super::{AzureOpenAIClient, ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData};
use super::{
openai::*, AzureOpenAIClient, Client, ExtraConfig, Model, ModelData, ModelPatches,
PromptAction, PromptKind, SendData,
};
use anyhow::Result;
use reqwest::{Client as ReqwestClient, RequestBuilder};
@ -11,6 +13,7 @@ pub struct AzureOpenAIConfig {
pub api_base: Option<String>,
pub api_key: Option<String>,
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
pub extra: Option<ExtraConfig>,
}
@ -35,7 +38,7 @@ impl AzureOpenAIClient {
let api_key = self.get_api_key()?;
let mut body = openai_build_body(data, &self.model);
self.model.merge_extra_fields(&mut body);
self.patch_request_body(&mut body);
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2024-02-01",

View File

@ -1,8 +1,7 @@
use super::claude::{claude_build_body, claude_extract_completion};
use super::{
catch_error, generate_prompt, BedrockClient, Client, CompletionOutput, ExtraConfig, Model,
ModelData, PromptAction, PromptFormat, PromptKind, SendData, SseHandler, LLAMA3_PROMPT_FORMAT,
MISTRAL_PROMPT_FORMAT,
prompt_format::*, claude::*,
catch_error, BedrockClient, Client, CompletionOutput, ExtraConfig, Model, ModelData,
ModelPatches, PromptAction, PromptKind, SendData, SseHandler,
};
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
@ -31,6 +30,7 @@ pub struct BedrockConfig {
pub region: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
pub extra: Option<ExtraConfig>,
}
@ -102,7 +102,7 @@ impl BedrockClient {
let headers = IndexMap::new();
let mut body = build_body(data, &self.model, model_category)?;
self.model.merge_extra_fields(&mut body);
self.patch_request_body(&mut body);
let builder = aws_fetch(
client,

View File

@ -1,7 +1,7 @@
use super::{
catch_error, extract_system_message, message::*, sse_stream, ClaudeClient, CompletionOutput,
ExtraConfig, ImageUrl, MessageContent, MessageContentPart, Model, ModelData, PromptAction,
PromptKind, SendData, SsMmessage, SseHandler, ToolCall,
catch_error, extract_system_message, message::*, sse_stream, ClaudeClient, Client,
CompletionOutput, ExtraConfig, ImageUrl, MessageContent, MessageContentPart, Model, ModelData,
ModelPatches, PromptAction, PromptKind, SendData, SsMmessage, SseHandler, ToolCall,
};
use anyhow::{bail, Context, Result};
@ -17,6 +17,7 @@ pub struct ClaudeConfig {
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
pub extra: Option<ExtraConfig>,
}
@ -29,7 +30,8 @@ impl ClaudeClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let body = claude_build_body(data, &self.model)?;
let mut body = claude_build_body(data, &self.model)?;
self.patch_request_body(&mut body);
let url = API_BASE;

View File

@ -1,6 +1,6 @@
use super::{
catch_error, sse_stream, CloudflareClient, CompletionOutput, ExtraConfig, Model, ModelData,
PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
catch_error, sse_stream, Client, CloudflareClient, CompletionOutput, ExtraConfig, Model,
ModelData, ModelPatches, PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
};
use anyhow::{anyhow, Result};
@ -17,6 +17,7 @@ pub struct CloudflareConfig {
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
pub extra: Option<ExtraConfig>,
}
@ -33,7 +34,8 @@ impl CloudflareClient {
let account_id = self.get_account_id()?;
let api_key = self.get_api_key()?;
let body = build_body(data, &self.model)?;
let mut body = build_body(data, &self.model)?;
self.patch_request_body(&mut body);
let url = format!(
"{API_BASE}/accounts/{account_id}/ai/run/{}",

View File

@ -1,6 +1,7 @@
use super::{
catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionOutput,
ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData, SseHandler, ToolCall,
catch_error, extract_system_message, json_stream, message::*, Client, CohereClient,
CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind,
SendData, SseHandler, ToolCall,
};
use anyhow::{bail, Result};
@ -16,6 +17,7 @@ pub struct CohereConfig {
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
pub extra: Option<ExtraConfig>,
}
@ -28,7 +30,8 @@ impl CohereClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key()?;
let body = build_body(data, &self.model)?;
let mut body = build_body(data, &self.model)?;
self.patch_request_body(&mut body);
let url = API_URL;

View File

@ -9,7 +9,9 @@ use crate::{
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use fancy_regex::Regex;
use futures_util::{Stream, StreamExt};
use indexmap::IndexMap;
use lazy_static::lazy_static;
use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy, RequestBuilder};
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
@ -23,6 +25,7 @@ const MODELS_YAML: &str = include_str!("../../models.yaml");
lazy_static! {
pub static ref ALL_CLIENT_MODELS: Vec<BuiltinModels> =
serde_yaml::from_str(MODELS_YAML).unwrap();
static ref ESCAPE_SLASH_RE: Regex = Regex::new(r"(?<!\\)/").unwrap();
}
#[macro_export]
@ -158,13 +161,16 @@ macro_rules! register_client {
#[macro_export]
macro_rules! client_common_fns {
() => {
fn config(
&self,
) -> (
&$crate::config::GlobalConfig,
&Option<$crate::client::ExtraConfig>,
) {
(&self.global_config, &self.config.extra)
fn global_config(&self) -> &$crate::config::GlobalConfig {
&self.global_config
}
fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> {
self.config.extra.as_ref()
}
fn patches_config(&self) -> Option<&$crate::client::ModelPatches> {
self.config.patches.as_ref()
}
fn list_models(&self) -> Vec<Model> {
@ -246,8 +252,13 @@ macro_rules! unsupported_model {
#[async_trait]
pub trait Client: Sync + Send {
fn config(&self) -> (&GlobalConfig, &Option<ExtraConfig>);
fn global_config(&self) -> &GlobalConfig;
fn extra_config(&self) -> Option<&ExtraConfig>;
fn patches_config(&self) -> Option<&ModelPatches>;
#[allow(unused)]
fn name(&self) -> &str;
#[allow(unused)]
@ -262,12 +273,9 @@ pub trait Client: Sync + Send {
fn build_client(&self) -> Result<ReqwestClient> {
let mut builder = ReqwestClient::builder();
let options = self.config().1;
let timeout = options
.as_ref()
.and_then(|v| v.connect_timeout)
.unwrap_or(10);
let proxy = options.as_ref().and_then(|v| v.proxy.clone());
let extra = self.extra_config();
let timeout = extra.and_then(|v| v.connect_timeout).unwrap_or(10);
let proxy = extra.and_then(|v| v.proxy.clone());
builder = set_proxy(builder, &proxy)?;
let client = builder
.connect_timeout(Duration::from_secs(timeout))
@ -277,8 +285,7 @@ pub trait Client: Sync + Send {
}
async fn send_message(&self, input: Input) -> Result<CompletionOutput> {
let global_config = self.config().0;
if global_config.read().dry_run {
if self.global_config().read().dry_run {
let content = input.echo_messages();
return Ok(CompletionOutput::new(&content));
}
@ -303,8 +310,7 @@ pub trait Client: Sync + Send {
let input = input.clone();
tokio::select! {
ret = async {
let global_config = self.config().0;
if global_config.read().dry_run {
if self.global_config().read().dry_run {
let content = input.echo_messages();
let tokens = tokenize(&content);
for token in tokens {
@ -327,6 +333,15 @@ pub trait Client: Sync + Send {
}
}
fn patch_request_body(&self, body: &mut Value) {
let model_name = self.model().name();
if let Some(patch_data) = slect_model_patch(self.patches_config(), model_name) {
if body.is_object() && patch_data.request_body.is_object() {
json_patch::merge(body, &patch_data.request_body)
}
}
}
async fn send_message_inner(
&self,
client: &ReqwestClient,
@ -353,6 +368,30 @@ pub struct ExtraConfig {
pub connect_timeout: Option<u64>,
}
pub type ModelPatches = IndexMap<String, ModelPatch>;
#[derive(Debug, Clone, Deserialize)]
pub struct ModelPatch {
#[serde(default)]
pub request_body: Value,
}
pub fn slect_model_patch<'a>(
patch: Option<&'a ModelPatches>,
name: &str,
) -> Option<&'a ModelPatch> {
let patch = patch?;
for (key, patch_data) in patch {
let key = ESCAPE_SLASH_RE.replace_all(key, r"\/");
if let Ok(regex) = Regex::new(&format!("^({key})$")) {
if let Ok(true) = regex.is_match(name) {
return Some(patch_data);
}
}
}
None
}
#[derive(Debug)]
pub struct SendData {
pub messages: Vec<Message>,

View File

@ -1,7 +1,8 @@
use super::access_token::*;
use super::{
maybe_catch_error, patch_system_message, sse_stream, Client, CompletionOutput, ErnieClient,
ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, SendData, SsMmessage,
SseHandler,
};
use anyhow::{anyhow, Context, Result};
@ -21,6 +22,7 @@ pub struct ErnieConfig {
pub secret_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
pub extra: Option<ExtraConfig>,
}
@ -31,7 +33,9 @@ impl ErnieClient {
];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let body = build_body(data, &self.model);
let mut body = build_body(data, &self.model);
self.patch_request_body(&mut body);
let access_token = get_access_token(self.name())?;
let url = format!(

View File

@ -1,5 +1,7 @@
use super::vertexai::gemini_build_body;
use super::{ExtraConfig, GeminiClient, Model, ModelData, PromptAction, PromptKind, SendData};
use super::{
vertexai::*, Client, ExtraConfig, GeminiClient, Model, ModelData, ModelPatches, PromptAction,
PromptKind, SendData,
};
use anyhow::Result;
use reqwest::{Client as ReqwestClient, RequestBuilder};
@ -15,6 +17,7 @@ pub struct GeminiConfig {
pub safety_settings: Option<serde_json::Value>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
pub extra: Option<ExtraConfig>,
}
@ -32,7 +35,8 @@ impl GeminiClient {
false => "generateContent",
};
let body = gemini_build_body(data, &self.model, self.config.safety_settings.clone())?;
let mut body = gemini_build_body(data, &self.model)?;
self.patch_request_body(&mut body);
let model = &self.model.name();

View File

@ -11,7 +11,6 @@ pub use crate::utils::PromptKind;
pub use common::*;
pub use message::*;
pub use model::*;
pub use prompt_format::*;
pub use sse_handler::*;
register_client!(

View File

@ -191,26 +191,6 @@ impl Model {
}
Ok(())
}
pub fn merge_extra_fields(&self, body: &mut serde_json::Value) {
if let (Some(body), Some(extra_fields)) = (body.as_object_mut(), &self.data.extra_fields) {
for (key, extra_field) in extra_fields {
if body.contains_key(key) {
if let (Some(sub_body), Some(extra_field)) =
(body[key].as_object_mut(), extra_field.as_object())
{
for (subkey, sub_field) in extra_field {
if !sub_body.contains_key(subkey) {
sub_body.insert(subkey.clone(), sub_field.clone());
}
}
}
} else {
body.insert(key.clone(), extra_field.clone());
}
}
}
}
}
#[derive(Debug, Clone, Default, Deserialize)]
@ -226,7 +206,6 @@ pub struct ModelData {
pub supports_vision: bool,
#[serde(default)]
pub supports_function_calling: bool,
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
}
impl ModelData {

View File

@ -1,6 +1,6 @@
use super::{
catch_error, message::*, CompletionOutput, ExtraConfig, Model, ModelData, OllamaClient,
PromptAction, PromptKind, SendData, SseHandler,
catch_error, message::*, Client, CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches,
OllamaClient, PromptAction, PromptKind, SendData, SseHandler,
};
use anyhow::{anyhow, bail, Result};
@ -16,6 +16,7 @@ pub struct OllamaConfig {
pub api_auth: Option<String>,
pub chat_endpoint: Option<String>,
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
pub extra: Option<ExtraConfig>,
}
@ -40,7 +41,7 @@ impl OllamaClient {
let api_auth = self.get_api_auth().ok();
let mut body = build_body(data, &self.model)?;
self.model.merge_extra_fields(&mut body);
self.patch_request_body(&mut body);
let chat_endpoint = self.config.chat_endpoint.as_deref().unwrap_or("/api/chat");

View File

@ -1,6 +1,7 @@
use super::{
catch_error, message::*, sse_stream, CompletionOutput, ExtraConfig, Model, ModelData,
OpenAIClient, PromptAction, PromptKind, SendData, SsMmessage, SseHandler, ToolCall,
catch_error, message::*, sse_stream, Client, CompletionOutput, ExtraConfig, Model, ModelData,
ModelPatches, OpenAIClient, PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
ToolCall,
};
use anyhow::{bail, Result};
@ -18,6 +19,7 @@ pub struct OpenAIConfig {
pub organization_id: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
pub extra: Option<ExtraConfig>,
}
@ -32,7 +34,8 @@ impl OpenAIClient {
let api_key = self.get_api_key()?;
let api_base = self.get_api_base().unwrap_or_else(|_| API_BASE.to_string());
let body = openai_build_body(data, &self.model);
let mut body = openai_build_body(data, &self.model);
self.patch_request_body(&mut body);
let url = format!("{api_base}/chat/completions");

View File

@ -1,8 +1,6 @@
use crate::client::OPENAI_COMPATIBLE_PLATFORMS;
use super::openai::openai_build_body;
use super::{
ExtraConfig, Model, ModelData, OpenAICompatibleClient, PromptAction, PromptKind, SendData,
openai::*, Client, ExtraConfig, Model, ModelData, ModelPatches, OpenAICompatibleClient,
PromptAction, PromptKind, SendData, OPENAI_COMPATIBLE_PLATFORMS,
};
use anyhow::Result;
@ -17,6 +15,7 @@ pub struct OpenAICompatibleConfig {
pub chat_endpoint: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
pub extra: Option<ExtraConfig>,
}
@ -58,7 +57,7 @@ impl OpenAICompatibleClient {
let api_key = self.get_api_key().ok();
let mut body = openai_build_body(data, &self.model);
self.model.merge_extra_fields(&mut body);
self.patch_request_body(&mut body);
let chat_endpoint = self
.config

View File

@ -1,6 +1,7 @@
use super::{
maybe_catch_error, message::*, sse_stream, Client, CompletionOutput, ExtraConfig, Model,
ModelData, PromptAction, PromptKind, QianwenClient, SendData, SsMmessage, SseHandler,
ModelData, ModelPatches, PromptAction, PromptKind, QianwenClient, SendData, SsMmessage,
SseHandler,
};
use crate::utils::{base64_decode, sha256};
@ -27,6 +28,7 @@ pub struct QianwenConfig {
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
pub extra: Option<ExtraConfig>,
}
@ -46,7 +48,8 @@ impl QianwenClient {
true => API_URL_VL,
false => API_URL,
};
let (body, has_upload) = build_body(data, &self.model, is_vl)?;
let (mut body, has_upload) = build_body(data, &self.model, is_vl)?;
self.patch_request_body(&mut body);
debug!("Qianwen Request: {url} {body}");

View File

@ -1,9 +1,7 @@
use std::time::Duration;
use super::{
catch_error, generate_prompt, smart_prompt_format, sse_stream, Client, CompletionOutput,
ExtraConfig, Model, ModelData, PromptAction, PromptKind, ReplicateClient, SendData, SsMmessage,
SseHandler,
catch_error, prompt_format::*, sse_stream, Client, CompletionOutput, ExtraConfig,
Model, ModelData, ModelPatches, PromptAction, PromptKind, ReplicateClient, SendData,
SsMmessage, SseHandler,
};
use anyhow::{anyhow, Result};
@ -11,6 +9,7 @@ use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use std::time::Duration;
const API_BASE: &str = "https://api.replicate.com/v1";
@ -20,6 +19,7 @@ pub struct ReplicateConfig {
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
pub extra: Option<ExtraConfig>,
}
@ -35,7 +35,8 @@ impl ReplicateClient {
data: SendData,
api_key: &str,
) -> Result<RequestBuilder> {
let body = build_body(data, &self.model)?;
let mut body = build_body(data, &self.model)?;
self.patch_request_body(&mut body);
let url = format!("{API_BASE}/models/{}/predictions", self.model.name());

View File

@ -1,7 +1,7 @@
use super::{
access_token::*, catch_error, json_stream, message::*, patch_system_message, Client,
CompletionOutput, ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData,
SseHandler, ToolCall, VertexAIClient,
CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind,
SendData, SseHandler, ToolCall, VertexAIClient,
};
use anyhow::{anyhow, bail, Context, Result};
@ -22,6 +22,7 @@ pub struct VertexAIConfig {
pub safety_settings: Option<Value>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
pub extra: Option<ExtraConfig>,
}
@ -47,7 +48,8 @@ impl VertexAIClient {
};
let url = format!("{base_url}/google/models/{}:{func}", self.model.name());
let body = gemini_build_body(data, &self.model, self.config.safety_settings.clone())?;
let mut body = gemini_build_body(data, &self.model)?;
self.patch_request_body(&mut body);
debug!("VertexAI Request: {url} {body}");
@ -178,7 +180,6 @@ fn gemini_extract_completion_text(data: &Value) -> Result<CompletionOutput> {
pub(crate) fn gemini_build_body(
data: SendData,
model: &Model,
safety_settings: Option<Value>,
) -> Result<Value> {
let SendData {
mut messages,
@ -259,10 +260,6 @@ pub(crate) fn gemini_build_body(
let mut body = json!({ "contents": contents, "generationConfig": {} });
if let Some(safety_settings) = safety_settings {
body["safetySettings"] = safety_settings;
}
if let Some(v) = model.max_tokens_param() {
body["generationConfig"]["maxOutputTokens"] = v.into();
}

View File

@ -1,9 +1,6 @@
use super::access_token::*;
use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming};
use super::vertexai::prepare_gcloud_access_token;
use super::{
Client, CompletionOutput, ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData,
SseHandler, VertexAIClaudeClient,
access_token::*, claude::*, vertexai::*, Client, CompletionOutput, ExtraConfig, Model,
ModelData, ModelPatches, PromptAction, PromptKind, SendData, SseHandler, VertexAIClaudeClient,
};
use anyhow::Result;
@ -19,6 +16,7 @@ pub struct VertexAIClaudeConfig {
pub adc_file: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
pub extra: Option<ExtraConfig>,
}
@ -43,6 +41,7 @@ impl VertexAIClaudeClient {
);
let mut body = claude_build_body(data, &self.model)?;
self.patch_request_body(&mut body);
if let Some(body_obj) = body.as_object_mut() {
body_obj.remove("model");
}