From a6f01960177c51a36f3f43f619b5939e19b0c6c8 Mon Sep 17 00:00:00 2001 From: sigoden Date: Thu, 25 Jul 2024 16:24:28 -0700 Subject: [PATCH] feat: vertexai support mistral models (#746) --- models.yaml | 16 ++++++++++++- src/client/vertexai.rs | 54 ++++++++++++++++++++++++++++++++---------- 2 files changed, 57 insertions(+), 13 deletions(-) diff --git a/models.yaml b/models.yaml index cdced59..a2bdfae 100644 --- a/models.yaml +++ b/models.yaml @@ -277,7 +277,7 @@ - platform: vertexai # docs: # - https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models - # - https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude + # - https://cloud.google.com/vertex-ai/generative-ai/docs/model-garden/explore-models # - https://cloud.google.com/vertex-ai/generative-ai/pricing # - https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini # notes: @@ -335,6 +335,20 @@ output_price: 1.25 supports_vision: true supports_function_calling: true + - name: mistral-large@2407 + max_input_tokens: 128000 + input_price: 3 + output_price: 9 + supports_function_calling: true + - name: mistral-nemo@2407 + max_input_tokens: 128000 + input_price: 0.3 + output_price: 0.3 + supports_function_calling: true + - name: codestral@2405 + max_input_tokens: 32000 + input_price: 1 + output_price: 3 - name: text-embedding-004 type: embedding max_input_tokens: 3072 diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index c9ae648..d93ccbd 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -1,5 +1,6 @@ use super::access_token::*; use super::claude::*; +use super::openai::*; use super::*; use anyhow::{anyhow, bail, Context, Result}; @@ -56,6 +57,13 @@ impl VertexAIClient { ModelCategory::Claude => { format!("{base_url}/anthropic/models/{model_name}:streamRawPredict") } + ModelCategory::Mistral => { + let func = match data.stream { + true => "streamRawPredict", + false => "rawPredict", + }; + format!("{base_url}/mistralai/models/{model_name}:{func}") + } }; let mut body = match model_category { @@ -68,6 +76,13 @@ impl VertexAIClient { body["anthropic_version"] = "vertex-2023-10-16".into(); body } + ModelCategory::Mistral => { + let mut body = openai_build_chat_completions_body(data, &self.model); + if let Some(body_obj) = body.as_object_mut() { + body_obj["model"] = strip_model_version(self.model.name()).into(); + } + body + } }; self.patch_chat_completions_body(&mut body); @@ -122,6 +137,7 @@ impl Client for VertexAIClient { match model_category { ModelCategory::Gemini => gemini_chat_completions(builder).await, ModelCategory::Claude => claude_chat_completions(builder).await, + ModelCategory::Mistral => openai_chat_completions(builder).await, } } @@ -137,6 +153,7 @@ impl Client for VertexAIClient { match model_category { ModelCategory::Gemini => gemini_chat_completions_streaming(builder, handler).await, ModelCategory::Claude => claude_chat_completions_streaming(builder, handler).await, + ModelCategory::Mistral => openai_chat_completions_streaming(builder, handler).await, } } @@ -379,16 +396,19 @@ pub fn gemini_build_chat_completions_body( if let Some(functions) = functions { // Gemini doesn't support functions with parameters that have empty properties, so we need to patch it. - let function_declarations: Vec<_> = functions.into_iter().map(|function| { - if function.parameters.is_empty_properties() { - json!({ - "name": function.name, - "description": function.description, - }) - } else { - json!(function) - } - }).collect(); + let function_declarations: Vec<_> = functions + .into_iter() + .map(|function| { + if function.parameters.is_empty_properties() { + json!({ + "name": function.name, + "description": function.description, + }) + } else { + json!(function) + } + }) + .collect(); body["tools"] = json!([{ "functionDeclarations": function_declarations }]); } @@ -399,16 +419,19 @@ pub fn gemini_build_chat_completions_body( enum ModelCategory { Gemini, Claude, + Mistral, } impl FromStr for ModelCategory { type Err = anyhow::Error; fn from_str(s: &str) -> std::result::Result { - if s.starts_with("gemini-") { + if s.starts_with("gemini") { Ok(ModelCategory::Gemini) - } else if s.starts_with("claude-") { + } else if s.starts_with("claude") { Ok(ModelCategory::Claude) + } else if s.starts_with("mistral") || s.starts_with("codestral") { + Ok(ModelCategory::Mistral) } else { unsupported_model!(s) } @@ -496,3 +519,10 @@ fn default_adc_file() -> Option { path.push("application_default_credentials.json"); Some(path) } + +fn strip_model_version(name: &str) -> &str { + match name.split_once('@') { + Some((v, _)) => v, + None => name, + } +}