From 9b283024b47f57ad7bbf032fa015eab42f8162a9 Mon Sep 17 00:00:00 2001 From: sigoden Date: Mon, 6 May 2024 08:19:42 +0800 Subject: [PATCH] feat: extract vertexai-claude client (#485) --- Argcfile.sh | 4 +- config.example.yaml | 31 +++++++- models.yaml | 9 ++- src/client/gemini.rs | 7 +- src/client/mod.rs | 6 ++ src/client/vertexai.rs | 139 ++++++++-------------------------- src/client/vertexai_claude.rs | 100 ++++++++++++++++++++++++ 7 files changed, 178 insertions(+), 118 deletions(-) create mode 100644 src/client/vertexai_claude.rs diff --git a/Argcfile.sh b/Argcfile.sh index e187489..4a2bb89 100755 --- a/Argcfile.sh +++ b/Argcfile.sh @@ -226,14 +226,14 @@ chat-ollama() { }' } -# @cmd Chat with vertexai-gemini api +# @cmd Chat with vertexai api # @env require-tools gcloud # @env VERTEXAI_PROJECT_ID! # @env VERTEXAI_LOCATION! # @option -m --model=gemini-1.0-pro $VERTEXAI_GEMINI_MODEL # @flag -S --no-stream # @arg text~ -chat-vertexai-gemini() { +chat-vertexai() { api_key="$(gcloud auth print-access-token)" func="streamGenerateContent" if [[ -n "$argc_no_stream" ]]; then diff --git a/config.example.yaml b/config.example.yaml index 8a4b227..ce16df8 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -60,8 +60,15 @@ clients: # See https://ai.google.dev/docs - type: gemini api_key: xxx # ENV: {client}_API_KEY - # possible values: BLOCK_NONE, BLOCK_ONLY_HIGH, BLOCK_MEDIUM_AND_ABOVE, BLOCK_LOW_AND_ABOVE - block_threshold: BLOCK_NONE # Optional + safetySettings: + - category: HARM_CATEGORY_HARASSMENT + threshold: BLOCK_NONE + - category: HARM_CATEGORY_HATE_SPEECH + threshold: BLOCK_NONE + - category: HARM_CATEGORY_SEXUALLY_EXPLICIT + threshold: BLOCK_NONE + - category: HARM_CATEGORY_DANGEROUS_CONTENT + threshold: BLOCK_NONE # See https://docs.anthropic.com/claude/reference/getting-started-with-the-api - type: claude @@ -114,8 +121,24 @@ clients: # Run `gcloud auth application-default login` to init the adc file # see https://cloud.google.com/docs/authentication/external/set-up-adc adc_file: - # Optional field, possible values: BLOCK_NONE, BLOCK_ONLY_HIGH, BLOCK_MEDIUM_AND_ABOVE, BLOCK_LOW_AND_ABOVE - block_threshold: BLOCK_ONLY_HIGH + safetySettings: + - category: HARM_CATEGORY_HARASSMENT + threshold: BLOCK_ONLY_HIGH + - category: HARM_CATEGORY_HATE_SPEECH + threshold: BLOCK_ONLY_HIGH + - category: HARM_CATEGORY_SEXUALLY_EXPLICIT + threshold: BLOCK_ONLY_HIGH + - category: HARM_CATEGORY_DANGEROUS_CONTENT + threshold: BLOCK_ONLY_HIGH + + # See https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude + - type: vertexai-claude + project_id: xxx # ENV: {client}_PROJECT_ID + location: xxx # ENV: {client}_LOCATION + # Specifies a application-default-credentials (adc) file, Optional field + # Run `gcloud auth application-default login` to init the adc file + # see https://cloud.google.com/docs/authentication/external/set-up-adc + adc_file: # See https://docs.aws.amazon.com/bedrock/latest/userguide/ - type: bedrock diff --git a/models.yaml b/models.yaml index 5b57dae..aa7fd45 100644 --- a/models.yaml +++ b/models.yaml @@ -238,7 +238,6 @@ # - https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini # notes: # - get max_output_tokens info from models doc - # - claude models have not been tested models: - name: gemini-1.0-pro max_input_tokens: 24568 @@ -257,6 +256,14 @@ input_price: 2.5 output_price: 7.5 supports_vision: true + +- platform: vertexai-claude + # docs: + # - https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude + # notes: + # - get max_output_tokens info from models doc + # - claude models have not been tested + models: - name: claude-3-opus@20240229 max_input_tokens: 200000 max_output_tokens: 4096 diff --git a/src/client/gemini.rs b/src/client/gemini.rs index 8f6a76d..28fc93e 100644 --- a/src/client/gemini.rs +++ b/src/client/gemini.rs @@ -11,7 +11,8 @@ const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models/ pub struct GeminiConfig { pub name: Option, pub api_key: Option, - pub block_threshold: Option, + #[serde(rename = "safetySettings")] + pub safety_settings: Option, #[serde(default)] pub models: Vec, pub extra: Option, @@ -31,9 +32,7 @@ impl GeminiClient { false => "generateContent", }; - let block_threshold = self.config.block_threshold.clone(); - - let body = gemini_build_body(data, &self.model, block_threshold)?; + let body = gemini_build_body(data, &self.model, self.config.safety_settings.clone())?; let model = &self.model.name; diff --git a/src/client/mod.rs b/src/client/mod.rs index 4ea5533..43c8c4b 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -31,6 +31,12 @@ register_client!( AzureOpenAIClient ), (vertexai, "vertexai", VertexAIConfig, VertexAIClient), + ( + vertexai_claude, + "vertexai-claude", + VertexAIClaudeConfig, + VertexAIClaudeClient + ), (bedrock, "bedrock", BedrockConfig, BedrockClient), (cloudflare, "cloudflare", CloudflareConfig, CloudflareClient), (replicate, "replicate", ReplicateConfig, ReplicateClient), diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index 2c1edd1..db04f38 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -1,4 +1,3 @@ -use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming}; use super::{ catch_error, json_stream, message::*, patch_system_message, Client, CompletionDetails, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SseHandler, @@ -11,7 +10,7 @@ use chrono::{Duration, Utc}; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; -use std::{path::PathBuf, str::FromStr}; +use std::path::PathBuf; static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation @@ -21,7 +20,8 @@ pub struct VertexAIConfig { pub project_id: Option, pub location: Option, pub adc_file: Option, - pub block_threshold: Option, + #[serde(rename = "safetySettings")] + pub safety_settings: Option, #[serde(default)] pub models: Vec, pub extra: Option, @@ -36,20 +36,19 @@ impl VertexAIClient { ("location", "Location", true, PromptKind::String), ]; - fn request_builder( - &self, - client: &ReqwestClient, - data: SendData, - model_category: &ModelCategory, - ) -> Result { + fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let project_id = self.get_project_id()?; let location = self.get_location()?; let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"); - let url = build_url(&base_url, &self.model.name, model_category, data.stream)?; - let block_threshold = self.config.block_threshold.clone(); - let body = build_body(data, &self.model, model_category, block_threshold)?; + let func = match data.stream { + true => "streamGenerateContent", + false => "generateContent", + }; + let url = format!("{base_url}/google/models/{}:{func}", self.model.name); + + let body = gemini_build_body(data, &self.model, self.config.safety_settings.clone())?; debug!("VertexAI Request: {url} {body}"); @@ -60,20 +59,6 @@ impl VertexAIClient { Ok(builder) } - - async fn prepare_access_token(&self) -> Result<()> { - if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } { - let client = self.build_client()?; - let (token, expires_in) = fetch_access_token(&client, &self.config.adc_file) - .await - .with_context(|| "Failed to fetch access token")?; - let expires_at = Utc::now() - + Duration::try_seconds(expires_in) - .ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?; - unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) }; - } - Ok(()) - } } #[async_trait] @@ -85,13 +70,9 @@ impl Client for VertexAIClient { client: &ReqwestClient, data: SendData, ) -> Result<(String, CompletionDetails)> { - let model_category = ModelCategory::from_str(&self.model.name)?; - self.prepare_access_token().await?; - let builder = self.request_builder(client, data, &model_category)?; - match model_category { - ModelCategory::Gemini => gemini_send_message(builder).await, - ModelCategory::Claude => claude_send_message(builder).await, - } + prepare_access_token(client, &self.config.adc_file).await?; + let builder = self.request_builder(client, data)?; + gemini_send_message(builder).await } async fn send_message_streaming_inner( @@ -100,13 +81,9 @@ impl Client for VertexAIClient { handler: &mut SseHandler, data: SendData, ) -> Result<()> { - let model_category = ModelCategory::from_str(&self.model.name)?; - self.prepare_access_token().await?; - let builder = self.request_builder(client, data, &model_category)?; - match model_category { - ModelCategory::Gemini => gemini_send_message_streaming(builder, handler).await, - ModelCategory::Claude => claude_send_message_streaming(builder, handler).await, - } + prepare_access_token(client, &self.config.adc_file).await?; + let builder = self.request_builder(client, data)?; + gemini_send_message_streaming(builder, handler).await } } @@ -158,7 +135,7 @@ fn gemini_extract_text(data: &Value) -> Result<&str> { .as_str() .or_else(|| data["candidates"][0]["finishReason"].as_str()) { - bail!("Blocked by safety settings,consider adjusting `block_threshold` in the client configuration") + bail!("Blocked by safety settings,consider adjusting `safetySettings` in the client configuration") } else { bail!("Invalid response data: {data}") } @@ -166,50 +143,10 @@ fn gemini_extract_text(data: &Value) -> Result<&str> { } } -fn build_url( - base_url: &str, - model_name: &str, - model_category: &ModelCategory, - stream: bool, -) -> Result { - let url = match model_category { - ModelCategory::Gemini => { - let func = match stream { - true => "streamGenerateContent", - false => "generateContent", - }; - format!("{base_url}/google/models/{model_name}:{func}") - } - ModelCategory::Claude => { - format!("{base_url}/anthropic/models/{model_name}:streamRawPredict") - } - }; - Ok(url) -} - -fn build_body( - data: SendData, - model: &Model, - model_category: &ModelCategory, - block_threshold: Option, -) -> Result { - match model_category { - ModelCategory::Gemini => gemini_build_body(data, model, block_threshold), - ModelCategory::Claude => { - let mut body = claude_build_body(data, model)?; - if let Some(body_obj) = body.as_object_mut() { - body_obj.remove("model"); - } - body["anthropic_version"] = "vertex-2023-10-16".into(); - Ok(body) - } - } -} - pub(crate) fn gemini_build_body( data: SendData, model: &Model, - block_threshold: Option, + safety_settings: Option, ) -> Result { let SendData { mut messages, @@ -263,13 +200,8 @@ pub(crate) fn gemini_build_body( let mut body = json!({ "contents": contents, "generationConfig": {} }); - if let Some(block_threshold) = block_threshold { - body["safetySettings"] = json!([ - {"category":"HARM_CATEGORY_HARASSMENT","threshold":block_threshold}, - {"category":"HARM_CATEGORY_HATE_SPEECH","threshold":block_threshold}, - {"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":block_threshold}, - {"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":block_threshold} - ]); + if let Some(safety_settings) = safety_settings { + body["safetySettings"] = safety_settings; } if let Some(v) = model.max_output_tokens { @@ -285,27 +217,20 @@ pub(crate) fn gemini_build_body( Ok(body) } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum ModelCategory { - Gemini, - Claude, -} - -impl FromStr for ModelCategory { - type Err = anyhow::Error; - - fn from_str(s: &str) -> std::result::Result { - if s.starts_with("gemini-") { - Ok(ModelCategory::Gemini) - } else if s.starts_with("claude-") { - Ok(ModelCategory::Claude) - } else { - unsupported_model!(s) - } +async fn prepare_access_token(client: &reqwest::Client, adc_file: &Option) -> Result<()> { + if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } { + let (token, expires_in) = fetch_gcloud_access_token(client, adc_file) + .await + .with_context(|| "Failed to fetch access token")?; + let expires_at = Utc::now() + + Duration::try_seconds(expires_in) + .ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?; + unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) }; } + Ok(()) } -async fn fetch_access_token( +pub async fn fetch_gcloud_access_token( client: &reqwest::Client, file: &Option, ) -> Result<(String, i64)> { diff --git a/src/client/vertexai_claude.rs b/src/client/vertexai_claude.rs new file mode 100644 index 0000000..78adb72 --- /dev/null +++ b/src/client/vertexai_claude.rs @@ -0,0 +1,100 @@ +use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming}; +use super::vertexai::fetch_gcloud_access_token; +use super::{ + Client, CompletionDetails, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, + SseHandler, VertexAIClaudeClient, +}; + +use anyhow::{anyhow, Context, Result}; +use async_trait::async_trait; +use chrono::{Duration, Utc}; +use reqwest::{Client as ReqwestClient, RequestBuilder}; +use serde::Deserialize; + +static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct VertexAIClaudeConfig { + pub name: Option, + pub project_id: Option, + pub location: Option, + pub adc_file: Option, + #[serde(default)] + pub models: Vec, + pub extra: Option, +} + +impl VertexAIClaudeClient { + config_get_fn!(project_id, get_project_id); + config_get_fn!(location, get_location); + + pub const PROMPTS: [PromptAction<'static>; 2] = [ + ("project_id", "Project ID", true, PromptKind::String), + ("location", "Location", true, PromptKind::String), + ]; + + fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + let project_id = self.get_project_id()?; + let location = self.get_location()?; + + let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"); + let url = format!( + "{base_url}/anthropic/models/{}:streamRawPredict", + self.model.name + ); + + let mut body = claude_build_body(data, &self.model)?; + if let Some(body_obj) = body.as_object_mut() { + body_obj.remove("model"); + } + body["anthropic_version"] = "vertex-2023-10-16".into(); + + debug!("VertexAIClaude Request: {url} {body}"); + + let builder = client + .post(url) + .bearer_auth(unsafe { &ACCESS_TOKEN.0 }) + .json(&body); + + Ok(builder) + } +} + +#[async_trait] +impl Client for VertexAIClaudeClient { + client_common_fns!(); + + async fn send_message_inner( + &self, + client: &ReqwestClient, + data: SendData, + ) -> Result<(String, CompletionDetails)> { + prepare_access_token(client, &self.config.adc_file).await?; + let builder = self.request_builder(client, data)?; + claude_send_message(builder).await + } + + async fn send_message_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut SseHandler, + data: SendData, + ) -> Result<()> { + prepare_access_token(client, &self.config.adc_file).await?; + let builder = self.request_builder(client, data)?; + claude_send_message_streaming(builder, handler).await + } +} + +async fn prepare_access_token(client: &reqwest::Client, adc_file: &Option) -> Result<()> { + if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } { + let (token, expires_in) = fetch_gcloud_access_token(client, adc_file) + .await + .with_context(|| "Failed to fetch access token")?; + let expires_at = Utc::now() + + Duration::try_seconds(expires_in) + .ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?; + unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) }; + } + Ok(()) +}