mirror of https://github.com/sigoden/aichat
feat: support vertexai (#308)
parent
3bf0c371e4
commit
5e4210980d
@ -0,0 +1,92 @@
|
||||
use super::{
|
||||
Client, ExtraConfig, VertexAIClient, Model, PromptType,
|
||||
SendData, TokensCountFactors,
|
||||
};
|
||||
use super::gemini::{build_body, send_message, send_message_streaming};
|
||||
|
||||
use crate::{render::ReplyHandler, utils::PromptKind};
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
const MODELS: [(&str, usize, &str); 2] = [
|
||||
("gemini-pro", 32760, "text"),
|
||||
("gemini-pro-vision", 16384, "text,vision"),
|
||||
];
|
||||
|
||||
const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct VertexAIConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_base: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Client for VertexAIClient {
|
||||
client_common_fns!();
|
||||
|
||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
send_message(builder).await
|
||||
}
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut ReplyHandler,
|
||||
data: SendData,
|
||||
) -> Result<()> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
send_message_streaming(builder, handler).await
|
||||
}
|
||||
}
|
||||
|
||||
impl VertexAIClient {
|
||||
config_get_fn!(api_base, get_api_base);
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
|
||||
pub const PROMPTS: [PromptType<'static>; 2] = [
|
||||
("api_base", "API Base:", true, PromptKind::String),
|
||||
("api_key", "API Key:", true, PromptKind::String),
|
||||
];
|
||||
|
||||
pub fn list_models(local_config: &VertexAIConfig) -> Vec<Model> {
|
||||
let client_name = Self::name(local_config);
|
||||
MODELS
|
||||
.into_iter()
|
||||
.map(|(name, max_tokens, capabilities)| {
|
||||
Model::new(client_name, name)
|
||||
.set_capabilities(capabilities.into())
|
||||
.set_max_tokens(Some(max_tokens))
|
||||
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
let api_base = self.get_api_base()?;
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
let func = match data.stream {
|
||||
true => "streamGenerateContent",
|
||||
false => "generateContent",
|
||||
};
|
||||
|
||||
let body = build_body(data, self.model.name.clone())?;
|
||||
|
||||
let model = self.model.name.clone();
|
||||
|
||||
let url = format!("{api_base}/{}:{}", model, func);
|
||||
|
||||
debug!("VertexAI Request: {url} {body}");
|
||||
|
||||
let builder = client.post(url).bearer_auth(api_key).json(&body);
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue