refactor: add/use openai_compatible_client macro

This commit is contained in:
sigoden 2023-11-01 15:01:50 +08:00
parent 8d76fc77fb
commit da3c541b68
5 changed files with 69 additions and 110 deletions

View File

@ -1,8 +1,5 @@
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming}; use super::openai::openai_build_body;
use super::{AzureOpenAIClient, Client, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData}; use super::{AzureOpenAIClient, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use async_trait::async_trait; use async_trait::async_trait;
@ -26,27 +23,7 @@ pub struct AzureOpenAIModel {
max_tokens: Option<usize>, max_tokens: Option<usize>,
} }
#[async_trait] openai_compatible_client!(AzureOpenAIClient);
impl Client for AzureOpenAIClient {
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
openai_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyStreamHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
openai_send_message_streaming(builder, handler).await
}
}
impl AzureOpenAIClient { impl AzureOpenAIClient {
config_get_fn!(api_base, get_api_base); config_get_fn!(api_base, get_api_base);

View File

@ -15,12 +15,18 @@ use tokio::time::sleep;
use super::{openai::OpenAIConfig, ClientConfig}; use super::{openai::OpenAIConfig, ClientConfig};
#[macro_export] #[macro_export]
macro_rules! register_role { macro_rules! register_client {
( (
$(($name:literal, $config_key:ident, $config:ident, $client:ident),)+ $(($module:ident, $name:literal, $config_key:ident, $config:ident, $client:ident),)+
) => { ) => {
$(
mod $module;
)+
$(
use self::$module::$config;
)+
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, serde::Deserialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum ClientConfig { pub enum ClientConfig {
$( $(
@ -35,15 +41,15 @@ macro_rules! register_role {
$( $(
#[derive(Debug)] #[derive(Debug)]
pub struct $client { pub struct $client {
global_config: SharedConfig, global_config: $crate::config::SharedConfig,
config: $config, config: $config,
model_info: ModelInfo, model_info: $crate::config::ModelInfo,
} }
impl $client { impl $client {
pub const NAME: &str = $name; pub const NAME: &str = $name;
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> { pub fn init(global_config: $crate::config::SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone(); let model_info = global_config.read().model_info.clone();
let config = { let config = {
if let ClientConfig::$config_key(c) = &global_config.read().clients[model_info.index] { if let ClientConfig::$config_key(c) = &global_config.read().clients[model_info.index] {
@ -66,12 +72,12 @@ macro_rules! register_role {
)+ )+
pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> { pub fn init_client(config: $crate::config::SharedConfig) -> anyhow::Result<Box<dyn Client>> {
None None
$(.or_else(|| $client::init(config.clone())))+ $(.or_else(|| $client::init(config.clone())))+
.ok_or_else(|| { .ok_or_else(|| {
let model_info = config.read().model_info.clone(); let model_info = config.read().model_info.clone();
anyhow!( anyhow::anyhow!(
"Unknown client {} at config.clients[{}]", "Unknown client {} at config.clients[{}]",
&model_info.client, &model_info.client,
&model_info.index &model_info.index
@ -83,16 +89,16 @@ macro_rules! register_role {
vec![$($client::NAME,)+] vec![$($client::NAME,)+]
} }
pub fn create_client_config(client: &str) -> Result<Value> { pub fn create_client_config(client: &str) -> anyhow::Result<serde_json::Value> {
$( $(
if client == $client::NAME { if client == $client::NAME {
return create_config(&$client::PROMPTS, $client::NAME) return create_config(&$client::PROMPTS, $client::NAME)
} }
)+ )+
bail!("Unknown client {}", client) anyhow::bail!("Unknown client {}", client)
} }
pub fn all_models(config: &Config) -> Vec<ModelInfo> { pub fn all_models(config: &$crate::config::Config) -> Vec<$crate::config::ModelInfo> {
config config
.clients .clients
.iter() .iter()
@ -107,16 +113,53 @@ macro_rules! register_role {
}; };
} }
#[macro_export]
macro_rules! openai_compatible_client {
($client:ident) => {
#[async_trait]
impl $crate::client::Client for $crate::client::$client {
fn config(
&self,
) -> (
&$crate::config::SharedConfig,
&Option<$crate::client::ExtraConfig>,
) {
(&self.global_config, &self.config.extra)
}
async fn send_message_inner(
&self,
client: &reqwest::Client,
data: $crate::client::SendData,
) -> anyhow::Result<String> {
let builder = self.request_builder(client, data)?;
$crate::client::openai::openai_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &reqwest::Client,
handler: &mut $crate::repl::ReplyStreamHandler,
data: $crate::client::SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
$crate::client::openai::openai_send_message_streaming(builder, handler).await
}
}
};
}
#[macro_export]
macro_rules! config_get_fn { macro_rules! config_get_fn {
($field_name:ident, $fn_name:ident) => { ($field_name:ident, $fn_name:ident) => {
fn $fn_name(&self) -> Result<String> { fn $fn_name(&self) -> anyhow::Result<String> {
let api_key = self.config.$field_name.clone(); let api_key = self.config.$field_name.clone();
api_key api_key
.or_else(|| { .or_else(|| {
let env_prefix = Self::name(&self.config); let env_prefix = Self::name(&self.config);
let env_name = let env_name =
format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase(); format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase();
env::var(&env_name).ok() std::env::var(&env_name).ok()
}) })
.ok_or_else(|| anyhow::anyhow!("Miss {}", stringify!($field_name))) .ok_or_else(|| anyhow::anyhow!("Miss {}", stringify!($field_name)))
} }

View File

@ -1,14 +1,10 @@
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming}; use super::openai::openai_build_body;
use super::{Client, ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData}; use super::{ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder}; use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize; use serde::Deserialize;
use std::env;
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub struct LocalAIConfig { pub struct LocalAIConfig {
@ -26,27 +22,7 @@ pub struct LocalAIModel {
max_tokens: Option<usize>, max_tokens: Option<usize>,
} }
#[async_trait] openai_compatible_client!(LocalAIClient);
impl Client for LocalAIClient {
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
openai_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyStreamHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
openai_send_message_streaming(builder, handler).await
}
}
impl LocalAIClient { impl LocalAIClient {
config_get_fn!(api_key, get_api_key); config_get_fn!(api_key, get_api_key);

View File

@ -1,29 +1,15 @@
#[macro_use] #[macro_use]
mod common; mod common;
pub mod azure_openai;
pub mod localai;
pub mod openai;
pub use common::*; pub use common::*;
use self::azure_openai::AzureOpenAIConfig; use crate::{config::ModelInfo, repl::ReplyStreamHandler, utils::PromptKind};
use self::localai::LocalAIConfig;
use self::openai::OpenAIConfig;
use crate::{ register_client!(
config::{Config, ModelInfo, SharedConfig}, (openai, "openai", OpenAI, OpenAIConfig, OpenAIClient),
utils::PromptKind, (localai, "localai", LocalAI, LocalAIConfig, LocalAIClient),
};
use anyhow::{anyhow, bail, Result};
use serde::Deserialize;
use serde_json::Value;
register_role!(
("openai", OpenAI, OpenAIConfig, OpenAIClient),
("localai", LocalAI, LocalAIConfig, LocalAIClient),
( (
azure_openai,
"azure-openai", "azure-openai",
AzureOpenAI, AzureOpenAI,
AzureOpenAIConfig, AzureOpenAIConfig,

View File

@ -1,7 +1,4 @@
use super::{Client, ExtraConfig, ModelInfo, OpenAIClient, PromptKind, PromptType, SendData}; use super::{ExtraConfig, ModelInfo, OpenAIClient, PromptKind, PromptType, SendData, ReplyStreamHandler};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
use anyhow::{anyhow, bail, Result}; use anyhow::{anyhow, bail, Result};
use async_trait::async_trait; use async_trait::async_trait;
@ -29,27 +26,7 @@ pub struct OpenAIConfig {
pub extra: Option<ExtraConfig>, pub extra: Option<ExtraConfig>,
} }
#[async_trait] openai_compatible_client!(OpenAIClient);
impl Client for OpenAIClient {
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
openai_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyStreamHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
openai_send_message_streaming(builder, handler).await
}
}
impl OpenAIClient { impl OpenAIClient {
config_get_fn!(api_key, get_api_key); config_get_fn!(api_key, get_api_key);