From d0208937633cbf255471269e079c10e862cc4e39 Mon Sep 17 00:00:00 2001 From: sigoden Date: Sat, 27 Jul 2024 09:41:28 +0800 Subject: [PATCH] refactor: split clients/macros.rs from clients/common.rs (#750) --- src/client/common.rs | 293 ------------------------------------------- src/client/macros.rs | 292 ++++++++++++++++++++++++++++++++++++++++++ src/client/mod.rs | 5 +- 3 files changed, 295 insertions(+), 295 deletions(-) create mode 100644 src/client/macros.rs diff --git a/src/client/common.rs b/src/client/common.rs index 1d802c5..71bf811 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -25,299 +25,6 @@ lazy_static! { static ref ESCAPE_SLASH_RE: Regex = Regex::new(r"(? { - $( - mod $module; - )+ - $( - use self::$module::$config; - )+ - - #[derive(Debug, Clone, serde::Deserialize)] - #[serde(tag = "type")] - pub enum ClientConfig { - $( - #[serde(rename = $name)] - $config($config), - )+ - #[serde(other)] - Unknown, - } - - $( - #[derive(Debug)] - pub struct $client { - global_config: $crate::config::GlobalConfig, - config: $config, - model: $crate::client::Model, - } - - impl $client { - pub const NAME: &'static str = $name; - - pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option> { - let config = global_config.read().clients.iter().find_map(|client_config| { - if let ClientConfig::$config(c) = client_config { - if Self::name(c) == model.client_name() { - return Some(c.clone()) - } - } - None - })?; - - Some(Box::new(Self { - global_config: global_config.clone(), - config, - model: model.clone(), - })) - } - - pub fn list_models(local_config: &$config) -> Vec { - let client_name = Self::name(local_config); - if local_config.models.is_empty() { - if let Some(models) = $crate::client::ALL_MODELS.iter().find(|v| { - v.platform == $name || - ($name == OpenAICompatibleClient::NAME && local_config.name.as_deref() == Some(&v.platform)) || - ($name == RagDedicatedClient::NAME && local_config.name.as_deref() == Some(&v.platform)) - }) { - return Model::from_config(client_name, &models.models); - } - vec![] - } else { - Model::from_config(client_name, &local_config.models) - } - } - - pub fn name(local_config: &$config) -> &str { - local_config.name.as_deref().unwrap_or(Self::NAME) - } - } - - )+ - - pub fn init_client(config: &$crate::config::GlobalConfig, model: Option<$crate::client::Model>) -> anyhow::Result> { - let model = model.unwrap_or_else(|| config.read().model.clone()); - None - $(.or_else(|| $client::init(config, &model)))+ - .ok_or_else(|| { - anyhow::anyhow!("Invalid model '{}'", model.id()) - }) - } - - pub fn list_client_types() -> Vec<&'static str> { - let mut client_types: Vec<_> = vec![$($client::NAME,)+]; - client_types.extend($crate::client::OPENAI_COMPATIBLE_PLATFORMS.iter().map(|(name, _)| *name)); - client_types - } - - pub fn create_client_config(client: &str) -> anyhow::Result<(String, serde_json::Value)> { - $( - if client == $client::NAME { - return create_config(&$client::PROMPTS, $client::NAME) - } - )+ - if let Some(ret) = create_openai_compatible_client_config(client)? { - return Ok(ret); - } - anyhow::bail!("Unknown client '{}'", client) - } - - static mut ALL_CLIENT_MODELS: Option> = None; - - pub fn list_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { - if unsafe { ALL_CLIENT_MODELS.is_none() } { - let models: Vec<_> = config - .clients - .iter() - .flat_map(|v| match v { - $(ClientConfig::$config(c) => $client::list_models(c),)+ - ClientConfig::Unknown => vec![], - }) - .collect(); - unsafe { ALL_CLIENT_MODELS = Some(models) }; - } - unsafe { ALL_CLIENT_MODELS.as_ref().unwrap().iter().collect() } - } - - pub fn list_chat_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { - list_models(config).into_iter().filter(|v| v.model_type() == "chat").collect() - } - - pub fn list_embedding_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { - list_models(config).into_iter().filter(|v| v.model_type() == "embedding").collect() - } - - pub fn list_reranker_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { - list_models(config).into_iter().filter(|v| v.model_type() == "reranker").collect() - } - }; -} - -#[macro_export] -macro_rules! client_common_fns { - () => { - fn global_config(&self) -> &$crate::config::GlobalConfig { - &self.global_config - } - - fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> { - self.config.extra.as_ref() - } - - fn patches_config(&self) -> Option<&$crate::client::ModelPatches> { - self.config.patches.as_ref() - } - - fn name(&self) -> &str { - Self::name(&self.config) - } - - fn model(&self) -> &Model { - &self.model - } - - fn model_mut(&mut self) -> &mut Model { - &mut self.model - } - }; -} - -#[macro_export] -macro_rules! impl_client_trait { - ($client:ident, $chat_completions:path, $chat_completions_streaming:path) => { - #[async_trait::async_trait] - impl $crate::client::Client for $crate::client::$client { - client_common_fns!(); - - async fn chat_completions_inner( - &self, - client: &reqwest::Client, - data: $crate::client::ChatCompletionsData, - ) -> anyhow::Result<$crate::client::ChatCompletionsOutput> { - let builder = self.chat_completions_builder(client, data)?; - $chat_completions(builder).await - } - - async fn chat_completions_streaming_inner( - &self, - client: &reqwest::Client, - handler: &mut $crate::client::SseHandler, - data: $crate::client::ChatCompletionsData, - ) -> Result<()> { - let builder = self.chat_completions_builder(client, data)?; - $chat_completions_streaming(builder, handler).await - } - } - }; - ($client:ident, $chat_completions:path, $chat_completions_streaming:path, $embeddings:path) => { - #[async_trait::async_trait] - impl $crate::client::Client for $crate::client::$client { - client_common_fns!(); - - async fn chat_completions_inner( - &self, - client: &reqwest::Client, - data: $crate::client::ChatCompletionsData, - ) -> anyhow::Result<$crate::client::ChatCompletionsOutput> { - let builder = self.chat_completions_builder(client, data)?; - $chat_completions(builder).await - } - - async fn chat_completions_streaming_inner( - &self, - client: &reqwest::Client, - handler: &mut $crate::client::SseHandler, - data: $crate::client::ChatCompletionsData, - ) -> Result<()> { - let builder = self.chat_completions_builder(client, data)?; - $chat_completions_streaming(builder, handler).await - } - - async fn embeddings_inner( - &self, - client: &reqwest::Client, - data: $crate::client::EmbeddingsData, - ) -> Result<$crate::client::EmbeddingsOutput> { - let builder = self.embeddings_builder(client, data)?; - $embeddings(builder).await - } - } - }; - ($client:ident, $chat_completions:path, $chat_completions_streaming:path, $embeddings:path, $rerank:path) => { - #[async_trait::async_trait] - impl $crate::client::Client for $crate::client::$client { - client_common_fns!(); - - async fn chat_completions_inner( - &self, - client: &reqwest::Client, - data: $crate::client::ChatCompletionsData, - ) -> anyhow::Result<$crate::client::ChatCompletionsOutput> { - let builder = self.chat_completions_builder(client, data)?; - $chat_completions(builder).await - } - - async fn chat_completions_streaming_inner( - &self, - client: &reqwest::Client, - handler: &mut $crate::client::SseHandler, - data: $crate::client::ChatCompletionsData, - ) -> Result<()> { - let builder = self.chat_completions_builder(client, data)?; - $chat_completions_streaming(builder, handler).await - } - - async fn embeddings_inner( - &self, - client: &reqwest::Client, - data: $crate::client::EmbeddingsData, - ) -> Result<$crate::client::EmbeddingsOutput> { - let builder = self.embeddings_builder(client, data)?; - $embeddings(builder).await - } - - async fn rerank_inner( - &self, - client: &reqwest::Client, - data: $crate::client::RerankData, - ) -> Result<$crate::client::RerankOutput> { - let builder = self.rerank_builder(client, data)?; - $rerank(builder).await - } - } - }; -} - -#[macro_export] -macro_rules! config_get_fn { - ($field_name:ident, $fn_name:ident) => { - fn $fn_name(&self) -> anyhow::Result { - let api_key = self.config.$field_name.clone(); - api_key - .or_else(|| { - let env_prefix = Self::name(&self.config); - let env_name = - format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase(); - std::env::var(&env_name).ok() - }) - .ok_or_else(|| { - anyhow::anyhow!("Miss '{}' in client configuration", stringify!($field_name)) - }) - } - }; -} - -#[macro_export] -macro_rules! unsupported_model { - ($name:expr) => { - anyhow::bail!("Unsupported model '{}'", $name) - }; -} - #[async_trait] pub trait Client: Sync + Send { fn global_config(&self) -> &GlobalConfig; diff --git a/src/client/macros.rs b/src/client/macros.rs new file mode 100644 index 0000000..344bc92 --- /dev/null +++ b/src/client/macros.rs @@ -0,0 +1,292 @@ +#[macro_export] +macro_rules! register_client { + ( + $(($module:ident, $name:literal, $config:ident, $client:ident),)+ + ) => { + $( + mod $module; + )+ + $( + use self::$module::$config; + )+ + + #[derive(Debug, Clone, serde::Deserialize)] + #[serde(tag = "type")] + pub enum ClientConfig { + $( + #[serde(rename = $name)] + $config($config), + )+ + #[serde(other)] + Unknown, + } + + $( + #[derive(Debug)] + pub struct $client { + global_config: $crate::config::GlobalConfig, + config: $config, + model: $crate::client::Model, + } + + impl $client { + pub const NAME: &'static str = $name; + + pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option> { + let config = global_config.read().clients.iter().find_map(|client_config| { + if let ClientConfig::$config(c) = client_config { + if Self::name(c) == model.client_name() { + return Some(c.clone()) + } + } + None + })?; + + Some(Box::new(Self { + global_config: global_config.clone(), + config, + model: model.clone(), + })) + } + + pub fn list_models(local_config: &$config) -> Vec { + let client_name = Self::name(local_config); + if local_config.models.is_empty() { + if let Some(models) = $crate::client::ALL_MODELS.iter().find(|v| { + v.platform == $name || + ($name == OpenAICompatibleClient::NAME && local_config.name.as_deref() == Some(&v.platform)) || + ($name == RagDedicatedClient::NAME && local_config.name.as_deref() == Some(&v.platform)) + }) { + return Model::from_config(client_name, &models.models); + } + vec![] + } else { + Model::from_config(client_name, &local_config.models) + } + } + + pub fn name(local_config: &$config) -> &str { + local_config.name.as_deref().unwrap_or(Self::NAME) + } + } + + )+ + + pub fn init_client(config: &$crate::config::GlobalConfig, model: Option<$crate::client::Model>) -> anyhow::Result> { + let model = model.unwrap_or_else(|| config.read().model.clone()); + None + $(.or_else(|| $client::init(config, &model)))+ + .ok_or_else(|| { + anyhow::anyhow!("Invalid model '{}'", model.id()) + }) + } + + pub fn list_client_types() -> Vec<&'static str> { + let mut client_types: Vec<_> = vec![$($client::NAME,)+]; + client_types.extend($crate::client::OPENAI_COMPATIBLE_PLATFORMS.iter().map(|(name, _)| *name)); + client_types + } + + pub fn create_client_config(client: &str) -> anyhow::Result<(String, serde_json::Value)> { + $( + if client == $client::NAME { + return create_config(&$client::PROMPTS, $client::NAME) + } + )+ + if let Some(ret) = create_openai_compatible_client_config(client)? { + return Ok(ret); + } + anyhow::bail!("Unknown client '{}'", client) + } + + static mut ALL_CLIENT_MODELS: Option> = None; + + pub fn list_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { + if unsafe { ALL_CLIENT_MODELS.is_none() } { + let models: Vec<_> = config + .clients + .iter() + .flat_map(|v| match v { + $(ClientConfig::$config(c) => $client::list_models(c),)+ + ClientConfig::Unknown => vec![], + }) + .collect(); + unsafe { ALL_CLIENT_MODELS = Some(models) }; + } + unsafe { ALL_CLIENT_MODELS.as_ref().unwrap().iter().collect() } + } + + pub fn list_chat_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { + list_models(config).into_iter().filter(|v| v.model_type() == "chat").collect() + } + + pub fn list_embedding_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { + list_models(config).into_iter().filter(|v| v.model_type() == "embedding").collect() + } + + pub fn list_reranker_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { + list_models(config).into_iter().filter(|v| v.model_type() == "reranker").collect() + } + }; +} + +#[macro_export] +macro_rules! client_common_fns { + () => { + fn global_config(&self) -> &$crate::config::GlobalConfig { + &self.global_config + } + + fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> { + self.config.extra.as_ref() + } + + fn patches_config(&self) -> Option<&$crate::client::ModelPatches> { + self.config.patches.as_ref() + } + + fn name(&self) -> &str { + Self::name(&self.config) + } + + fn model(&self) -> &Model { + &self.model + } + + fn model_mut(&mut self) -> &mut Model { + &mut self.model + } + }; +} + +#[macro_export] +macro_rules! impl_client_trait { + ($client:ident, $chat_completions:path, $chat_completions_streaming:path) => { + #[async_trait::async_trait] + impl $crate::client::Client for $crate::client::$client { + client_common_fns!(); + + async fn chat_completions_inner( + &self, + client: &reqwest::Client, + data: $crate::client::ChatCompletionsData, + ) -> anyhow::Result<$crate::client::ChatCompletionsOutput> { + let builder = self.chat_completions_builder(client, data)?; + $chat_completions(builder).await + } + + async fn chat_completions_streaming_inner( + &self, + client: &reqwest::Client, + handler: &mut $crate::client::SseHandler, + data: $crate::client::ChatCompletionsData, + ) -> Result<()> { + let builder = self.chat_completions_builder(client, data)?; + $chat_completions_streaming(builder, handler).await + } + } + }; + ($client:ident, $chat_completions:path, $chat_completions_streaming:path, $embeddings:path) => { + #[async_trait::async_trait] + impl $crate::client::Client for $crate::client::$client { + client_common_fns!(); + + async fn chat_completions_inner( + &self, + client: &reqwest::Client, + data: $crate::client::ChatCompletionsData, + ) -> anyhow::Result<$crate::client::ChatCompletionsOutput> { + let builder = self.chat_completions_builder(client, data)?; + $chat_completions(builder).await + } + + async fn chat_completions_streaming_inner( + &self, + client: &reqwest::Client, + handler: &mut $crate::client::SseHandler, + data: $crate::client::ChatCompletionsData, + ) -> Result<()> { + let builder = self.chat_completions_builder(client, data)?; + $chat_completions_streaming(builder, handler).await + } + + async fn embeddings_inner( + &self, + client: &reqwest::Client, + data: $crate::client::EmbeddingsData, + ) -> Result<$crate::client::EmbeddingsOutput> { + let builder = self.embeddings_builder(client, data)?; + $embeddings(builder).await + } + } + }; + ($client:ident, $chat_completions:path, $chat_completions_streaming:path, $embeddings:path, $rerank:path) => { + #[async_trait::async_trait] + impl $crate::client::Client for $crate::client::$client { + client_common_fns!(); + + async fn chat_completions_inner( + &self, + client: &reqwest::Client, + data: $crate::client::ChatCompletionsData, + ) -> anyhow::Result<$crate::client::ChatCompletionsOutput> { + let builder = self.chat_completions_builder(client, data)?; + $chat_completions(builder).await + } + + async fn chat_completions_streaming_inner( + &self, + client: &reqwest::Client, + handler: &mut $crate::client::SseHandler, + data: $crate::client::ChatCompletionsData, + ) -> Result<()> { + let builder = self.chat_completions_builder(client, data)?; + $chat_completions_streaming(builder, handler).await + } + + async fn embeddings_inner( + &self, + client: &reqwest::Client, + data: $crate::client::EmbeddingsData, + ) -> Result<$crate::client::EmbeddingsOutput> { + let builder = self.embeddings_builder(client, data)?; + $embeddings(builder).await + } + + async fn rerank_inner( + &self, + client: &reqwest::Client, + data: $crate::client::RerankData, + ) -> Result<$crate::client::RerankOutput> { + let builder = self.rerank_builder(client, data)?; + $rerank(builder).await + } + } + }; +} + +#[macro_export] +macro_rules! config_get_fn { + ($field_name:ident, $fn_name:ident) => { + fn $fn_name(&self) -> anyhow::Result { + let api_key = self.config.$field_name.clone(); + api_key + .or_else(|| { + let env_prefix = Self::name(&self.config); + let env_name = + format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase(); + std::env::var(&env_name).ok() + }) + .ok_or_else(|| { + anyhow::anyhow!("Miss '{}' in client configuration", stringify!($field_name)) + }) + } + }; +} + +#[macro_export] +macro_rules! unsupported_model { + ($name:expr) => { + anyhow::bail!("Unsupported model '{}'", $name) + }; +} diff --git a/src/client/mod.rs b/src/client/mod.rs index dd1d0b1..b2bd3de 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,7 +1,8 @@ -#[macro_use] -mod common; mod access_token; +mod common; mod message; +#[macro_use] +mod macros; mod model; mod prompt_format; mod stream;