mirror of
https://github.com/sigoden/aichat
synced 2024-11-10 07:10:36 +00:00
refactor: add/use openai_compatible_client macro
This commit is contained in:
parent
8d76fc77fb
commit
da3c541b68
@ -1,8 +1,5 @@
|
||||
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
|
||||
use super::{AzureOpenAIClient, Client, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData};
|
||||
|
||||
use crate::config::SharedConfig;
|
||||
use crate::repl::ReplyStreamHandler;
|
||||
use super::openai::openai_build_body;
|
||||
use super::{AzureOpenAIClient, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
@ -26,27 +23,7 @@ pub struct AzureOpenAIModel {
|
||||
max_tokens: Option<usize>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Client for AzureOpenAIClient {
|
||||
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
|
||||
(&self.global_config, &self.config.extra)
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
openai_send_message(builder).await
|
||||
}
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
data: SendData,
|
||||
) -> Result<()> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
openai_send_message_streaming(builder, handler).await
|
||||
}
|
||||
}
|
||||
openai_compatible_client!(AzureOpenAIClient);
|
||||
|
||||
impl AzureOpenAIClient {
|
||||
config_get_fn!(api_base, get_api_base);
|
||||
|
@ -15,12 +15,18 @@ use tokio::time::sleep;
|
||||
use super::{openai::OpenAIConfig, ClientConfig};
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! register_role {
|
||||
macro_rules! register_client {
|
||||
(
|
||||
$(($name:literal, $config_key:ident, $config:ident, $client:ident),)+
|
||||
$(($module:ident, $name:literal, $config_key:ident, $config:ident, $client:ident),)+
|
||||
) => {
|
||||
$(
|
||||
mod $module;
|
||||
)+
|
||||
$(
|
||||
use self::$module::$config;
|
||||
)+
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ClientConfig {
|
||||
$(
|
||||
@ -35,15 +41,15 @@ macro_rules! register_role {
|
||||
$(
|
||||
#[derive(Debug)]
|
||||
pub struct $client {
|
||||
global_config: SharedConfig,
|
||||
global_config: $crate::config::SharedConfig,
|
||||
config: $config,
|
||||
model_info: ModelInfo,
|
||||
model_info: $crate::config::ModelInfo,
|
||||
}
|
||||
|
||||
impl $client {
|
||||
pub const NAME: &str = $name;
|
||||
|
||||
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
|
||||
pub fn init(global_config: $crate::config::SharedConfig) -> Option<Box<dyn Client>> {
|
||||
let model_info = global_config.read().model_info.clone();
|
||||
let config = {
|
||||
if let ClientConfig::$config_key(c) = &global_config.read().clients[model_info.index] {
|
||||
@ -66,12 +72,12 @@ macro_rules! register_role {
|
||||
|
||||
)+
|
||||
|
||||
pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
|
||||
pub fn init_client(config: $crate::config::SharedConfig) -> anyhow::Result<Box<dyn Client>> {
|
||||
None
|
||||
$(.or_else(|| $client::init(config.clone())))+
|
||||
.ok_or_else(|| {
|
||||
let model_info = config.read().model_info.clone();
|
||||
anyhow!(
|
||||
anyhow::anyhow!(
|
||||
"Unknown client {} at config.clients[{}]",
|
||||
&model_info.client,
|
||||
&model_info.index
|
||||
@ -83,16 +89,16 @@ macro_rules! register_role {
|
||||
vec![$($client::NAME,)+]
|
||||
}
|
||||
|
||||
pub fn create_client_config(client: &str) -> Result<Value> {
|
||||
pub fn create_client_config(client: &str) -> anyhow::Result<serde_json::Value> {
|
||||
$(
|
||||
if client == $client::NAME {
|
||||
return create_config(&$client::PROMPTS, $client::NAME)
|
||||
}
|
||||
)+
|
||||
bail!("Unknown client {}", client)
|
||||
anyhow::bail!("Unknown client {}", client)
|
||||
}
|
||||
|
||||
pub fn all_models(config: &Config) -> Vec<ModelInfo> {
|
||||
pub fn all_models(config: &$crate::config::Config) -> Vec<$crate::config::ModelInfo> {
|
||||
config
|
||||
.clients
|
||||
.iter()
|
||||
@ -107,16 +113,53 @@ macro_rules! register_role {
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! openai_compatible_client {
|
||||
($client:ident) => {
|
||||
#[async_trait]
|
||||
impl $crate::client::Client for $crate::client::$client {
|
||||
fn config(
|
||||
&self,
|
||||
) -> (
|
||||
&$crate::config::SharedConfig,
|
||||
&Option<$crate::client::ExtraConfig>,
|
||||
) {
|
||||
(&self.global_config, &self.config.extra)
|
||||
}
|
||||
|
||||
async fn send_message_inner(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
data: $crate::client::SendData,
|
||||
) -> anyhow::Result<String> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
$crate::client::openai::openai_send_message(builder).await
|
||||
}
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
handler: &mut $crate::repl::ReplyStreamHandler,
|
||||
data: $crate::client::SendData,
|
||||
) -> Result<()> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
$crate::client::openai::openai_send_message_streaming(builder, handler).await
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! config_get_fn {
|
||||
($field_name:ident, $fn_name:ident) => {
|
||||
fn $fn_name(&self) -> Result<String> {
|
||||
fn $fn_name(&self) -> anyhow::Result<String> {
|
||||
let api_key = self.config.$field_name.clone();
|
||||
api_key
|
||||
.or_else(|| {
|
||||
let env_prefix = Self::name(&self.config);
|
||||
let env_name =
|
||||
format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase();
|
||||
env::var(&env_name).ok()
|
||||
std::env::var(&env_name).ok()
|
||||
})
|
||||
.ok_or_else(|| anyhow::anyhow!("Miss {}", stringify!($field_name)))
|
||||
}
|
||||
|
@ -1,14 +1,10 @@
|
||||
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
|
||||
use super::{Client, ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData};
|
||||
|
||||
use crate::config::SharedConfig;
|
||||
use crate::repl::ReplyStreamHandler;
|
||||
use super::openai::openai_build_body;
|
||||
use super::{ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData};
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct LocalAIConfig {
|
||||
@ -26,27 +22,7 @@ pub struct LocalAIModel {
|
||||
max_tokens: Option<usize>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Client for LocalAIClient {
|
||||
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
|
||||
(&self.global_config, &self.config.extra)
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
openai_send_message(builder).await
|
||||
}
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
data: SendData,
|
||||
) -> Result<()> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
openai_send_message_streaming(builder, handler).await
|
||||
}
|
||||
}
|
||||
openai_compatible_client!(LocalAIClient);
|
||||
|
||||
impl LocalAIClient {
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
|
@ -1,29 +1,15 @@
|
||||
#[macro_use]
|
||||
mod common;
|
||||
|
||||
pub mod azure_openai;
|
||||
pub mod localai;
|
||||
pub mod openai;
|
||||
|
||||
pub use common::*;
|
||||
|
||||
use self::azure_openai::AzureOpenAIConfig;
|
||||
use self::localai::LocalAIConfig;
|
||||
use self::openai::OpenAIConfig;
|
||||
use crate::{config::ModelInfo, repl::ReplyStreamHandler, utils::PromptKind};
|
||||
|
||||
use crate::{
|
||||
config::{Config, ModelInfo, SharedConfig},
|
||||
utils::PromptKind,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
|
||||
register_role!(
|
||||
("openai", OpenAI, OpenAIConfig, OpenAIClient),
|
||||
("localai", LocalAI, LocalAIConfig, LocalAIClient),
|
||||
register_client!(
|
||||
(openai, "openai", OpenAI, OpenAIConfig, OpenAIClient),
|
||||
(localai, "localai", LocalAI, LocalAIConfig, LocalAIClient),
|
||||
(
|
||||
azure_openai,
|
||||
"azure-openai",
|
||||
AzureOpenAI,
|
||||
AzureOpenAIConfig,
|
||||
|
@ -1,7 +1,4 @@
|
||||
use super::{Client, ExtraConfig, ModelInfo, OpenAIClient, PromptKind, PromptType, SendData};
|
||||
|
||||
use crate::config::SharedConfig;
|
||||
use crate::repl::ReplyStreamHandler;
|
||||
use super::{ExtraConfig, ModelInfo, OpenAIClient, PromptKind, PromptType, SendData, ReplyStreamHandler};
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use async_trait::async_trait;
|
||||
@ -29,27 +26,7 @@ pub struct OpenAIConfig {
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Client for OpenAIClient {
|
||||
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
|
||||
(&self.global_config, &self.config.extra)
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
openai_send_message(builder).await
|
||||
}
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
data: SendData,
|
||||
) -> Result<()> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
openai_send_message_streaming(builder, handler).await
|
||||
}
|
||||
}
|
||||
openai_compatible_client!(OpenAIClient);
|
||||
|
||||
impl OpenAIClient {
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
|
Loading…
Reference in New Issue
Block a user