|
|
@ -1,5 +1,6 @@
|
|
|
|
use super::access_token::*;
|
|
|
|
use super::access_token::*;
|
|
|
|
use super::claude::*;
|
|
|
|
use super::claude::*;
|
|
|
|
|
|
|
|
use super::openai::*;
|
|
|
|
use super::*;
|
|
|
|
use super::*;
|
|
|
|
|
|
|
|
|
|
|
|
use anyhow::{anyhow, bail, Context, Result};
|
|
|
|
use anyhow::{anyhow, bail, Context, Result};
|
|
|
@ -56,6 +57,13 @@ impl VertexAIClient {
|
|
|
|
ModelCategory::Claude => {
|
|
|
|
ModelCategory::Claude => {
|
|
|
|
format!("{base_url}/anthropic/models/{model_name}:streamRawPredict")
|
|
|
|
format!("{base_url}/anthropic/models/{model_name}:streamRawPredict")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
ModelCategory::Mistral => {
|
|
|
|
|
|
|
|
let func = match data.stream {
|
|
|
|
|
|
|
|
true => "streamRawPredict",
|
|
|
|
|
|
|
|
false => "rawPredict",
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
format!("{base_url}/mistralai/models/{model_name}:{func}")
|
|
|
|
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
let mut body = match model_category {
|
|
|
|
let mut body = match model_category {
|
|
|
@ -68,6 +76,13 @@ impl VertexAIClient {
|
|
|
|
body["anthropic_version"] = "vertex-2023-10-16".into();
|
|
|
|
body["anthropic_version"] = "vertex-2023-10-16".into();
|
|
|
|
body
|
|
|
|
body
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
ModelCategory::Mistral => {
|
|
|
|
|
|
|
|
let mut body = openai_build_chat_completions_body(data, &self.model);
|
|
|
|
|
|
|
|
if let Some(body_obj) = body.as_object_mut() {
|
|
|
|
|
|
|
|
body_obj["model"] = strip_model_version(self.model.name()).into();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
body
|
|
|
|
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
self.patch_chat_completions_body(&mut body);
|
|
|
|
self.patch_chat_completions_body(&mut body);
|
|
|
|
|
|
|
|
|
|
|
@ -122,6 +137,7 @@ impl Client for VertexAIClient {
|
|
|
|
match model_category {
|
|
|
|
match model_category {
|
|
|
|
ModelCategory::Gemini => gemini_chat_completions(builder).await,
|
|
|
|
ModelCategory::Gemini => gemini_chat_completions(builder).await,
|
|
|
|
ModelCategory::Claude => claude_chat_completions(builder).await,
|
|
|
|
ModelCategory::Claude => claude_chat_completions(builder).await,
|
|
|
|
|
|
|
|
ModelCategory::Mistral => openai_chat_completions(builder).await,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -137,6 +153,7 @@ impl Client for VertexAIClient {
|
|
|
|
match model_category {
|
|
|
|
match model_category {
|
|
|
|
ModelCategory::Gemini => gemini_chat_completions_streaming(builder, handler).await,
|
|
|
|
ModelCategory::Gemini => gemini_chat_completions_streaming(builder, handler).await,
|
|
|
|
ModelCategory::Claude => claude_chat_completions_streaming(builder, handler).await,
|
|
|
|
ModelCategory::Claude => claude_chat_completions_streaming(builder, handler).await,
|
|
|
|
|
|
|
|
ModelCategory::Mistral => openai_chat_completions_streaming(builder, handler).await,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -379,7 +396,9 @@ pub fn gemini_build_chat_completions_body(
|
|
|
|
|
|
|
|
|
|
|
|
if let Some(functions) = functions {
|
|
|
|
if let Some(functions) = functions {
|
|
|
|
// Gemini doesn't support functions with parameters that have empty properties, so we need to patch it.
|
|
|
|
// Gemini doesn't support functions with parameters that have empty properties, so we need to patch it.
|
|
|
|
let function_declarations: Vec<_> = functions.into_iter().map(|function| {
|
|
|
|
let function_declarations: Vec<_> = functions
|
|
|
|
|
|
|
|
.into_iter()
|
|
|
|
|
|
|
|
.map(|function| {
|
|
|
|
if function.parameters.is_empty_properties() {
|
|
|
|
if function.parameters.is_empty_properties() {
|
|
|
|
json!({
|
|
|
|
json!({
|
|
|
|
"name": function.name,
|
|
|
|
"name": function.name,
|
|
|
@ -388,7 +407,8 @@ pub fn gemini_build_chat_completions_body(
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
json!(function)
|
|
|
|
json!(function)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}).collect();
|
|
|
|
})
|
|
|
|
|
|
|
|
.collect();
|
|
|
|
body["tools"] = json!([{ "functionDeclarations": function_declarations }]);
|
|
|
|
body["tools"] = json!([{ "functionDeclarations": function_declarations }]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -399,16 +419,19 @@ pub fn gemini_build_chat_completions_body(
|
|
|
|
enum ModelCategory {
|
|
|
|
enum ModelCategory {
|
|
|
|
Gemini,
|
|
|
|
Gemini,
|
|
|
|
Claude,
|
|
|
|
Claude,
|
|
|
|
|
|
|
|
Mistral,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl FromStr for ModelCategory {
|
|
|
|
impl FromStr for ModelCategory {
|
|
|
|
type Err = anyhow::Error;
|
|
|
|
type Err = anyhow::Error;
|
|
|
|
|
|
|
|
|
|
|
|
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
|
|
|
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
|
|
|
if s.starts_with("gemini-") {
|
|
|
|
if s.starts_with("gemini") {
|
|
|
|
Ok(ModelCategory::Gemini)
|
|
|
|
Ok(ModelCategory::Gemini)
|
|
|
|
} else if s.starts_with("claude-") {
|
|
|
|
} else if s.starts_with("claude") {
|
|
|
|
Ok(ModelCategory::Claude)
|
|
|
|
Ok(ModelCategory::Claude)
|
|
|
|
|
|
|
|
} else if s.starts_with("mistral") || s.starts_with("codestral") {
|
|
|
|
|
|
|
|
Ok(ModelCategory::Mistral)
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
unsupported_model!(s)
|
|
|
|
unsupported_model!(s)
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -496,3 +519,10 @@ fn default_adc_file() -> Option<PathBuf> {
|
|
|
|
path.push("application_default_credentials.json");
|
|
|
|
path.push("application_default_credentials.json");
|
|
|
|
Some(path)
|
|
|
|
Some(path)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn strip_model_version(name: &str) -> &str {
|
|
|
|
|
|
|
|
match name.split_once('@') {
|
|
|
|
|
|
|
|
Some((v, _)) => v,
|
|
|
|
|
|
|
|
None => name,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|