|
|
@ -1,3 +1,4 @@
|
|
|
|
|
|
|
|
use super::access_token::*;
|
|
|
|
use super::{
|
|
|
|
use super::{
|
|
|
|
catch_error, json_stream, message::*, patch_system_message, Client, CompletionDetails,
|
|
|
|
catch_error, json_stream, message::*, patch_system_message, Client, CompletionDetails,
|
|
|
|
ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SseHandler,
|
|
|
|
ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SseHandler,
|
|
|
@ -12,8 +13,6 @@ use serde::Deserialize;
|
|
|
|
use serde_json::{json, Value};
|
|
|
|
use serde_json::{json, Value};
|
|
|
|
use std::path::PathBuf;
|
|
|
|
use std::path::PathBuf;
|
|
|
|
|
|
|
|
|
|
|
|
static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Deserialize, Default)]
|
|
|
|
#[derive(Debug, Clone, Deserialize, Default)]
|
|
|
|
pub struct VertexAIConfig {
|
|
|
|
pub struct VertexAIConfig {
|
|
|
|
pub name: Option<String>,
|
|
|
|
pub name: Option<String>,
|
|
|
@ -39,6 +38,7 @@ impl VertexAIClient {
|
|
|
|
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
|
|
|
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
|
|
|
let project_id = self.get_project_id()?;
|
|
|
|
let project_id = self.get_project_id()?;
|
|
|
|
let location = self.get_location()?;
|
|
|
|
let location = self.get_location()?;
|
|
|
|
|
|
|
|
let access_token = get_access_token(self.name())?;
|
|
|
|
|
|
|
|
|
|
|
|
let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers");
|
|
|
|
let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers");
|
|
|
|
|
|
|
|
|
|
|
@ -52,10 +52,7 @@ impl VertexAIClient {
|
|
|
|
|
|
|
|
|
|
|
|
debug!("VertexAI Request: {url} {body}");
|
|
|
|
debug!("VertexAI Request: {url} {body}");
|
|
|
|
|
|
|
|
|
|
|
|
let builder = client
|
|
|
|
let builder = client.post(url).bearer_auth(access_token).json(&body);
|
|
|
|
.post(url)
|
|
|
|
|
|
|
|
.bearer_auth(unsafe { &ACCESS_TOKEN.0 })
|
|
|
|
|
|
|
|
.json(&body);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ok(builder)
|
|
|
|
Ok(builder)
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -70,7 +67,7 @@ impl Client for VertexAIClient {
|
|
|
|
client: &ReqwestClient,
|
|
|
|
client: &ReqwestClient,
|
|
|
|
data: SendData,
|
|
|
|
data: SendData,
|
|
|
|
) -> Result<(String, CompletionDetails)> {
|
|
|
|
) -> Result<(String, CompletionDetails)> {
|
|
|
|
prepare_access_token(client, &self.config.adc_file).await?;
|
|
|
|
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
|
|
|
|
let builder = self.request_builder(client, data)?;
|
|
|
|
let builder = self.request_builder(client, data)?;
|
|
|
|
gemini_send_message(builder).await
|
|
|
|
gemini_send_message(builder).await
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -81,7 +78,7 @@ impl Client for VertexAIClient {
|
|
|
|
handler: &mut SseHandler,
|
|
|
|
handler: &mut SseHandler,
|
|
|
|
data: SendData,
|
|
|
|
data: SendData,
|
|
|
|
) -> Result<()> {
|
|
|
|
) -> Result<()> {
|
|
|
|
prepare_access_token(client, &self.config.adc_file).await?;
|
|
|
|
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
|
|
|
|
let builder = self.request_builder(client, data)?;
|
|
|
|
let builder = self.request_builder(client, data)?;
|
|
|
|
gemini_send_message_streaming(builder, handler).await
|
|
|
|
gemini_send_message_streaming(builder, handler).await
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -217,20 +214,24 @@ pub(crate) fn gemini_build_body(
|
|
|
|
Ok(body)
|
|
|
|
Ok(body)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async fn prepare_access_token(client: &reqwest::Client, adc_file: &Option<String>) -> Result<()> {
|
|
|
|
pub async fn prepare_gcloud_access_token(
|
|
|
|
if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } {
|
|
|
|
client: &reqwest::Client,
|
|
|
|
let (token, expires_in) = fetch_gcloud_access_token(client, adc_file)
|
|
|
|
client_name: &str,
|
|
|
|
|
|
|
|
adc_file: &Option<String>,
|
|
|
|
|
|
|
|
) -> Result<()> {
|
|
|
|
|
|
|
|
if !is_valid_access_token(client_name) {
|
|
|
|
|
|
|
|
let (token, expires_in) = fetch_access_token(client, adc_file)
|
|
|
|
.await
|
|
|
|
.await
|
|
|
|
.with_context(|| "Failed to fetch access token")?;
|
|
|
|
.with_context(|| "Failed to fetch access token")?;
|
|
|
|
let expires_at = Utc::now()
|
|
|
|
let expires_at = Utc::now()
|
|
|
|
+ Duration::try_seconds(expires_in)
|
|
|
|
+ Duration::try_seconds(expires_in)
|
|
|
|
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
|
|
|
|
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
|
|
|
|
unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) };
|
|
|
|
set_access_token(client_name, token, expires_at.timestamp())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Ok(())
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
pub async fn fetch_gcloud_access_token(
|
|
|
|
async fn fetch_access_token(
|
|
|
|
client: &reqwest::Client,
|
|
|
|
client: &reqwest::Client,
|
|
|
|
file: &Option<String>,
|
|
|
|
file: &Option<String>,
|
|
|
|
) -> Result<(String, i64)> {
|
|
|
|
) -> Result<(String, i64)> {
|
|
|
|