feat: vertexai support mistral models (#746)

pull/749/head
sigoden 3 months ago committed by GitHub
parent 96ad64352d
commit a6f0196017
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -277,7 +277,7 @@
- platform: vertexai - platform: vertexai
# docs: # docs:
# - https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models # - https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
# - https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude # - https://cloud.google.com/vertex-ai/generative-ai/docs/model-garden/explore-models
# - https://cloud.google.com/vertex-ai/generative-ai/pricing # - https://cloud.google.com/vertex-ai/generative-ai/pricing
# - https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini # - https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini
# notes: # notes:
@ -335,6 +335,20 @@
output_price: 1.25 output_price: 1.25
supports_vision: true supports_vision: true
supports_function_calling: true supports_function_calling: true
- name: mistral-large@2407
max_input_tokens: 128000
input_price: 3
output_price: 9
supports_function_calling: true
- name: mistral-nemo@2407
max_input_tokens: 128000
input_price: 0.3
output_price: 0.3
supports_function_calling: true
- name: codestral@2405
max_input_tokens: 32000
input_price: 1
output_price: 3
- name: text-embedding-004 - name: text-embedding-004
type: embedding type: embedding
max_input_tokens: 3072 max_input_tokens: 3072

@ -1,5 +1,6 @@
use super::access_token::*; use super::access_token::*;
use super::claude::*; use super::claude::*;
use super::openai::*;
use super::*; use super::*;
use anyhow::{anyhow, bail, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
@ -56,6 +57,13 @@ impl VertexAIClient {
ModelCategory::Claude => { ModelCategory::Claude => {
format!("{base_url}/anthropic/models/{model_name}:streamRawPredict") format!("{base_url}/anthropic/models/{model_name}:streamRawPredict")
} }
ModelCategory::Mistral => {
let func = match data.stream {
true => "streamRawPredict",
false => "rawPredict",
};
format!("{base_url}/mistralai/models/{model_name}:{func}")
}
}; };
let mut body = match model_category { let mut body = match model_category {
@ -68,6 +76,13 @@ impl VertexAIClient {
body["anthropic_version"] = "vertex-2023-10-16".into(); body["anthropic_version"] = "vertex-2023-10-16".into();
body body
} }
ModelCategory::Mistral => {
let mut body = openai_build_chat_completions_body(data, &self.model);
if let Some(body_obj) = body.as_object_mut() {
body_obj["model"] = strip_model_version(self.model.name()).into();
}
body
}
}; };
self.patch_chat_completions_body(&mut body); self.patch_chat_completions_body(&mut body);
@ -122,6 +137,7 @@ impl Client for VertexAIClient {
match model_category { match model_category {
ModelCategory::Gemini => gemini_chat_completions(builder).await, ModelCategory::Gemini => gemini_chat_completions(builder).await,
ModelCategory::Claude => claude_chat_completions(builder).await, ModelCategory::Claude => claude_chat_completions(builder).await,
ModelCategory::Mistral => openai_chat_completions(builder).await,
} }
} }
@ -137,6 +153,7 @@ impl Client for VertexAIClient {
match model_category { match model_category {
ModelCategory::Gemini => gemini_chat_completions_streaming(builder, handler).await, ModelCategory::Gemini => gemini_chat_completions_streaming(builder, handler).await,
ModelCategory::Claude => claude_chat_completions_streaming(builder, handler).await, ModelCategory::Claude => claude_chat_completions_streaming(builder, handler).await,
ModelCategory::Mistral => openai_chat_completions_streaming(builder, handler).await,
} }
} }
@ -379,7 +396,9 @@ pub fn gemini_build_chat_completions_body(
if let Some(functions) = functions { if let Some(functions) = functions {
// Gemini doesn't support functions with parameters that have empty properties, so we need to patch it. // Gemini doesn't support functions with parameters that have empty properties, so we need to patch it.
let function_declarations: Vec<_> = functions.into_iter().map(|function| { let function_declarations: Vec<_> = functions
.into_iter()
.map(|function| {
if function.parameters.is_empty_properties() { if function.parameters.is_empty_properties() {
json!({ json!({
"name": function.name, "name": function.name,
@ -388,7 +407,8 @@ pub fn gemini_build_chat_completions_body(
} else { } else {
json!(function) json!(function)
} }
}).collect(); })
.collect();
body["tools"] = json!([{ "functionDeclarations": function_declarations }]); body["tools"] = json!([{ "functionDeclarations": function_declarations }]);
} }
@ -399,16 +419,19 @@ pub fn gemini_build_chat_completions_body(
enum ModelCategory { enum ModelCategory {
Gemini, Gemini,
Claude, Claude,
Mistral,
} }
impl FromStr for ModelCategory { impl FromStr for ModelCategory {
type Err = anyhow::Error; type Err = anyhow::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> { fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
if s.starts_with("gemini-") { if s.starts_with("gemini") {
Ok(ModelCategory::Gemini) Ok(ModelCategory::Gemini)
} else if s.starts_with("claude-") { } else if s.starts_with("claude") {
Ok(ModelCategory::Claude) Ok(ModelCategory::Claude)
} else if s.starts_with("mistral") || s.starts_with("codestral") {
Ok(ModelCategory::Mistral)
} else { } else {
unsupported_model!(s) unsupported_model!(s)
} }
@ -496,3 +519,10 @@ fn default_adc_file() -> Option<PathBuf> {
path.push("application_default_credentials.json"); path.push("application_default_credentials.json");
Some(path) Some(path)
} }
fn strip_model_version(name: &str) -> &str {
match name.split_once('@') {
Some((v, _)) => v,
None => name,
}
}

Loading…
Cancel
Save