mirror of
https://github.com/sigoden/aichat
synced 2024-11-10 07:10:36 +00:00
refactor: add/use openai_compatible_client macro
This commit is contained in:
parent
8d76fc77fb
commit
da3c541b68
@ -1,8 +1,5 @@
|
|||||||
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
|
use super::openai::openai_build_body;
|
||||||
use super::{AzureOpenAIClient, Client, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData};
|
use super::{AzureOpenAIClient, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData};
|
||||||
|
|
||||||
use crate::config::SharedConfig;
|
|
||||||
use crate::repl::ReplyStreamHandler;
|
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
@ -26,27 +23,7 @@ pub struct AzureOpenAIModel {
|
|||||||
max_tokens: Option<usize>,
|
max_tokens: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
openai_compatible_client!(AzureOpenAIClient);
|
||||||
impl Client for AzureOpenAIClient {
|
|
||||||
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
|
|
||||||
(&self.global_config, &self.config.extra)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
|
||||||
let builder = self.request_builder(client, data)?;
|
|
||||||
openai_send_message(builder).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_message_streaming_inner(
|
|
||||||
&self,
|
|
||||||
client: &ReqwestClient,
|
|
||||||
handler: &mut ReplyStreamHandler,
|
|
||||||
data: SendData,
|
|
||||||
) -> Result<()> {
|
|
||||||
let builder = self.request_builder(client, data)?;
|
|
||||||
openai_send_message_streaming(builder, handler).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AzureOpenAIClient {
|
impl AzureOpenAIClient {
|
||||||
config_get_fn!(api_base, get_api_base);
|
config_get_fn!(api_base, get_api_base);
|
||||||
|
@ -15,12 +15,18 @@ use tokio::time::sleep;
|
|||||||
use super::{openai::OpenAIConfig, ClientConfig};
|
use super::{openai::OpenAIConfig, ClientConfig};
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! register_role {
|
macro_rules! register_client {
|
||||||
(
|
(
|
||||||
$(($name:literal, $config_key:ident, $config:ident, $client:ident),)+
|
$(($module:ident, $name:literal, $config_key:ident, $config:ident, $client:ident),)+
|
||||||
) => {
|
) => {
|
||||||
|
$(
|
||||||
|
mod $module;
|
||||||
|
)+
|
||||||
|
$(
|
||||||
|
use self::$module::$config;
|
||||||
|
)+
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub enum ClientConfig {
|
pub enum ClientConfig {
|
||||||
$(
|
$(
|
||||||
@ -35,15 +41,15 @@ macro_rules! register_role {
|
|||||||
$(
|
$(
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct $client {
|
pub struct $client {
|
||||||
global_config: SharedConfig,
|
global_config: $crate::config::SharedConfig,
|
||||||
config: $config,
|
config: $config,
|
||||||
model_info: ModelInfo,
|
model_info: $crate::config::ModelInfo,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl $client {
|
impl $client {
|
||||||
pub const NAME: &str = $name;
|
pub const NAME: &str = $name;
|
||||||
|
|
||||||
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
|
pub fn init(global_config: $crate::config::SharedConfig) -> Option<Box<dyn Client>> {
|
||||||
let model_info = global_config.read().model_info.clone();
|
let model_info = global_config.read().model_info.clone();
|
||||||
let config = {
|
let config = {
|
||||||
if let ClientConfig::$config_key(c) = &global_config.read().clients[model_info.index] {
|
if let ClientConfig::$config_key(c) = &global_config.read().clients[model_info.index] {
|
||||||
@ -66,12 +72,12 @@ macro_rules! register_role {
|
|||||||
|
|
||||||
)+
|
)+
|
||||||
|
|
||||||
pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
|
pub fn init_client(config: $crate::config::SharedConfig) -> anyhow::Result<Box<dyn Client>> {
|
||||||
None
|
None
|
||||||
$(.or_else(|| $client::init(config.clone())))+
|
$(.or_else(|| $client::init(config.clone())))+
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
let model_info = config.read().model_info.clone();
|
let model_info = config.read().model_info.clone();
|
||||||
anyhow!(
|
anyhow::anyhow!(
|
||||||
"Unknown client {} at config.clients[{}]",
|
"Unknown client {} at config.clients[{}]",
|
||||||
&model_info.client,
|
&model_info.client,
|
||||||
&model_info.index
|
&model_info.index
|
||||||
@ -83,16 +89,16 @@ macro_rules! register_role {
|
|||||||
vec![$($client::NAME,)+]
|
vec![$($client::NAME,)+]
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_client_config(client: &str) -> Result<Value> {
|
pub fn create_client_config(client: &str) -> anyhow::Result<serde_json::Value> {
|
||||||
$(
|
$(
|
||||||
if client == $client::NAME {
|
if client == $client::NAME {
|
||||||
return create_config(&$client::PROMPTS, $client::NAME)
|
return create_config(&$client::PROMPTS, $client::NAME)
|
||||||
}
|
}
|
||||||
)+
|
)+
|
||||||
bail!("Unknown client {}", client)
|
anyhow::bail!("Unknown client {}", client)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn all_models(config: &Config) -> Vec<ModelInfo> {
|
pub fn all_models(config: &$crate::config::Config) -> Vec<$crate::config::ModelInfo> {
|
||||||
config
|
config
|
||||||
.clients
|
.clients
|
||||||
.iter()
|
.iter()
|
||||||
@ -107,16 +113,53 @@ macro_rules! register_role {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! openai_compatible_client {
|
||||||
|
($client:ident) => {
|
||||||
|
#[async_trait]
|
||||||
|
impl $crate::client::Client for $crate::client::$client {
|
||||||
|
fn config(
|
||||||
|
&self,
|
||||||
|
) -> (
|
||||||
|
&$crate::config::SharedConfig,
|
||||||
|
&Option<$crate::client::ExtraConfig>,
|
||||||
|
) {
|
||||||
|
(&self.global_config, &self.config.extra)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_message_inner(
|
||||||
|
&self,
|
||||||
|
client: &reqwest::Client,
|
||||||
|
data: $crate::client::SendData,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let builder = self.request_builder(client, data)?;
|
||||||
|
$crate::client::openai::openai_send_message(builder).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_message_streaming_inner(
|
||||||
|
&self,
|
||||||
|
client: &reqwest::Client,
|
||||||
|
handler: &mut $crate::repl::ReplyStreamHandler,
|
||||||
|
data: $crate::client::SendData,
|
||||||
|
) -> Result<()> {
|
||||||
|
let builder = self.request_builder(client, data)?;
|
||||||
|
$crate::client::openai::openai_send_message_streaming(builder, handler).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
macro_rules! config_get_fn {
|
macro_rules! config_get_fn {
|
||||||
($field_name:ident, $fn_name:ident) => {
|
($field_name:ident, $fn_name:ident) => {
|
||||||
fn $fn_name(&self) -> Result<String> {
|
fn $fn_name(&self) -> anyhow::Result<String> {
|
||||||
let api_key = self.config.$field_name.clone();
|
let api_key = self.config.$field_name.clone();
|
||||||
api_key
|
api_key
|
||||||
.or_else(|| {
|
.or_else(|| {
|
||||||
let env_prefix = Self::name(&self.config);
|
let env_prefix = Self::name(&self.config);
|
||||||
let env_name =
|
let env_name =
|
||||||
format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase();
|
format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase();
|
||||||
env::var(&env_name).ok()
|
std::env::var(&env_name).ok()
|
||||||
})
|
})
|
||||||
.ok_or_else(|| anyhow::anyhow!("Miss {}", stringify!($field_name)))
|
.ok_or_else(|| anyhow::anyhow!("Miss {}", stringify!($field_name)))
|
||||||
}
|
}
|
||||||
|
@ -1,14 +1,10 @@
|
|||||||
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
|
use super::openai::openai_build_body;
|
||||||
use super::{Client, ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData};
|
use super::{ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData};
|
||||||
|
|
||||||
use crate::config::SharedConfig;
|
|
||||||
use crate::repl::ReplyStreamHandler;
|
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::env;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
pub struct LocalAIConfig {
|
pub struct LocalAIConfig {
|
||||||
@ -26,27 +22,7 @@ pub struct LocalAIModel {
|
|||||||
max_tokens: Option<usize>,
|
max_tokens: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
openai_compatible_client!(LocalAIClient);
|
||||||
impl Client for LocalAIClient {
|
|
||||||
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
|
|
||||||
(&self.global_config, &self.config.extra)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
|
||||||
let builder = self.request_builder(client, data)?;
|
|
||||||
openai_send_message(builder).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_message_streaming_inner(
|
|
||||||
&self,
|
|
||||||
client: &ReqwestClient,
|
|
||||||
handler: &mut ReplyStreamHandler,
|
|
||||||
data: SendData,
|
|
||||||
) -> Result<()> {
|
|
||||||
let builder = self.request_builder(client, data)?;
|
|
||||||
openai_send_message_streaming(builder, handler).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LocalAIClient {
|
impl LocalAIClient {
|
||||||
config_get_fn!(api_key, get_api_key);
|
config_get_fn!(api_key, get_api_key);
|
||||||
|
@ -1,29 +1,15 @@
|
|||||||
#[macro_use]
|
#[macro_use]
|
||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
pub mod azure_openai;
|
|
||||||
pub mod localai;
|
|
||||||
pub mod openai;
|
|
||||||
|
|
||||||
pub use common::*;
|
pub use common::*;
|
||||||
|
|
||||||
use self::azure_openai::AzureOpenAIConfig;
|
use crate::{config::ModelInfo, repl::ReplyStreamHandler, utils::PromptKind};
|
||||||
use self::localai::LocalAIConfig;
|
|
||||||
use self::openai::OpenAIConfig;
|
|
||||||
|
|
||||||
use crate::{
|
register_client!(
|
||||||
config::{Config, ModelInfo, SharedConfig},
|
(openai, "openai", OpenAI, OpenAIConfig, OpenAIClient),
|
||||||
utils::PromptKind,
|
(localai, "localai", LocalAI, LocalAIConfig, LocalAIClient),
|
||||||
};
|
|
||||||
|
|
||||||
use anyhow::{anyhow, bail, Result};
|
|
||||||
use serde::Deserialize;
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
register_role!(
|
|
||||||
("openai", OpenAI, OpenAIConfig, OpenAIClient),
|
|
||||||
("localai", LocalAI, LocalAIConfig, LocalAIClient),
|
|
||||||
(
|
(
|
||||||
|
azure_openai,
|
||||||
"azure-openai",
|
"azure-openai",
|
||||||
AzureOpenAI,
|
AzureOpenAI,
|
||||||
AzureOpenAIConfig,
|
AzureOpenAIConfig,
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
use super::{Client, ExtraConfig, ModelInfo, OpenAIClient, PromptKind, PromptType, SendData};
|
use super::{ExtraConfig, ModelInfo, OpenAIClient, PromptKind, PromptType, SendData, ReplyStreamHandler};
|
||||||
|
|
||||||
use crate::config::SharedConfig;
|
|
||||||
use crate::repl::ReplyStreamHandler;
|
|
||||||
|
|
||||||
use anyhow::{anyhow, bail, Result};
|
use anyhow::{anyhow, bail, Result};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
@ -29,27 +26,7 @@ pub struct OpenAIConfig {
|
|||||||
pub extra: Option<ExtraConfig>,
|
pub extra: Option<ExtraConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
openai_compatible_client!(OpenAIClient);
|
||||||
impl Client for OpenAIClient {
|
|
||||||
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
|
|
||||||
(&self.global_config, &self.config.extra)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
|
||||||
let builder = self.request_builder(client, data)?;
|
|
||||||
openai_send_message(builder).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_message_streaming_inner(
|
|
||||||
&self,
|
|
||||||
client: &ReqwestClient,
|
|
||||||
handler: &mut ReplyStreamHandler,
|
|
||||||
data: SendData,
|
|
||||||
) -> Result<()> {
|
|
||||||
let builder = self.request_builder(client, data)?;
|
|
||||||
openai_send_message_streaming(builder, handler).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAIClient {
|
impl OpenAIClient {
|
||||||
config_get_fn!(api_key, get_api_key);
|
config_get_fn!(api_key, get_api_key);
|
||||||
|
Loading…
Reference in New Issue
Block a user