From 615bab215bdc7ba156b90d5bbfdd720e5ce430a2 Mon Sep 17 00:00:00 2001 From: sigoden Date: Sun, 28 Apr 2024 10:55:41 +0800 Subject: [PATCH] feat: support vertexai claude (#439) --- config.example.yaml | 3 +- models.yaml | 12 +++++ src/client/vertexai.rs | 110 ++++++++++++++++++++++++++++++++++------- 3 files changed, 107 insertions(+), 18 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index 9a30c3f..9706dd6 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -107,7 +107,8 @@ clients: # See https://cloud.google.com/vertex-ai - type: vertexai - api_base: https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models + project_id: xxx + location: xxx # Specifies a application-default-credentials (adc) file, Optional field # Run `gcloud auth application-default login` to init the adc file # see https://cloud.google.com/docs/authentication/external/set-up-adc diff --git a/models.yaml b/models.yaml index 29417ab..4c80712 100644 --- a/models.yaml +++ b/models.yaml @@ -240,6 +240,18 @@ input_price: 2.5 output_price: 7.5 supports_vision: true + - name: claude-3-opus@20240229 + max_input_tokens: 200000 + max_output_tokens: 4096 + supports_vision: true + - name: claude-3-sonnet@20240229 + max_input_tokens: 200000 + max_output_tokens: 4096 + supports_vision: true + - name: claude-3-haiku@20240307 + max_input_tokens: 200000 + max_output_tokens: 4096 + supports_vision: true - type: ernie # docs: diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index afb1136..317ed2a 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -1,3 +1,4 @@ +use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming}; use super::{ catch_error, json_stream, message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig, PromptType, ReplyHandler, SendData, VertexAIClient, @@ -11,14 +12,15 @@ use chrono::{Duration, Utc}; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; -use std::path::PathBuf; +use std::{path::PathBuf, str::FromStr}; static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation #[derive(Debug, Clone, Deserialize, Default)] pub struct VertexAIConfig { pub name: Option, - pub api_base: Option, + pub project_id: Option, + pub location: Option, pub adc_file: Option, pub block_threshold: Option, #[serde(default)] @@ -27,22 +29,28 @@ pub struct VertexAIConfig { } impl VertexAIClient { - config_get_fn!(api_base, get_api_base); + config_get_fn!(project_id, get_project_id); + config_get_fn!(location, get_location); - pub const PROMPTS: [PromptType<'static>; 1] = - [("api_base", "API Base:", true, PromptKind::String)]; + pub const PROMPTS: [PromptType<'static>; 2] = [ + ("project_id", "Project ID", true, PromptKind::String), + ("location", "Global Location", true, PromptKind::String), + ]; - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { - let api_base = self.get_api_base()?; + fn request_builder( + &self, + client: &ReqwestClient, + data: SendData, + model_category: &ModelCategory, + ) -> Result { + let project_id = self.get_project_id()?; + let location = self.get_location()?; - let func = match data.stream { - true => "streamGenerateContent", - false => "generateContent", - }; - let url = format!("{api_base}/{}:{}", &self.model.name, func); + let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"); + let url = build_url(&base_url, &self.model.name, model_category, data.stream)?; let block_threshold = self.config.block_threshold.clone(); - let body = gemini_build_body(data, &self.model, block_threshold)?; + let body = build_body(data, &self.model, model_category, block_threshold)?; debug!("VertexAI Request: {url} {body}"); @@ -74,9 +82,13 @@ impl Client for VertexAIClient { client_common_fns!(); async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { + let model_category = ModelCategory::from_str(&self.model.name)?; self.prepare_access_token().await?; - let builder = self.request_builder(client, data)?; - gemini_send_message(builder).await + let builder = self.request_builder(client, data, &model_category)?; + match model_category { + ModelCategory::Gemini => gemini_send_message(builder).await, + ModelCategory::Claude => claude_send_message(builder).await, + } } async fn send_message_streaming_inner( @@ -85,9 +97,13 @@ impl Client for VertexAIClient { handler: &mut ReplyHandler, data: SendData, ) -> Result<()> { + let model_category = ModelCategory::from_str(&self.model.name)?; self.prepare_access_token().await?; - let builder = self.request_builder(client, data)?; - gemini_send_message_streaming(builder, handler).await + let builder = self.request_builder(client, data, &model_category)?; + match model_category { + ModelCategory::Gemini => gemini_send_message_streaming(builder, handler).await, + ModelCategory::Claude => claude_send_message_streaming(builder, handler).await, + } } } @@ -138,6 +154,46 @@ fn gemini_extract_text(data: &Value) -> Result<&str> { } } +fn build_url( + base_url: &str, + model_name: &str, + model_category: &ModelCategory, + stream: bool, +) -> Result { + let url = match model_category { + ModelCategory::Gemini => { + let func = match stream { + true => "streamGenerateContent", + false => "generateContent", + }; + format!("{base_url}/google/models/{model_name}:{func}") + } + ModelCategory::Claude => { + format!("{base_url}/anthropic/models/{model_name}:streamRawPredict") + } + }; + Ok(url) +} + +fn build_body( + data: SendData, + model: &Model, + model_category: &ModelCategory, + block_threshold: Option, +) -> Result { + match model_category { + ModelCategory::Gemini => gemini_build_body(data, model, block_threshold), + ModelCategory::Claude => { + let mut body = claude_build_body(data, model)?; + if let Some(body_obj) = body.as_object_mut() { + body_obj.remove("model"); + } + body["anthropic_version"] = "vertex-2023-10-16".into(); + Ok(body) + } + } +} + pub(crate) fn gemini_build_body( data: SendData, model: &Model, @@ -217,6 +273,26 @@ pub(crate) fn gemini_build_body( Ok(body) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ModelCategory { + Gemini, + Claude, +} + +impl FromStr for ModelCategory { + type Err = anyhow::Error; + + fn from_str(s: &str) -> std::result::Result { + if s.starts_with("gemini-") { + Ok(ModelCategory::Gemini) + } else if s.starts_with("claude-") { + Ok(ModelCategory::Claude) + } else { + unsupported_model!(s) + } + } +} + async fn fetch_access_token( client: &reqwest::Client, file: &Option,