diff --git a/src/client/access_token.rs b/src/client/access_token.rs new file mode 100644 index 0000000..07640d5 --- /dev/null +++ b/src/client/access_token.rs @@ -0,0 +1,34 @@ +use anyhow::{anyhow, Result}; +use chrono::Utc; +use indexmap::IndexMap; +use lazy_static::lazy_static; +use parking_lot::RwLock; + +lazy_static! { + static ref ACCESS_TOKENS: RwLock> = + RwLock::new(IndexMap::new()); +} + +pub fn get_access_token(client_name: &str) -> Result { + ACCESS_TOKENS + .read() + .get(client_name) + .map(|(token, _)| token.clone()) + .ok_or_else(|| anyhow!("Invalid access token")) +} + +pub fn is_valid_access_token(client_name: &str) -> bool { + let access_tokens = ACCESS_TOKENS.read(); + let (token, expires_at) = match access_tokens.get(client_name) { + Some(v) => v, + None => return false, + }; + !token.is_empty() && Utc::now().timestamp() < *expires_at +} + +pub fn set_access_token(client_name: &str, token: String, expires_at: i64) { + let mut access_tokens = ACCESS_TOKENS.write(); + let entry = access_tokens.entry(client_name.to_string()).or_default(); + entry.0 = token; + entry.1 = expires_at; +} diff --git a/src/client/common.rs b/src/client/common.rs index fb0c38e..24e6f8e 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -102,8 +102,8 @@ macro_rules! register_client { } } - pub fn name(config: &$config) -> &str { - config.name.as_deref().unwrap_or(Self::NAME) + pub fn name(local_config: &$config) -> &str { + local_config.name.as_deref().unwrap_or(Self::NAME) } } @@ -184,6 +184,10 @@ macro_rules! client_common_fns { Self::list_models(&self.config) } + fn name(&self) -> &str { + Self::name(&self.config) + } + fn model(&self) -> &Model { &self.model } @@ -259,6 +263,8 @@ pub trait Client: Sync + Send { fn list_models(&self) -> Vec; + fn name(&self) -> &str; + fn model(&self) -> &Model; fn model_mut(&mut self) -> &mut Model; diff --git a/src/client/ernie.rs b/src/client/ernie.rs index 982edae..d3002ff 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -1,3 +1,4 @@ +use super::access_token::*; use super::{ maybe_catch_error, patch_system_message, sse_stream, Client, CompletionDetails, ErnieClient, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SsMmessage, SseHandler, @@ -5,7 +6,6 @@ use super::{ use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; -use chrono::Utc; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; @@ -14,8 +14,6 @@ use std::env; const API_BASE: &str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1"; const ACCESS_TOKEN_URL: &str = "https://aip.baidubce.com/oauth/2.0/token"; -static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); - #[derive(Debug, Clone, Deserialize, Default)] pub struct ErnieConfig { pub name: Option, @@ -34,11 +32,11 @@ impl ErnieClient { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let body = build_body(data, &self.model); + let access_token = get_access_token(self.name())?; let url = format!( - "{API_BASE}/wenxinworkshop/chat/{}?access_token={}", + "{API_BASE}/wenxinworkshop/chat/{}?access_token={access_token}", &self.model.name, - unsafe { &ACCESS_TOKEN.0 } ); debug!("Ernie Request: {url} {body}"); @@ -49,7 +47,8 @@ impl ErnieClient { } async fn prepare_access_token(&self) -> Result<()> { - if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } { + let client_name = self.name(); + if !is_valid_access_token(client_name) { let env_prefix = Self::name(&self.config).to_uppercase(); let api_key = self.config.api_key.clone(); let api_key = api_key @@ -65,7 +64,7 @@ impl ErnieClient { let token = fetch_access_token(&client, &api_key, &secret_key) .await .with_context(|| "Failed to fetch access token")?; - unsafe { ACCESS_TOKEN = (token, 86400) }; + set_access_token(client_name, token, 86400); } Ok(()) } diff --git a/src/client/mod.rs b/src/client/mod.rs index 43c8c4b..07f5f57 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,5 +1,6 @@ #[macro_use] mod common; +mod access_token; mod message; mod model; mod prompt_format; diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index db04f38..9907a7f 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -1,3 +1,4 @@ +use super::access_token::*; use super::{ catch_error, json_stream, message::*, patch_system_message, Client, CompletionDetails, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SseHandler, @@ -12,8 +13,6 @@ use serde::Deserialize; use serde_json::{json, Value}; use std::path::PathBuf; -static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation - #[derive(Debug, Clone, Deserialize, Default)] pub struct VertexAIConfig { pub name: Option, @@ -39,6 +38,7 @@ impl VertexAIClient { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let project_id = self.get_project_id()?; let location = self.get_location()?; + let access_token = get_access_token(self.name())?; let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"); @@ -52,10 +52,7 @@ impl VertexAIClient { debug!("VertexAI Request: {url} {body}"); - let builder = client - .post(url) - .bearer_auth(unsafe { &ACCESS_TOKEN.0 }) - .json(&body); + let builder = client.post(url).bearer_auth(access_token).json(&body); Ok(builder) } @@ -70,7 +67,7 @@ impl Client for VertexAIClient { client: &ReqwestClient, data: SendData, ) -> Result<(String, CompletionDetails)> { - prepare_access_token(client, &self.config.adc_file).await?; + prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; let builder = self.request_builder(client, data)?; gemini_send_message(builder).await } @@ -81,7 +78,7 @@ impl Client for VertexAIClient { handler: &mut SseHandler, data: SendData, ) -> Result<()> { - prepare_access_token(client, &self.config.adc_file).await?; + prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; let builder = self.request_builder(client, data)?; gemini_send_message_streaming(builder, handler).await } @@ -217,20 +214,24 @@ pub(crate) fn gemini_build_body( Ok(body) } -async fn prepare_access_token(client: &reqwest::Client, adc_file: &Option) -> Result<()> { - if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } { - let (token, expires_in) = fetch_gcloud_access_token(client, adc_file) +pub async fn prepare_gcloud_access_token( + client: &reqwest::Client, + client_name: &str, + adc_file: &Option, +) -> Result<()> { + if !is_valid_access_token(client_name) { + let (token, expires_in) = fetch_access_token(client, adc_file) .await .with_context(|| "Failed to fetch access token")?; let expires_at = Utc::now() + Duration::try_seconds(expires_in) .ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?; - unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) }; + set_access_token(client_name, token, expires_at.timestamp()) } Ok(()) } -pub async fn fetch_gcloud_access_token( +async fn fetch_access_token( client: &reqwest::Client, file: &Option, ) -> Result<(String, i64)> { diff --git a/src/client/vertexai_claude.rs b/src/client/vertexai_claude.rs index 78adb72..2b5f763 100644 --- a/src/client/vertexai_claude.rs +++ b/src/client/vertexai_claude.rs @@ -1,18 +1,16 @@ +use super::access_token::*; use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming}; -use super::vertexai::fetch_gcloud_access_token; +use super::vertexai::prepare_gcloud_access_token; use super::{ Client, CompletionDetails, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SseHandler, VertexAIClaudeClient, }; -use anyhow::{anyhow, Context, Result}; +use anyhow::Result; use async_trait::async_trait; -use chrono::{Duration, Utc}; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; -static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation - #[derive(Debug, Clone, Deserialize, Default)] pub struct VertexAIClaudeConfig { pub name: Option, @@ -36,6 +34,7 @@ impl VertexAIClaudeClient { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let project_id = self.get_project_id()?; let location = self.get_location()?; + let access_token = get_access_token(self.name())?; let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"); let url = format!( @@ -51,10 +50,7 @@ impl VertexAIClaudeClient { debug!("VertexAIClaude Request: {url} {body}"); - let builder = client - .post(url) - .bearer_auth(unsafe { &ACCESS_TOKEN.0 }) - .json(&body); + let builder = client.post(url).bearer_auth(access_token).json(&body); Ok(builder) } @@ -69,7 +65,7 @@ impl Client for VertexAIClaudeClient { client: &ReqwestClient, data: SendData, ) -> Result<(String, CompletionDetails)> { - prepare_access_token(client, &self.config.adc_file).await?; + prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; let builder = self.request_builder(client, data)?; claude_send_message(builder).await } @@ -80,21 +76,8 @@ impl Client for VertexAIClaudeClient { handler: &mut SseHandler, data: SendData, ) -> Result<()> { - prepare_access_token(client, &self.config.adc_file).await?; + prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; let builder = self.request_builder(client, data)?; claude_send_message_streaming(builder, handler).await } } - -async fn prepare_access_token(client: &reqwest::Client, adc_file: &Option) -> Result<()> { - if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } { - let (token, expires_in) = fetch_gcloud_access_token(client, adc_file) - .await - .with_context(|| "Failed to fetch access token")?; - let expires_at = Utc::now() - + Duration::try_seconds(expires_in) - .ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?; - unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) }; - } - Ok(()) -}