|
|
|
@ -1,3 +1,4 @@
|
|
|
|
|
use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming};
|
|
|
|
|
use super::{
|
|
|
|
|
catch_error, json_stream, message::*, patch_system_message, Client, ExtraConfig, Model,
|
|
|
|
|
ModelConfig, PromptType, ReplyHandler, SendData, VertexAIClient,
|
|
|
|
@ -11,14 +12,15 @@ use chrono::{Duration, Utc};
|
|
|
|
|
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
|
|
|
|
use serde::Deserialize;
|
|
|
|
|
use serde_json::{json, Value};
|
|
|
|
|
use std::path::PathBuf;
|
|
|
|
|
use std::{path::PathBuf, str::FromStr};
|
|
|
|
|
|
|
|
|
|
static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Deserialize, Default)]
|
|
|
|
|
pub struct VertexAIConfig {
|
|
|
|
|
pub name: Option<String>,
|
|
|
|
|
pub api_base: Option<String>,
|
|
|
|
|
pub project_id: Option<String>,
|
|
|
|
|
pub location: Option<String>,
|
|
|
|
|
pub adc_file: Option<String>,
|
|
|
|
|
pub block_threshold: Option<String>,
|
|
|
|
|
#[serde(default)]
|
|
|
|
@ -27,22 +29,28 @@ pub struct VertexAIConfig {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl VertexAIClient {
|
|
|
|
|
config_get_fn!(api_base, get_api_base);
|
|
|
|
|
config_get_fn!(project_id, get_project_id);
|
|
|
|
|
config_get_fn!(location, get_location);
|
|
|
|
|
|
|
|
|
|
pub const PROMPTS: [PromptType<'static>; 1] =
|
|
|
|
|
[("api_base", "API Base:", true, PromptKind::String)];
|
|
|
|
|
pub const PROMPTS: [PromptType<'static>; 2] = [
|
|
|
|
|
("project_id", "Project ID", true, PromptKind::String),
|
|
|
|
|
("location", "Global Location", true, PromptKind::String),
|
|
|
|
|
];
|
|
|
|
|
|
|
|
|
|
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
|
|
|
|
let api_base = self.get_api_base()?;
|
|
|
|
|
fn request_builder(
|
|
|
|
|
&self,
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
|
data: SendData,
|
|
|
|
|
model_category: &ModelCategory,
|
|
|
|
|
) -> Result<RequestBuilder> {
|
|
|
|
|
let project_id = self.get_project_id()?;
|
|
|
|
|
let location = self.get_location()?;
|
|
|
|
|
|
|
|
|
|
let func = match data.stream {
|
|
|
|
|
true => "streamGenerateContent",
|
|
|
|
|
false => "generateContent",
|
|
|
|
|
};
|
|
|
|
|
let url = format!("{api_base}/{}:{}", &self.model.name, func);
|
|
|
|
|
let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers");
|
|
|
|
|
let url = build_url(&base_url, &self.model.name, model_category, data.stream)?;
|
|
|
|
|
|
|
|
|
|
let block_threshold = self.config.block_threshold.clone();
|
|
|
|
|
let body = gemini_build_body(data, &self.model, block_threshold)?;
|
|
|
|
|
let body = build_body(data, &self.model, model_category, block_threshold)?;
|
|
|
|
|
|
|
|
|
|
debug!("VertexAI Request: {url} {body}");
|
|
|
|
|
|
|
|
|
@ -74,9 +82,13 @@ impl Client for VertexAIClient {
|
|
|
|
|
client_common_fns!();
|
|
|
|
|
|
|
|
|
|
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
|
|
|
|
let model_category = ModelCategory::from_str(&self.model.name)?;
|
|
|
|
|
self.prepare_access_token().await?;
|
|
|
|
|
let builder = self.request_builder(client, data)?;
|
|
|
|
|
gemini_send_message(builder).await
|
|
|
|
|
let builder = self.request_builder(client, data, &model_category)?;
|
|
|
|
|
match model_category {
|
|
|
|
|
ModelCategory::Gemini => gemini_send_message(builder).await,
|
|
|
|
|
ModelCategory::Claude => claude_send_message(builder).await,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn send_message_streaming_inner(
|
|
|
|
@ -85,9 +97,13 @@ impl Client for VertexAIClient {
|
|
|
|
|
handler: &mut ReplyHandler,
|
|
|
|
|
data: SendData,
|
|
|
|
|
) -> Result<()> {
|
|
|
|
|
let model_category = ModelCategory::from_str(&self.model.name)?;
|
|
|
|
|
self.prepare_access_token().await?;
|
|
|
|
|
let builder = self.request_builder(client, data)?;
|
|
|
|
|
gemini_send_message_streaming(builder, handler).await
|
|
|
|
|
let builder = self.request_builder(client, data, &model_category)?;
|
|
|
|
|
match model_category {
|
|
|
|
|
ModelCategory::Gemini => gemini_send_message_streaming(builder, handler).await,
|
|
|
|
|
ModelCategory::Claude => claude_send_message_streaming(builder, handler).await,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -138,6 +154,46 @@ fn gemini_extract_text(data: &Value) -> Result<&str> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn build_url(
|
|
|
|
|
base_url: &str,
|
|
|
|
|
model_name: &str,
|
|
|
|
|
model_category: &ModelCategory,
|
|
|
|
|
stream: bool,
|
|
|
|
|
) -> Result<String> {
|
|
|
|
|
let url = match model_category {
|
|
|
|
|
ModelCategory::Gemini => {
|
|
|
|
|
let func = match stream {
|
|
|
|
|
true => "streamGenerateContent",
|
|
|
|
|
false => "generateContent",
|
|
|
|
|
};
|
|
|
|
|
format!("{base_url}/google/models/{model_name}:{func}")
|
|
|
|
|
}
|
|
|
|
|
ModelCategory::Claude => {
|
|
|
|
|
format!("{base_url}/anthropic/models/{model_name}:streamRawPredict")
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
Ok(url)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn build_body(
|
|
|
|
|
data: SendData,
|
|
|
|
|
model: &Model,
|
|
|
|
|
model_category: &ModelCategory,
|
|
|
|
|
block_threshold: Option<String>,
|
|
|
|
|
) -> Result<Value> {
|
|
|
|
|
match model_category {
|
|
|
|
|
ModelCategory::Gemini => gemini_build_body(data, model, block_threshold),
|
|
|
|
|
ModelCategory::Claude => {
|
|
|
|
|
let mut body = claude_build_body(data, model)?;
|
|
|
|
|
if let Some(body_obj) = body.as_object_mut() {
|
|
|
|
|
body_obj.remove("model");
|
|
|
|
|
}
|
|
|
|
|
body["anthropic_version"] = "vertex-2023-10-16".into();
|
|
|
|
|
Ok(body)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub(crate) fn gemini_build_body(
|
|
|
|
|
data: SendData,
|
|
|
|
|
model: &Model,
|
|
|
|
@ -217,6 +273,26 @@ pub(crate) fn gemini_build_body(
|
|
|
|
|
Ok(body)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
|
|
|
enum ModelCategory {
|
|
|
|
|
Gemini,
|
|
|
|
|
Claude,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl FromStr for ModelCategory {
|
|
|
|
|
type Err = anyhow::Error;
|
|
|
|
|
|
|
|
|
|
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
|
|
|
|
if s.starts_with("gemini-") {
|
|
|
|
|
Ok(ModelCategory::Gemini)
|
|
|
|
|
} else if s.starts_with("claude-") {
|
|
|
|
|
Ok(ModelCategory::Claude)
|
|
|
|
|
} else {
|
|
|
|
|
unsupported_model!(s)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn fetch_access_token(
|
|
|
|
|
client: &reqwest::Client,
|
|
|
|
|
file: &Option<String>,
|
|
|
|
|