refactor: add/use openai_compatible_client macro

This commit is contained in:
sigoden 2023-11-01 15:01:50 +08:00
parent 8d76fc77fb
commit da3c541b68
5 changed files with 69 additions and 110 deletions

View File

@ -1,8 +1,5 @@
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
use super::{AzureOpenAIClient, Client, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
use super::openai::openai_build_body;
use super::{AzureOpenAIClient, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
@ -26,27 +23,7 @@ pub struct AzureOpenAIModel {
max_tokens: Option<usize>,
}
#[async_trait]
impl Client for AzureOpenAIClient {
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
openai_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyStreamHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
openai_send_message_streaming(builder, handler).await
}
}
openai_compatible_client!(AzureOpenAIClient);
impl AzureOpenAIClient {
config_get_fn!(api_base, get_api_base);

View File

@ -15,12 +15,18 @@ use tokio::time::sleep;
use super::{openai::OpenAIConfig, ClientConfig};
#[macro_export]
macro_rules! register_role {
macro_rules! register_client {
(
$(($name:literal, $config_key:ident, $config:ident, $client:ident),)+
$(($module:ident, $name:literal, $config_key:ident, $config:ident, $client:ident),)+
) => {
$(
mod $module;
)+
$(
use self::$module::$config;
)+
#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(tag = "type")]
pub enum ClientConfig {
$(
@ -35,15 +41,15 @@ macro_rules! register_role {
$(
#[derive(Debug)]
pub struct $client {
global_config: SharedConfig,
global_config: $crate::config::SharedConfig,
config: $config,
model_info: ModelInfo,
model_info: $crate::config::ModelInfo,
}
impl $client {
pub const NAME: &str = $name;
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
pub fn init(global_config: $crate::config::SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
let config = {
if let ClientConfig::$config_key(c) = &global_config.read().clients[model_info.index] {
@ -66,12 +72,12 @@ macro_rules! register_role {
)+
pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
pub fn init_client(config: $crate::config::SharedConfig) -> anyhow::Result<Box<dyn Client>> {
None
$(.or_else(|| $client::init(config.clone())))+
.ok_or_else(|| {
let model_info = config.read().model_info.clone();
anyhow!(
anyhow::anyhow!(
"Unknown client {} at config.clients[{}]",
&model_info.client,
&model_info.index
@ -83,16 +89,16 @@ macro_rules! register_role {
vec![$($client::NAME,)+]
}
pub fn create_client_config(client: &str) -> Result<Value> {
pub fn create_client_config(client: &str) -> anyhow::Result<serde_json::Value> {
$(
if client == $client::NAME {
return create_config(&$client::PROMPTS, $client::NAME)
}
)+
bail!("Unknown client {}", client)
anyhow::bail!("Unknown client {}", client)
}
pub fn all_models(config: &Config) -> Vec<ModelInfo> {
pub fn all_models(config: &$crate::config::Config) -> Vec<$crate::config::ModelInfo> {
config
.clients
.iter()
@ -107,16 +113,53 @@ macro_rules! register_role {
};
}
#[macro_export]
macro_rules! openai_compatible_client {
($client:ident) => {
#[async_trait]
impl $crate::client::Client for $crate::client::$client {
fn config(
&self,
) -> (
&$crate::config::SharedConfig,
&Option<$crate::client::ExtraConfig>,
) {
(&self.global_config, &self.config.extra)
}
async fn send_message_inner(
&self,
client: &reqwest::Client,
data: $crate::client::SendData,
) -> anyhow::Result<String> {
let builder = self.request_builder(client, data)?;
$crate::client::openai::openai_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &reqwest::Client,
handler: &mut $crate::repl::ReplyStreamHandler,
data: $crate::client::SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
$crate::client::openai::openai_send_message_streaming(builder, handler).await
}
}
};
}
#[macro_export]
macro_rules! config_get_fn {
($field_name:ident, $fn_name:ident) => {
fn $fn_name(&self) -> Result<String> {
fn $fn_name(&self) -> anyhow::Result<String> {
let api_key = self.config.$field_name.clone();
api_key
.or_else(|| {
let env_prefix = Self::name(&self.config);
let env_name =
format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase();
env::var(&env_name).ok()
std::env::var(&env_name).ok()
})
.ok_or_else(|| anyhow::anyhow!("Miss {}", stringify!($field_name)))
}

View File

@ -1,14 +1,10 @@
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
use super::{Client, ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
use super::openai::openai_build_body;
use super::{ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData};
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use std::env;
#[derive(Debug, Clone, Deserialize)]
pub struct LocalAIConfig {
@ -26,27 +22,7 @@ pub struct LocalAIModel {
max_tokens: Option<usize>,
}
#[async_trait]
impl Client for LocalAIClient {
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
openai_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyStreamHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
openai_send_message_streaming(builder, handler).await
}
}
openai_compatible_client!(LocalAIClient);
impl LocalAIClient {
config_get_fn!(api_key, get_api_key);

View File

@ -1,29 +1,15 @@
#[macro_use]
mod common;
pub mod azure_openai;
pub mod localai;
pub mod openai;
pub use common::*;
use self::azure_openai::AzureOpenAIConfig;
use self::localai::LocalAIConfig;
use self::openai::OpenAIConfig;
use crate::{config::ModelInfo, repl::ReplyStreamHandler, utils::PromptKind};
use crate::{
config::{Config, ModelInfo, SharedConfig},
utils::PromptKind,
};
use anyhow::{anyhow, bail, Result};
use serde::Deserialize;
use serde_json::Value;
register_role!(
("openai", OpenAI, OpenAIConfig, OpenAIClient),
("localai", LocalAI, LocalAIConfig, LocalAIClient),
register_client!(
(openai, "openai", OpenAI, OpenAIConfig, OpenAIClient),
(localai, "localai", LocalAI, LocalAIConfig, LocalAIClient),
(
azure_openai,
"azure-openai",
AzureOpenAI,
AzureOpenAIConfig,

View File

@ -1,7 +1,4 @@
use super::{Client, ExtraConfig, ModelInfo, OpenAIClient, PromptKind, PromptType, SendData};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
use super::{ExtraConfig, ModelInfo, OpenAIClient, PromptKind, PromptType, SendData, ReplyStreamHandler};
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
@ -29,27 +26,7 @@ pub struct OpenAIConfig {
pub extra: Option<ExtraConfig>,
}
#[async_trait]
impl Client for OpenAIClient {
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
openai_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyStreamHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
openai_send_message_streaming(builder, handler).await
}
}
openai_compatible_client!(OpenAIClient);
impl OpenAIClient {
config_get_fn!(api_key, get_api_key);