feat: support vertexai claude (#439)

pull/450/head
sigoden 4 weeks ago committed by GitHub
parent d6df1e84a7
commit 615bab215b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -107,7 +107,8 @@ clients:
# See https://cloud.google.com/vertex-ai
- type: vertexai
api_base: https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models
project_id: xxx
location: xxx
# Specifies a application-default-credentials (adc) file, Optional field
# Run `gcloud auth application-default login` to init the adc file
# see https://cloud.google.com/docs/authentication/external/set-up-adc

@ -240,6 +240,18 @@
input_price: 2.5
output_price: 7.5
supports_vision: true
- name: claude-3-opus@20240229
max_input_tokens: 200000
max_output_tokens: 4096
supports_vision: true
- name: claude-3-sonnet@20240229
max_input_tokens: 200000
max_output_tokens: 4096
supports_vision: true
- name: claude-3-haiku@20240307
max_input_tokens: 200000
max_output_tokens: 4096
supports_vision: true
- type: ernie
# docs:

@ -1,3 +1,4 @@
use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming};
use super::{
catch_error, json_stream, message::*, patch_system_message, Client, ExtraConfig, Model,
ModelConfig, PromptType, ReplyHandler, SendData, VertexAIClient,
@ -11,14 +12,15 @@ use chrono::{Duration, Utc};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use std::path::PathBuf;
use std::{path::PathBuf, str::FromStr};
static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation
#[derive(Debug, Clone, Deserialize, Default)]
pub struct VertexAIConfig {
pub name: Option<String>,
pub api_base: Option<String>,
pub project_id: Option<String>,
pub location: Option<String>,
pub adc_file: Option<String>,
pub block_threshold: Option<String>,
#[serde(default)]
@ -27,22 +29,28 @@ pub struct VertexAIConfig {
}
impl VertexAIClient {
config_get_fn!(api_base, get_api_base);
config_get_fn!(project_id, get_project_id);
config_get_fn!(location, get_location);
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_base", "API Base:", true, PromptKind::String)];
pub const PROMPTS: [PromptType<'static>; 2] = [
("project_id", "Project ID", true, PromptKind::String),
("location", "Global Location", true, PromptKind::String),
];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_base = self.get_api_base()?;
fn request_builder(
&self,
client: &ReqwestClient,
data: SendData,
model_category: &ModelCategory,
) -> Result<RequestBuilder> {
let project_id = self.get_project_id()?;
let location = self.get_location()?;
let func = match data.stream {
true => "streamGenerateContent",
false => "generateContent",
};
let url = format!("{api_base}/{}:{}", &self.model.name, func);
let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers");
let url = build_url(&base_url, &self.model.name, model_category, data.stream)?;
let block_threshold = self.config.block_threshold.clone();
let body = gemini_build_body(data, &self.model, block_threshold)?;
let body = build_body(data, &self.model, model_category, block_threshold)?;
debug!("VertexAI Request: {url} {body}");
@ -74,9 +82,13 @@ impl Client for VertexAIClient {
client_common_fns!();
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let model_category = ModelCategory::from_str(&self.model.name)?;
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
gemini_send_message(builder).await
let builder = self.request_builder(client, data, &model_category)?;
match model_category {
ModelCategory::Gemini => gemini_send_message(builder).await,
ModelCategory::Claude => claude_send_message(builder).await,
}
}
async fn send_message_streaming_inner(
@ -85,9 +97,13 @@ impl Client for VertexAIClient {
handler: &mut ReplyHandler,
data: SendData,
) -> Result<()> {
let model_category = ModelCategory::from_str(&self.model.name)?;
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
gemini_send_message_streaming(builder, handler).await
let builder = self.request_builder(client, data, &model_category)?;
match model_category {
ModelCategory::Gemini => gemini_send_message_streaming(builder, handler).await,
ModelCategory::Claude => claude_send_message_streaming(builder, handler).await,
}
}
}
@ -138,6 +154,46 @@ fn gemini_extract_text(data: &Value) -> Result<&str> {
}
}
fn build_url(
base_url: &str,
model_name: &str,
model_category: &ModelCategory,
stream: bool,
) -> Result<String> {
let url = match model_category {
ModelCategory::Gemini => {
let func = match stream {
true => "streamGenerateContent",
false => "generateContent",
};
format!("{base_url}/google/models/{model_name}:{func}")
}
ModelCategory::Claude => {
format!("{base_url}/anthropic/models/{model_name}:streamRawPredict")
}
};
Ok(url)
}
fn build_body(
data: SendData,
model: &Model,
model_category: &ModelCategory,
block_threshold: Option<String>,
) -> Result<Value> {
match model_category {
ModelCategory::Gemini => gemini_build_body(data, model, block_threshold),
ModelCategory::Claude => {
let mut body = claude_build_body(data, model)?;
if let Some(body_obj) = body.as_object_mut() {
body_obj.remove("model");
}
body["anthropic_version"] = "vertex-2023-10-16".into();
Ok(body)
}
}
}
pub(crate) fn gemini_build_body(
data: SendData,
model: &Model,
@ -217,6 +273,26 @@ pub(crate) fn gemini_build_body(
Ok(body)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelCategory {
Gemini,
Claude,
}
impl FromStr for ModelCategory {
type Err = anyhow::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
if s.starts_with("gemini-") {
Ok(ModelCategory::Gemini)
} else if s.starts_with("claude-") {
Ok(ModelCategory::Claude)
} else {
unsupported_model!(s)
}
}
}
async fn fetch_access_token(
client: &reqwest::Client,
file: &Option<String>,

Loading…
Cancel
Save