chore: improve client-related code quality

This commit is contained in:
sigoden 2023-10-29 10:09:33 +08:00
parent 557bed1459
commit 985f8c0946
3 changed files with 55 additions and 73 deletions

View File

@ -1,5 +1,5 @@
use super::openai::{openai_send_message, openai_send_message_streaming};
use super::{set_proxy, Client, ModelInfo};
use super::{set_proxy, Client, ClientConfig, ModelInfo};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
@ -59,27 +59,34 @@ impl Client for LocalAIClient {
}
impl LocalAIClient {
pub fn new(
global_config: SharedConfig,
local_config: LocalAIConfig,
model_info: ModelInfo,
) -> Self {
Self {
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
if model_info.client != LocalAIClient::name() {
return None;
}
let local_config = {
if let ClientConfig::LocalAI(c) = &global_config.read().clients[model_info.index] {
c.clone()
} else {
return None;
}
};
Some(Box::new(Self {
global_config,
local_config,
model_info,
}
}))
}
pub fn name() -> &'static str {
"localai"
}
pub fn list_models(local_config: &LocalAIConfig) -> Vec<(String, usize)> {
pub fn list_models(local_config: &LocalAIConfig, index: usize) -> Vec<ModelInfo> {
local_config
.models
.iter()
.map(|v| (v.name.to_string(), v.max_tokens))
.map(|v| ModelInfo::new(Self::name(), &v.name, v.max_tokens, index))
.collect()
}

View File

@ -6,7 +6,7 @@ use self::{
openai::{OpenAIClient, OpenAIConfig},
};
use anyhow::{bail, Context, Result};
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
use reqwest::{ClientBuilder, Proxy};
use serde::Deserialize;
@ -41,9 +41,7 @@ pub struct ModelInfo {
impl Default for ModelInfo {
fn default() -> Self {
let client = OpenAIClient::name();
let (name, max_tokens) = &OpenAIClient::list_models(&OpenAIConfig::default())[0];
Self::new(client, name, *max_tokens, 0)
OpenAIClient::list_models(&OpenAIConfig::default(), 0)[0].clone()
}
}
@ -128,43 +126,16 @@ pub trait Client {
}
pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
let model_info = config.read().model_info.clone();
let model_info_err = |model_info: &ModelInfo| {
bail!(
"Unknown client {} at config.clients[{}]",
&model_info.client,
&model_info.index
)
};
if model_info.client == OpenAIClient::name() {
let local_config = {
if let ClientConfig::OpenAI(c) = &config.read().clients[model_info.index] {
c.clone()
} else {
return model_info_err(&model_info);
}
};
Ok(Box::new(OpenAIClient::new(
config,
local_config,
model_info,
)))
} else if model_info.client == LocalAIClient::name() {
let local_config = {
if let ClientConfig::LocalAI(c) = &config.read().clients[model_info.index] {
c.clone()
} else {
return model_info_err(&model_info);
}
};
Ok(Box::new(LocalAIClient::new(
config,
local_config,
model_info,
)))
} else {
bail!("Unknown client {}", &model_info.client)
}
OpenAIClient::init(config.clone())
.or_else(|| LocalAIClient::init(config.clone()))
.ok_or_else(|| {
let model_info = config.read().model_info.clone();
anyhow!(
"Unknown client {} at config.clients[{}]",
&model_info.client,
&model_info.index
)
})
}
pub fn all_clients() -> Vec<&'static str> {
@ -187,14 +158,8 @@ pub fn list_models(config: &Config) -> Vec<ModelInfo> {
.iter()
.enumerate()
.flat_map(|(i, v)| match v {
ClientConfig::OpenAI(c) => OpenAIClient::list_models(c)
.iter()
.map(|(x, y)| ModelInfo::new(OpenAIClient::name(), x, *y, i))
.collect::<Vec<ModelInfo>>(),
ClientConfig::LocalAI(c) => LocalAIClient::list_models(c)
.iter()
.map(|(x, y)| ModelInfo::new(LocalAIClient::name(), x, *y, i))
.collect::<Vec<ModelInfo>>(),
ClientConfig::OpenAI(c) => OpenAIClient::list_models(c, i),
ClientConfig::LocalAI(c) => LocalAIClient::list_models(c, i),
})
.collect()
}

View File

@ -1,4 +1,4 @@
use super::{set_proxy, Client, ModelInfo};
use super::{set_proxy, Client, ClientConfig, ModelInfo};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
@ -56,29 +56,39 @@ impl Client for OpenAIClient {
}
impl OpenAIClient {
pub fn new(
global_config: SharedConfig,
local_config: OpenAIConfig,
model_info: ModelInfo,
) -> Self {
Self {
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
if model_info.client != OpenAIClient::name() {
return None;
}
let local_config = {
if let ClientConfig::OpenAI(c) = &global_config.read().clients[model_info.index] {
c.clone()
} else {
return None;
}
};
Some(Box::new(Self {
global_config,
local_config,
model_info,
}
}))
}
pub fn name() -> &'static str {
"openai"
}
pub fn list_models(_local_config: &OpenAIConfig) -> Vec<(String, usize)> {
vec![
("gpt-3.5-turbo".into(), 4096),
("gpt-3.5-turbo-16k".into(), 16384),
("gpt-4".into(), 8192),
("gpt-4-32k".into(), 32768),
pub fn list_models(_local_config: &OpenAIConfig, index: usize) -> Vec<ModelInfo> {
[
("gpt-3.5-turbo", 4096),
("gpt-3.5-turbo-16k", 16384),
("gpt-4", 8192),
("gpt-4-32k", 32768),
]
.into_iter()
.map(|(name, max_tokens)| ModelInfo::new(Self::name(), name, max_tokens, index))
.collect()
}
pub fn create_config() -> Result<String> {