mirror of
https://github.com/sigoden/aichat
synced 2024-11-18 09:28:27 +00:00
feat: allow patching req body with client config (#534)
This commit is contained in:
parent
91a06543b2
commit
ba3bcfd67c
33
Cargo.lock
generated
33
Cargo.lock
generated
@ -54,6 +54,7 @@ dependencies = [
|
||||
"indexmap",
|
||||
"inquire",
|
||||
"is-terminal",
|
||||
"json-patch",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"mime_guess",
|
||||
@ -715,6 +716,15 @@ dependencies = [
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fluent-uri"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "17c704e9dbe1ddd863da1e6ff3567795087b1eb201ce80d8fa81162e1516500d"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
@ -1139,6 +1149,29 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "json-patch"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b1fb8864823fad91877e6caea0baca82e49e8db50f8e5c9f9a453e27d3330fc"
|
||||
dependencies = [
|
||||
"jsonptr",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonptr"
|
||||
version = "0.4.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1c6e529149475ca0b2820835d3dce8fcc41c6b943ca608d32f35b449255e4627"
|
||||
dependencies = [
|
||||
"fluent-uri",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
|
@ -58,6 +58,7 @@ urlencoding = "2.1.3"
|
||||
unicode-segmentation = "1.11.0"
|
||||
num_cpus = "1.16.0"
|
||||
threadpool = "1.8.1"
|
||||
json-patch = { version = "2.0.0", default-features = false }
|
||||
|
||||
[dependencies.reqwest]
|
||||
version = "0.12.0"
|
||||
|
@ -37,8 +37,9 @@ clients:
|
||||
# - name: xxxx # The model name
|
||||
# max_input_tokens: 100000
|
||||
# supports_vision: true
|
||||
# extra_fields: # Set custom parameters, will merge with the body json
|
||||
# key: value
|
||||
# patches:
|
||||
# <regex>: # The regex to match model names, e.g. '.*' 'gpt-4o' 'gpt-4o|gpt-4-.*'
|
||||
# request_body: # The JSON to be merged with the request body.
|
||||
# extra:
|
||||
# proxy: socks5://127.0.0.1:1080 # Set https/socks5 proxy. ENV: HTTPS_PROXY/https_proxy/ALL_PROXY/all_proxy
|
||||
# connect_timeout: 10 # Set timeout in seconds for connect to api
|
||||
@ -62,15 +63,18 @@ clients:
|
||||
# See https://ai.google.dev/docs
|
||||
- type: gemini
|
||||
api_key: xxx # ENV: {client}_API_KEY
|
||||
safetySettings:
|
||||
- category: HARM_CATEGORY_HARASSMENT
|
||||
threshold: BLOCK_NONE
|
||||
- category: HARM_CATEGORY_HATE_SPEECH
|
||||
threshold: BLOCK_NONE
|
||||
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
|
||||
threshold: BLOCK_NONE
|
||||
- category: HARM_CATEGORY_DANGEROUS_CONTENT
|
||||
threshold: BLOCK_NONE
|
||||
patches:
|
||||
'.*':
|
||||
request_body: # Override safetySettings for all models
|
||||
safetySettings:
|
||||
- category: HARM_CATEGORY_HARASSMENT
|
||||
threshold: BLOCK_NONE
|
||||
- category: HARM_CATEGORY_HATE_SPEECH
|
||||
threshold: BLOCK_NONE
|
||||
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
|
||||
threshold: BLOCK_NONE
|
||||
- category: HARM_CATEGORY_DANGEROUS_CONTENT
|
||||
threshold: BLOCK_NONE
|
||||
|
||||
# See https://docs.anthropic.com/claude/reference/getting-started-with-the-api
|
||||
- type: claude
|
||||
@ -123,15 +127,18 @@ clients:
|
||||
# Run `gcloud auth application-default login` to init the adc file
|
||||
# see https://cloud.google.com/docs/authentication/external/set-up-adc
|
||||
adc_file: <path-to/gcloud/application_default_credentials.json>
|
||||
safetySettings:
|
||||
- category: HARM_CATEGORY_HARASSMENT
|
||||
threshold: BLOCK_ONLY_HIGH
|
||||
- category: HARM_CATEGORY_HATE_SPEECH
|
||||
threshold: BLOCK_ONLY_HIGH
|
||||
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
|
||||
threshold: BLOCK_ONLY_HIGH
|
||||
- category: HARM_CATEGORY_DANGEROUS_CONTENT
|
||||
threshold: BLOCK_ONLY_HIGH
|
||||
patches:
|
||||
'gemini-.*':
|
||||
request_body: # Override safetySettings for all gemini models
|
||||
safetySettings:
|
||||
- category: HARM_CATEGORY_HARASSMENT
|
||||
threshold: BLOCK_ONLY_HIGH
|
||||
- category: HARM_CATEGORY_HATE_SPEECH
|
||||
threshold: BLOCK_ONLY_HIGH
|
||||
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
|
||||
threshold: BLOCK_ONLY_HIGH
|
||||
- category: HARM_CATEGORY_DANGEROUS_CONTENT
|
||||
threshold: BLOCK_ONLY_HIGH
|
||||
|
||||
# See https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude
|
||||
- type: vertexai-claude
|
||||
|
@ -1,5 +1,7 @@
|
||||
use super::openai::openai_build_body;
|
||||
use super::{AzureOpenAIClient, ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData};
|
||||
use super::{
|
||||
openai::*, AzureOpenAIClient, Client, ExtraConfig, Model, ModelData, ModelPatches,
|
||||
PromptAction, PromptKind, SendData,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
@ -11,6 +13,7 @@ pub struct AzureOpenAIConfig {
|
||||
pub api_base: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub models: Vec<ModelData>,
|
||||
pub patches: Option<ModelPatches>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
@ -35,7 +38,7 @@ impl AzureOpenAIClient {
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
let mut body = openai_build_body(data, &self.model);
|
||||
self.model.merge_extra_fields(&mut body);
|
||||
self.patch_request_body(&mut body);
|
||||
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/chat/completions?api-version=2024-02-01",
|
||||
|
@ -1,8 +1,7 @@
|
||||
use super::claude::{claude_build_body, claude_extract_completion};
|
||||
use super::{
|
||||
catch_error, generate_prompt, BedrockClient, Client, CompletionOutput, ExtraConfig, Model,
|
||||
ModelData, PromptAction, PromptFormat, PromptKind, SendData, SseHandler, LLAMA3_PROMPT_FORMAT,
|
||||
MISTRAL_PROMPT_FORMAT,
|
||||
prompt_format::*, claude::*,
|
||||
catch_error, BedrockClient, Client, CompletionOutput, ExtraConfig, Model, ModelData,
|
||||
ModelPatches, PromptAction, PromptKind, SendData, SseHandler,
|
||||
};
|
||||
|
||||
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
|
||||
@ -31,6 +30,7 @@ pub struct BedrockConfig {
|
||||
pub region: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patches: Option<ModelPatches>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
@ -102,7 +102,7 @@ impl BedrockClient {
|
||||
let headers = IndexMap::new();
|
||||
|
||||
let mut body = build_body(data, &self.model, model_category)?;
|
||||
self.model.merge_extra_fields(&mut body);
|
||||
self.patch_request_body(&mut body);
|
||||
|
||||
let builder = aws_fetch(
|
||||
client,
|
||||
|
@ -1,7 +1,7 @@
|
||||
use super::{
|
||||
catch_error, extract_system_message, message::*, sse_stream, ClaudeClient, CompletionOutput,
|
||||
ExtraConfig, ImageUrl, MessageContent, MessageContentPart, Model, ModelData, PromptAction,
|
||||
PromptKind, SendData, SsMmessage, SseHandler, ToolCall,
|
||||
catch_error, extract_system_message, message::*, sse_stream, ClaudeClient, Client,
|
||||
CompletionOutput, ExtraConfig, ImageUrl, MessageContent, MessageContentPart, Model, ModelData,
|
||||
ModelPatches, PromptAction, PromptKind, SendData, SsMmessage, SseHandler, ToolCall,
|
||||
};
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
@ -17,6 +17,7 @@ pub struct ClaudeConfig {
|
||||
pub api_key: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patches: Option<ModelPatches>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
@ -29,7 +30,8 @@ impl ClaudeClient {
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
let api_key = self.get_api_key().ok();
|
||||
|
||||
let body = claude_build_body(data, &self.model)?;
|
||||
let mut body = claude_build_body(data, &self.model)?;
|
||||
self.patch_request_body(&mut body);
|
||||
|
||||
let url = API_BASE;
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
catch_error, sse_stream, CloudflareClient, CompletionOutput, ExtraConfig, Model, ModelData,
|
||||
PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
|
||||
catch_error, sse_stream, Client, CloudflareClient, CompletionOutput, ExtraConfig, Model,
|
||||
ModelData, ModelPatches, PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
@ -17,6 +17,7 @@ pub struct CloudflareConfig {
|
||||
pub api_key: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patches: Option<ModelPatches>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
@ -33,7 +34,8 @@ impl CloudflareClient {
|
||||
let account_id = self.get_account_id()?;
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
let body = build_body(data, &self.model)?;
|
||||
let mut body = build_body(data, &self.model)?;
|
||||
self.patch_request_body(&mut body);
|
||||
|
||||
let url = format!(
|
||||
"{API_BASE}/accounts/{account_id}/ai/run/{}",
|
||||
|
@ -1,6 +1,7 @@
|
||||
use super::{
|
||||
catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionOutput,
|
||||
ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData, SseHandler, ToolCall,
|
||||
catch_error, extract_system_message, json_stream, message::*, Client, CohereClient,
|
||||
CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind,
|
||||
SendData, SseHandler, ToolCall,
|
||||
};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
@ -16,6 +17,7 @@ pub struct CohereConfig {
|
||||
pub api_key: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patches: Option<ModelPatches>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
@ -28,7 +30,8 @@ impl CohereClient {
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
let body = build_body(data, &self.model)?;
|
||||
let mut body = build_body(data, &self.model)?;
|
||||
self.patch_request_body(&mut body);
|
||||
|
||||
let url = API_URL;
|
||||
|
||||
|
@ -9,7 +9,9 @@ use crate::{
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use fancy_regex::Regex;
|
||||
use futures_util::{Stream, StreamExt};
|
||||
use indexmap::IndexMap;
|
||||
use lazy_static::lazy_static;
|
||||
use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy, RequestBuilder};
|
||||
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
|
||||
@ -23,6 +25,7 @@ const MODELS_YAML: &str = include_str!("../../models.yaml");
|
||||
lazy_static! {
|
||||
pub static ref ALL_CLIENT_MODELS: Vec<BuiltinModels> =
|
||||
serde_yaml::from_str(MODELS_YAML).unwrap();
|
||||
static ref ESCAPE_SLASH_RE: Regex = Regex::new(r"(?<!\\)/").unwrap();
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
@ -158,13 +161,16 @@ macro_rules! register_client {
|
||||
#[macro_export]
|
||||
macro_rules! client_common_fns {
|
||||
() => {
|
||||
fn config(
|
||||
&self,
|
||||
) -> (
|
||||
&$crate::config::GlobalConfig,
|
||||
&Option<$crate::client::ExtraConfig>,
|
||||
) {
|
||||
(&self.global_config, &self.config.extra)
|
||||
fn global_config(&self) -> &$crate::config::GlobalConfig {
|
||||
&self.global_config
|
||||
}
|
||||
|
||||
fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> {
|
||||
self.config.extra.as_ref()
|
||||
}
|
||||
|
||||
fn patches_config(&self) -> Option<&$crate::client::ModelPatches> {
|
||||
self.config.patches.as_ref()
|
||||
}
|
||||
|
||||
fn list_models(&self) -> Vec<Model> {
|
||||
@ -246,8 +252,13 @@ macro_rules! unsupported_model {
|
||||
|
||||
#[async_trait]
|
||||
pub trait Client: Sync + Send {
|
||||
fn config(&self) -> (&GlobalConfig, &Option<ExtraConfig>);
|
||||
fn global_config(&self) -> &GlobalConfig;
|
||||
|
||||
fn extra_config(&self) -> Option<&ExtraConfig>;
|
||||
|
||||
fn patches_config(&self) -> Option<&ModelPatches>;
|
||||
|
||||
#[allow(unused)]
|
||||
fn name(&self) -> &str;
|
||||
|
||||
#[allow(unused)]
|
||||
@ -262,12 +273,9 @@ pub trait Client: Sync + Send {
|
||||
|
||||
fn build_client(&self) -> Result<ReqwestClient> {
|
||||
let mut builder = ReqwestClient::builder();
|
||||
let options = self.config().1;
|
||||
let timeout = options
|
||||
.as_ref()
|
||||
.and_then(|v| v.connect_timeout)
|
||||
.unwrap_or(10);
|
||||
let proxy = options.as_ref().and_then(|v| v.proxy.clone());
|
||||
let extra = self.extra_config();
|
||||
let timeout = extra.and_then(|v| v.connect_timeout).unwrap_or(10);
|
||||
let proxy = extra.and_then(|v| v.proxy.clone());
|
||||
builder = set_proxy(builder, &proxy)?;
|
||||
let client = builder
|
||||
.connect_timeout(Duration::from_secs(timeout))
|
||||
@ -277,8 +285,7 @@ pub trait Client: Sync + Send {
|
||||
}
|
||||
|
||||
async fn send_message(&self, input: Input) -> Result<CompletionOutput> {
|
||||
let global_config = self.config().0;
|
||||
if global_config.read().dry_run {
|
||||
if self.global_config().read().dry_run {
|
||||
let content = input.echo_messages();
|
||||
return Ok(CompletionOutput::new(&content));
|
||||
}
|
||||
@ -303,8 +310,7 @@ pub trait Client: Sync + Send {
|
||||
let input = input.clone();
|
||||
tokio::select! {
|
||||
ret = async {
|
||||
let global_config = self.config().0;
|
||||
if global_config.read().dry_run {
|
||||
if self.global_config().read().dry_run {
|
||||
let content = input.echo_messages();
|
||||
let tokens = tokenize(&content);
|
||||
for token in tokens {
|
||||
@ -327,6 +333,15 @@ pub trait Client: Sync + Send {
|
||||
}
|
||||
}
|
||||
|
||||
fn patch_request_body(&self, body: &mut Value) {
|
||||
let model_name = self.model().name();
|
||||
if let Some(patch_data) = slect_model_patch(self.patches_config(), model_name) {
|
||||
if body.is_object() && patch_data.request_body.is_object() {
|
||||
json_patch::merge(body, &patch_data.request_body)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_message_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
@ -353,6 +368,30 @@ pub struct ExtraConfig {
|
||||
pub connect_timeout: Option<u64>,
|
||||
}
|
||||
|
||||
pub type ModelPatches = IndexMap<String, ModelPatch>;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelPatch {
|
||||
#[serde(default)]
|
||||
pub request_body: Value,
|
||||
}
|
||||
|
||||
pub fn slect_model_patch<'a>(
|
||||
patch: Option<&'a ModelPatches>,
|
||||
name: &str,
|
||||
) -> Option<&'a ModelPatch> {
|
||||
let patch = patch?;
|
||||
for (key, patch_data) in patch {
|
||||
let key = ESCAPE_SLASH_RE.replace_all(key, r"\/");
|
||||
if let Ok(regex) = Regex::new(&format!("^({key})$")) {
|
||||
if let Ok(true) = regex.is_match(name) {
|
||||
return Some(patch_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SendData {
|
||||
pub messages: Vec<Message>,
|
||||
|
@ -1,7 +1,8 @@
|
||||
use super::access_token::*;
|
||||
use super::{
|
||||
maybe_catch_error, patch_system_message, sse_stream, Client, CompletionOutput, ErnieClient,
|
||||
ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
|
||||
ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, SendData, SsMmessage,
|
||||
SseHandler,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
@ -21,6 +22,7 @@ pub struct ErnieConfig {
|
||||
pub secret_key: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patches: Option<ModelPatches>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
@ -31,7 +33,9 @@ impl ErnieClient {
|
||||
];
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
let body = build_body(data, &self.model);
|
||||
let mut body = build_body(data, &self.model);
|
||||
self.patch_request_body(&mut body);
|
||||
|
||||
let access_token = get_access_token(self.name())?;
|
||||
|
||||
let url = format!(
|
||||
|
@ -1,5 +1,7 @@
|
||||
use super::vertexai::gemini_build_body;
|
||||
use super::{ExtraConfig, GeminiClient, Model, ModelData, PromptAction, PromptKind, SendData};
|
||||
use super::{
|
||||
vertexai::*, Client, ExtraConfig, GeminiClient, Model, ModelData, ModelPatches, PromptAction,
|
||||
PromptKind, SendData,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
@ -15,6 +17,7 @@ pub struct GeminiConfig {
|
||||
pub safety_settings: Option<serde_json::Value>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patches: Option<ModelPatches>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
@ -32,7 +35,8 @@ impl GeminiClient {
|
||||
false => "generateContent",
|
||||
};
|
||||
|
||||
let body = gemini_build_body(data, &self.model, self.config.safety_settings.clone())?;
|
||||
let mut body = gemini_build_body(data, &self.model)?;
|
||||
self.patch_request_body(&mut body);
|
||||
|
||||
let model = &self.model.name();
|
||||
|
||||
|
@ -11,7 +11,6 @@ pub use crate::utils::PromptKind;
|
||||
pub use common::*;
|
||||
pub use message::*;
|
||||
pub use model::*;
|
||||
pub use prompt_format::*;
|
||||
pub use sse_handler::*;
|
||||
|
||||
register_client!(
|
||||
|
@ -191,26 +191,6 @@ impl Model {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn merge_extra_fields(&self, body: &mut serde_json::Value) {
|
||||
if let (Some(body), Some(extra_fields)) = (body.as_object_mut(), &self.data.extra_fields) {
|
||||
for (key, extra_field) in extra_fields {
|
||||
if body.contains_key(key) {
|
||||
if let (Some(sub_body), Some(extra_field)) =
|
||||
(body[key].as_object_mut(), extra_field.as_object())
|
||||
{
|
||||
for (subkey, sub_field) in extra_field {
|
||||
if !sub_body.contains_key(subkey) {
|
||||
sub_body.insert(subkey.clone(), sub_field.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
body.insert(key.clone(), extra_field.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize)]
|
||||
@ -226,7 +206,6 @@ pub struct ModelData {
|
||||
pub supports_vision: bool,
|
||||
#[serde(default)]
|
||||
pub supports_function_calling: bool,
|
||||
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
impl ModelData {
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
catch_error, message::*, CompletionOutput, ExtraConfig, Model, ModelData, OllamaClient,
|
||||
PromptAction, PromptKind, SendData, SseHandler,
|
||||
catch_error, message::*, Client, CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches,
|
||||
OllamaClient, PromptAction, PromptKind, SendData, SseHandler,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
@ -16,6 +16,7 @@ pub struct OllamaConfig {
|
||||
pub api_auth: Option<String>,
|
||||
pub chat_endpoint: Option<String>,
|
||||
pub models: Vec<ModelData>,
|
||||
pub patches: Option<ModelPatches>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
@ -40,7 +41,7 @@ impl OllamaClient {
|
||||
let api_auth = self.get_api_auth().ok();
|
||||
|
||||
let mut body = build_body(data, &self.model)?;
|
||||
self.model.merge_extra_fields(&mut body);
|
||||
self.patch_request_body(&mut body);
|
||||
|
||||
let chat_endpoint = self.config.chat_endpoint.as_deref().unwrap_or("/api/chat");
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
use super::{
|
||||
catch_error, message::*, sse_stream, CompletionOutput, ExtraConfig, Model, ModelData,
|
||||
OpenAIClient, PromptAction, PromptKind, SendData, SsMmessage, SseHandler, ToolCall,
|
||||
catch_error, message::*, sse_stream, Client, CompletionOutput, ExtraConfig, Model, ModelData,
|
||||
ModelPatches, OpenAIClient, PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
|
||||
ToolCall,
|
||||
};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
@ -18,6 +19,7 @@ pub struct OpenAIConfig {
|
||||
pub organization_id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patches: Option<ModelPatches>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
@ -32,7 +34,8 @@ impl OpenAIClient {
|
||||
let api_key = self.get_api_key()?;
|
||||
let api_base = self.get_api_base().unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let body = openai_build_body(data, &self.model);
|
||||
let mut body = openai_build_body(data, &self.model);
|
||||
self.patch_request_body(&mut body);
|
||||
|
||||
let url = format!("{api_base}/chat/completions");
|
||||
|
||||
|
@ -1,8 +1,6 @@
|
||||
use crate::client::OPENAI_COMPATIBLE_PLATFORMS;
|
||||
|
||||
use super::openai::openai_build_body;
|
||||
use super::{
|
||||
ExtraConfig, Model, ModelData, OpenAICompatibleClient, PromptAction, PromptKind, SendData,
|
||||
openai::*, Client, ExtraConfig, Model, ModelData, ModelPatches, OpenAICompatibleClient,
|
||||
PromptAction, PromptKind, SendData, OPENAI_COMPATIBLE_PLATFORMS,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
@ -17,6 +15,7 @@ pub struct OpenAICompatibleConfig {
|
||||
pub chat_endpoint: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patches: Option<ModelPatches>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
@ -58,7 +57,7 @@ impl OpenAICompatibleClient {
|
||||
let api_key = self.get_api_key().ok();
|
||||
|
||||
let mut body = openai_build_body(data, &self.model);
|
||||
self.model.merge_extra_fields(&mut body);
|
||||
self.patch_request_body(&mut body);
|
||||
|
||||
let chat_endpoint = self
|
||||
.config
|
||||
|
@ -1,6 +1,7 @@
|
||||
use super::{
|
||||
maybe_catch_error, message::*, sse_stream, Client, CompletionOutput, ExtraConfig, Model,
|
||||
ModelData, PromptAction, PromptKind, QianwenClient, SendData, SsMmessage, SseHandler,
|
||||
ModelData, ModelPatches, PromptAction, PromptKind, QianwenClient, SendData, SsMmessage,
|
||||
SseHandler,
|
||||
};
|
||||
|
||||
use crate::utils::{base64_decode, sha256};
|
||||
@ -27,6 +28,7 @@ pub struct QianwenConfig {
|
||||
pub api_key: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patches: Option<ModelPatches>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
@ -46,7 +48,8 @@ impl QianwenClient {
|
||||
true => API_URL_VL,
|
||||
false => API_URL,
|
||||
};
|
||||
let (body, has_upload) = build_body(data, &self.model, is_vl)?;
|
||||
let (mut body, has_upload) = build_body(data, &self.model, is_vl)?;
|
||||
self.patch_request_body(&mut body);
|
||||
|
||||
debug!("Qianwen Request: {url} {body}");
|
||||
|
||||
|
@ -1,9 +1,7 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use super::{
|
||||
catch_error, generate_prompt, smart_prompt_format, sse_stream, Client, CompletionOutput,
|
||||
ExtraConfig, Model, ModelData, PromptAction, PromptKind, ReplicateClient, SendData, SsMmessage,
|
||||
SseHandler,
|
||||
catch_error, prompt_format::*, sse_stream, Client, CompletionOutput, ExtraConfig,
|
||||
Model, ModelData, ModelPatches, PromptAction, PromptKind, ReplicateClient, SendData,
|
||||
SsMmessage, SseHandler,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
@ -11,6 +9,7 @@ use async_trait::async_trait;
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::time::Duration;
|
||||
|
||||
const API_BASE: &str = "https://api.replicate.com/v1";
|
||||
|
||||
@ -20,6 +19,7 @@ pub struct ReplicateConfig {
|
||||
pub api_key: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patches: Option<ModelPatches>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
@ -35,7 +35,8 @@ impl ReplicateClient {
|
||||
data: SendData,
|
||||
api_key: &str,
|
||||
) -> Result<RequestBuilder> {
|
||||
let body = build_body(data, &self.model)?;
|
||||
let mut body = build_body(data, &self.model)?;
|
||||
self.patch_request_body(&mut body);
|
||||
|
||||
let url = format!("{API_BASE}/models/{}/predictions", self.model.name());
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
use super::{
|
||||
access_token::*, catch_error, json_stream, message::*, patch_system_message, Client,
|
||||
CompletionOutput, ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData,
|
||||
SseHandler, ToolCall, VertexAIClient,
|
||||
CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind,
|
||||
SendData, SseHandler, ToolCall, VertexAIClient,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
@ -22,6 +22,7 @@ pub struct VertexAIConfig {
|
||||
pub safety_settings: Option<Value>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patches: Option<ModelPatches>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
@ -47,7 +48,8 @@ impl VertexAIClient {
|
||||
};
|
||||
let url = format!("{base_url}/google/models/{}:{func}", self.model.name());
|
||||
|
||||
let body = gemini_build_body(data, &self.model, self.config.safety_settings.clone())?;
|
||||
let mut body = gemini_build_body(data, &self.model)?;
|
||||
self.patch_request_body(&mut body);
|
||||
|
||||
debug!("VertexAI Request: {url} {body}");
|
||||
|
||||
@ -178,7 +180,6 @@ fn gemini_extract_completion_text(data: &Value) -> Result<CompletionOutput> {
|
||||
pub(crate) fn gemini_build_body(
|
||||
data: SendData,
|
||||
model: &Model,
|
||||
safety_settings: Option<Value>,
|
||||
) -> Result<Value> {
|
||||
let SendData {
|
||||
mut messages,
|
||||
@ -259,10 +260,6 @@ pub(crate) fn gemini_build_body(
|
||||
|
||||
let mut body = json!({ "contents": contents, "generationConfig": {} });
|
||||
|
||||
if let Some(safety_settings) = safety_settings {
|
||||
body["safetySettings"] = safety_settings;
|
||||
}
|
||||
|
||||
if let Some(v) = model.max_tokens_param() {
|
||||
body["generationConfig"]["maxOutputTokens"] = v.into();
|
||||
}
|
||||
|
@ -1,9 +1,6 @@
|
||||
use super::access_token::*;
|
||||
use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming};
|
||||
use super::vertexai::prepare_gcloud_access_token;
|
||||
use super::{
|
||||
Client, CompletionOutput, ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData,
|
||||
SseHandler, VertexAIClaudeClient,
|
||||
access_token::*, claude::*, vertexai::*, Client, CompletionOutput, ExtraConfig, Model,
|
||||
ModelData, ModelPatches, PromptAction, PromptKind, SendData, SseHandler, VertexAIClaudeClient,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
@ -19,6 +16,7 @@ pub struct VertexAIClaudeConfig {
|
||||
pub adc_file: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patches: Option<ModelPatches>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
@ -43,6 +41,7 @@ impl VertexAIClaudeClient {
|
||||
);
|
||||
|
||||
let mut body = claude_build_body(data, &self.model)?;
|
||||
self.patch_request_body(&mut body);
|
||||
if let Some(body_obj) = body.as_object_mut() {
|
||||
body_obj.remove("model");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user