feat: support vertexai (#308)

pull/309/head
sigoden 4 months ago committed by GitHub
parent 3bf0c371e4
commit 5e4210980d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -56,7 +56,7 @@ clients:
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
- type: azure-openai
api_base: https://RESOURCE.openai.azure.com
api_base: https://{RESOURCE}.openai.azure.com
api_key: xxx
models:
- name: MyGPT4 # Model deployment name
@ -69,4 +69,9 @@ clients:
# See https://help.aliyun.com/zh/dashscope/
- type: qianwen
api_key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
api_key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# See https://cloud.google.com/vertex-ai
- type: vertexai
api_base: https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models
api_key: xxx

@ -89,7 +89,7 @@ impl GeminiClient {
}
}
async fn send_message(builder: RequestBuilder) -> Result<String> {
pub(crate) async fn send_message(builder: RequestBuilder) -> Result<String> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -102,7 +102,7 @@ async fn send_message(builder: RequestBuilder) -> Result<String> {
Ok(output.to_string())
}
async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> {
pub(crate) async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> {
let res = builder.send().await?;
if res.status() != 200 {
let data: Value = res.json().await?;
@ -178,7 +178,7 @@ fn check_error(data: &Value) -> Result<()> {
}
}
fn build_body(data: SendData, _model: String) -> Result<Value> {
pub(crate) fn build_body(data: SendData, _model: String) -> Result<Value> {
let SendData {
mut messages,
temperature,

@ -20,4 +20,5 @@ register_client!(
),
(ernie, "ernie", ErnieConfig, ErnieClient),
(qianwen, "qianwen", QianwenConfig, QianwenClient),
(vertexai, "vertexai", VertexAIConfig, VertexAIClient),
);

@ -0,0 +1,92 @@
use super::{
Client, ExtraConfig, VertexAIClient, Model, PromptType,
SendData, TokensCountFactors,
};
use super::gemini::{build_body, send_message, send_message_streaming};
use crate::{render::ReplyHandler, utils::PromptKind};
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
const MODELS: [(&str, usize, &str); 2] = [
("gemini-pro", 32760, "text"),
("gemini-pro-vision", 16384, "text,vision"),
];
const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
#[derive(Debug, Clone, Deserialize, Default)]
pub struct VertexAIConfig {
pub name: Option<String>,
pub api_base: Option<String>,
pub api_key: Option<String>,
pub extra: Option<ExtraConfig>,
}
#[async_trait]
impl Client for VertexAIClient {
client_common_fns!();
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler).await
}
}
impl VertexAIClient {
config_get_fn!(api_base, get_api_base);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 2] = [
("api_base", "API Base:", true, PromptKind::String),
("api_key", "API Key:", true, PromptKind::String),
];
pub fn list_models(local_config: &VertexAIConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_tokens(Some(max_tokens))
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})
.collect()
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_base = self.get_api_base()?;
let api_key = self.get_api_key()?;
let func = match data.stream {
true => "streamGenerateContent",
false => "generateContent",
};
let body = build_body(data, self.model.name.clone())?;
let model = self.model.name.clone();
let url = format!("{api_base}/{}:{}", model, func);
debug!("VertexAI Request: {url} {body}");
let builder = client.post(url).bearer_auth(api_key).json(&body);
Ok(builder)
}
}
Loading…
Cancel
Save