mirror of
https://github.com/sigoden/aichat
synced 2024-11-16 06:15:26 +00:00
chore: improve client-related code quality
This commit is contained in:
parent
557bed1459
commit
985f8c0946
@ -1,5 +1,5 @@
|
||||
use super::openai::{openai_send_message, openai_send_message_streaming};
|
||||
use super::{set_proxy, Client, ModelInfo};
|
||||
use super::{set_proxy, Client, ClientConfig, ModelInfo};
|
||||
|
||||
use crate::config::SharedConfig;
|
||||
use crate::repl::ReplyStreamHandler;
|
||||
@ -59,27 +59,34 @@ impl Client for LocalAIClient {
|
||||
}
|
||||
|
||||
impl LocalAIClient {
|
||||
pub fn new(
|
||||
global_config: SharedConfig,
|
||||
local_config: LocalAIConfig,
|
||||
model_info: ModelInfo,
|
||||
) -> Self {
|
||||
Self {
|
||||
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
|
||||
let model_info = global_config.read().model_info.clone();
|
||||
if model_info.client != LocalAIClient::name() {
|
||||
return None;
|
||||
}
|
||||
let local_config = {
|
||||
if let ClientConfig::LocalAI(c) = &global_config.read().clients[model_info.index] {
|
||||
c.clone()
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
};
|
||||
Some(Box::new(Self {
|
||||
global_config,
|
||||
local_config,
|
||||
model_info,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn name() -> &'static str {
|
||||
"localai"
|
||||
}
|
||||
|
||||
pub fn list_models(local_config: &LocalAIConfig) -> Vec<(String, usize)> {
|
||||
pub fn list_models(local_config: &LocalAIConfig, index: usize) -> Vec<ModelInfo> {
|
||||
local_config
|
||||
.models
|
||||
.iter()
|
||||
.map(|v| (v.name.to_string(), v.max_tokens))
|
||||
.map(|v| ModelInfo::new(Self::name(), &v.name, v.max_tokens, index))
|
||||
.collect()
|
||||
}
|
||||
|
||||
|
@ -6,7 +6,7 @@ use self::{
|
||||
openai::{OpenAIClient, OpenAIConfig},
|
||||
};
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::{ClientBuilder, Proxy};
|
||||
use serde::Deserialize;
|
||||
@ -41,9 +41,7 @@ pub struct ModelInfo {
|
||||
|
||||
impl Default for ModelInfo {
|
||||
fn default() -> Self {
|
||||
let client = OpenAIClient::name();
|
||||
let (name, max_tokens) = &OpenAIClient::list_models(&OpenAIConfig::default())[0];
|
||||
Self::new(client, name, *max_tokens, 0)
|
||||
OpenAIClient::list_models(&OpenAIConfig::default(), 0)[0].clone()
|
||||
}
|
||||
}
|
||||
|
||||
@ -128,43 +126,16 @@ pub trait Client {
|
||||
}
|
||||
|
||||
pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
|
||||
let model_info = config.read().model_info.clone();
|
||||
let model_info_err = |model_info: &ModelInfo| {
|
||||
bail!(
|
||||
"Unknown client {} at config.clients[{}]",
|
||||
&model_info.client,
|
||||
&model_info.index
|
||||
)
|
||||
};
|
||||
if model_info.client == OpenAIClient::name() {
|
||||
let local_config = {
|
||||
if let ClientConfig::OpenAI(c) = &config.read().clients[model_info.index] {
|
||||
c.clone()
|
||||
} else {
|
||||
return model_info_err(&model_info);
|
||||
}
|
||||
};
|
||||
Ok(Box::new(OpenAIClient::new(
|
||||
config,
|
||||
local_config,
|
||||
model_info,
|
||||
)))
|
||||
} else if model_info.client == LocalAIClient::name() {
|
||||
let local_config = {
|
||||
if let ClientConfig::LocalAI(c) = &config.read().clients[model_info.index] {
|
||||
c.clone()
|
||||
} else {
|
||||
return model_info_err(&model_info);
|
||||
}
|
||||
};
|
||||
Ok(Box::new(LocalAIClient::new(
|
||||
config,
|
||||
local_config,
|
||||
model_info,
|
||||
)))
|
||||
} else {
|
||||
bail!("Unknown client {}", &model_info.client)
|
||||
}
|
||||
OpenAIClient::init(config.clone())
|
||||
.or_else(|| LocalAIClient::init(config.clone()))
|
||||
.ok_or_else(|| {
|
||||
let model_info = config.read().model_info.clone();
|
||||
anyhow!(
|
||||
"Unknown client {} at config.clients[{}]",
|
||||
&model_info.client,
|
||||
&model_info.index
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn all_clients() -> Vec<&'static str> {
|
||||
@ -187,14 +158,8 @@ pub fn list_models(config: &Config) -> Vec<ModelInfo> {
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, v)| match v {
|
||||
ClientConfig::OpenAI(c) => OpenAIClient::list_models(c)
|
||||
.iter()
|
||||
.map(|(x, y)| ModelInfo::new(OpenAIClient::name(), x, *y, i))
|
||||
.collect::<Vec<ModelInfo>>(),
|
||||
ClientConfig::LocalAI(c) => LocalAIClient::list_models(c)
|
||||
.iter()
|
||||
.map(|(x, y)| ModelInfo::new(LocalAIClient::name(), x, *y, i))
|
||||
.collect::<Vec<ModelInfo>>(),
|
||||
ClientConfig::OpenAI(c) => OpenAIClient::list_models(c, i),
|
||||
ClientConfig::LocalAI(c) => LocalAIClient::list_models(c, i),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use super::{set_proxy, Client, ModelInfo};
|
||||
use super::{set_proxy, Client, ClientConfig, ModelInfo};
|
||||
|
||||
use crate::config::SharedConfig;
|
||||
use crate::repl::ReplyStreamHandler;
|
||||
@ -56,29 +56,39 @@ impl Client for OpenAIClient {
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
pub fn new(
|
||||
global_config: SharedConfig,
|
||||
local_config: OpenAIConfig,
|
||||
model_info: ModelInfo,
|
||||
) -> Self {
|
||||
Self {
|
||||
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
|
||||
let model_info = global_config.read().model_info.clone();
|
||||
if model_info.client != OpenAIClient::name() {
|
||||
return None;
|
||||
}
|
||||
let local_config = {
|
||||
if let ClientConfig::OpenAI(c) = &global_config.read().clients[model_info.index] {
|
||||
c.clone()
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
};
|
||||
Some(Box::new(Self {
|
||||
global_config,
|
||||
local_config,
|
||||
model_info,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn name() -> &'static str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
pub fn list_models(_local_config: &OpenAIConfig) -> Vec<(String, usize)> {
|
||||
vec![
|
||||
("gpt-3.5-turbo".into(), 4096),
|
||||
("gpt-3.5-turbo-16k".into(), 16384),
|
||||
("gpt-4".into(), 8192),
|
||||
("gpt-4-32k".into(), 32768),
|
||||
pub fn list_models(_local_config: &OpenAIConfig, index: usize) -> Vec<ModelInfo> {
|
||||
[
|
||||
("gpt-3.5-turbo", 4096),
|
||||
("gpt-3.5-turbo-16k", 16384),
|
||||
("gpt-4", 8192),
|
||||
("gpt-4-32k", 32768),
|
||||
]
|
||||
.into_iter()
|
||||
.map(|(name, max_tokens)| ModelInfo::new(Self::name(), name, max_tokens, index))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn create_config() -> Result<String> {
|
||||
|
Loading…
Reference in New Issue
Block a user