diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs index cabcebd..d0fa31f 100644 --- a/src/client/azure_openai.rs +++ b/src/client/azure_openai.rs @@ -1,8 +1,5 @@ -use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming}; -use super::{AzureOpenAIClient, Client, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData}; - -use crate::config::SharedConfig; -use crate::repl::ReplyStreamHandler; +use super::openai::openai_build_body; +use super::{AzureOpenAIClient, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData}; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -26,27 +23,7 @@ pub struct AzureOpenAIModel { max_tokens: Option, } -#[async_trait] -impl Client for AzureOpenAIClient { - fn config(&self) -> (&SharedConfig, &Option) { - (&self.global_config, &self.config.extra) - } - - async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { - let builder = self.request_builder(client, data)?; - openai_send_message(builder).await - } - - async fn send_message_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut ReplyStreamHandler, - data: SendData, - ) -> Result<()> { - let builder = self.request_builder(client, data)?; - openai_send_message_streaming(builder, handler).await - } -} +openai_compatible_client!(AzureOpenAIClient); impl AzureOpenAIClient { config_get_fn!(api_base, get_api_base); diff --git a/src/client/common.rs b/src/client/common.rs index 6e5f365..c77f89b 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -15,12 +15,18 @@ use tokio::time::sleep; use super::{openai::OpenAIConfig, ClientConfig}; #[macro_export] -macro_rules! register_role { +macro_rules! register_client { ( - $(($name:literal, $config_key:ident, $config:ident, $client:ident),)+ + $(($module:ident, $name:literal, $config_key:ident, $config:ident, $client:ident),)+ ) => { + $( + mod $module; + )+ + $( + use self::$module::$config; + )+ - #[derive(Debug, Clone, Deserialize)] + #[derive(Debug, Clone, serde::Deserialize)] #[serde(tag = "type")] pub enum ClientConfig { $( @@ -35,15 +41,15 @@ macro_rules! register_role { $( #[derive(Debug)] pub struct $client { - global_config: SharedConfig, + global_config: $crate::config::SharedConfig, config: $config, - model_info: ModelInfo, + model_info: $crate::config::ModelInfo, } impl $client { pub const NAME: &str = $name; - pub fn init(global_config: SharedConfig) -> Option> { + pub fn init(global_config: $crate::config::SharedConfig) -> Option> { let model_info = global_config.read().model_info.clone(); let config = { if let ClientConfig::$config_key(c) = &global_config.read().clients[model_info.index] { @@ -66,12 +72,12 @@ macro_rules! register_role { )+ - pub fn init_client(config: SharedConfig) -> Result> { + pub fn init_client(config: $crate::config::SharedConfig) -> anyhow::Result> { None $(.or_else(|| $client::init(config.clone())))+ .ok_or_else(|| { let model_info = config.read().model_info.clone(); - anyhow!( + anyhow::anyhow!( "Unknown client {} at config.clients[{}]", &model_info.client, &model_info.index @@ -83,16 +89,16 @@ macro_rules! register_role { vec![$($client::NAME,)+] } - pub fn create_client_config(client: &str) -> Result { + pub fn create_client_config(client: &str) -> anyhow::Result { $( if client == $client::NAME { return create_config(&$client::PROMPTS, $client::NAME) } )+ - bail!("Unknown client {}", client) + anyhow::bail!("Unknown client {}", client) } - pub fn all_models(config: &Config) -> Vec { + pub fn all_models(config: &$crate::config::Config) -> Vec<$crate::config::ModelInfo> { config .clients .iter() @@ -107,16 +113,53 @@ macro_rules! register_role { }; } +#[macro_export] +macro_rules! openai_compatible_client { + ($client:ident) => { + #[async_trait] + impl $crate::client::Client for $crate::client::$client { + fn config( + &self, + ) -> ( + &$crate::config::SharedConfig, + &Option<$crate::client::ExtraConfig>, + ) { + (&self.global_config, &self.config.extra) + } + + async fn send_message_inner( + &self, + client: &reqwest::Client, + data: $crate::client::SendData, + ) -> anyhow::Result { + let builder = self.request_builder(client, data)?; + $crate::client::openai::openai_send_message(builder).await + } + + async fn send_message_streaming_inner( + &self, + client: &reqwest::Client, + handler: &mut $crate::repl::ReplyStreamHandler, + data: $crate::client::SendData, + ) -> Result<()> { + let builder = self.request_builder(client, data)?; + $crate::client::openai::openai_send_message_streaming(builder, handler).await + } + } + }; +} + +#[macro_export] macro_rules! config_get_fn { ($field_name:ident, $fn_name:ident) => { - fn $fn_name(&self) -> Result { + fn $fn_name(&self) -> anyhow::Result { let api_key = self.config.$field_name.clone(); api_key .or_else(|| { let env_prefix = Self::name(&self.config); let env_name = format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase(); - env::var(&env_name).ok() + std::env::var(&env_name).ok() }) .ok_or_else(|| anyhow::anyhow!("Miss {}", stringify!($field_name))) } diff --git a/src/client/localai.rs b/src/client/localai.rs index 131d92b..291fc96 100644 --- a/src/client/localai.rs +++ b/src/client/localai.rs @@ -1,14 +1,10 @@ -use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming}; -use super::{Client, ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData}; - -use crate::config::SharedConfig; -use crate::repl::ReplyStreamHandler; +use super::openai::openai_build_body; +use super::{ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData}; use anyhow::Result; use async_trait::async_trait; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; -use std::env; #[derive(Debug, Clone, Deserialize)] pub struct LocalAIConfig { @@ -26,27 +22,7 @@ pub struct LocalAIModel { max_tokens: Option, } -#[async_trait] -impl Client for LocalAIClient { - fn config(&self) -> (&SharedConfig, &Option) { - (&self.global_config, &self.config.extra) - } - - async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { - let builder = self.request_builder(client, data)?; - openai_send_message(builder).await - } - - async fn send_message_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut ReplyStreamHandler, - data: SendData, - ) -> Result<()> { - let builder = self.request_builder(client, data)?; - openai_send_message_streaming(builder, handler).await - } -} +openai_compatible_client!(LocalAIClient); impl LocalAIClient { config_get_fn!(api_key, get_api_key); diff --git a/src/client/mod.rs b/src/client/mod.rs index 5562eb6..dba049d 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,29 +1,15 @@ #[macro_use] mod common; -pub mod azure_openai; -pub mod localai; -pub mod openai; - pub use common::*; -use self::azure_openai::AzureOpenAIConfig; -use self::localai::LocalAIConfig; -use self::openai::OpenAIConfig; - -use crate::{ - config::{Config, ModelInfo, SharedConfig}, - utils::PromptKind, -}; - -use anyhow::{anyhow, bail, Result}; -use serde::Deserialize; -use serde_json::Value; +use crate::{config::ModelInfo, repl::ReplyStreamHandler, utils::PromptKind}; -register_role!( - ("openai", OpenAI, OpenAIConfig, OpenAIClient), - ("localai", LocalAI, LocalAIConfig, LocalAIClient), +register_client!( + (openai, "openai", OpenAI, OpenAIConfig, OpenAIClient), + (localai, "localai", LocalAI, LocalAIConfig, LocalAIClient), ( + azure_openai, "azure-openai", AzureOpenAI, AzureOpenAIConfig, diff --git a/src/client/openai.rs b/src/client/openai.rs index 4ae3c5f..b554f35 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -1,7 +1,4 @@ -use super::{Client, ExtraConfig, ModelInfo, OpenAIClient, PromptKind, PromptType, SendData}; - -use crate::config::SharedConfig; -use crate::repl::ReplyStreamHandler; +use super::{ExtraConfig, ModelInfo, OpenAIClient, PromptKind, PromptType, SendData, ReplyStreamHandler}; use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; @@ -29,27 +26,7 @@ pub struct OpenAIConfig { pub extra: Option, } -#[async_trait] -impl Client for OpenAIClient { - fn config(&self) -> (&SharedConfig, &Option) { - (&self.global_config, &self.config.extra) - } - - async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { - let builder = self.request_builder(client, data)?; - openai_send_message(builder).await - } - - async fn send_message_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut ReplyStreamHandler, - data: SendData, - ) -> Result<()> { - let builder = self.request_builder(client, data)?; - openai_send_message_streaming(builder, handler).await - } -} +openai_compatible_client!(OpenAIClient); impl OpenAIClient { config_get_fn!(api_key, get_api_key);