refactor: clients/* and config.rs (#193)

- add register_clients macro to make it easier to add a new client
- no create_client_config, just add const PROMPTS
- move ModelInfo from clients/ to config/
- model's max_tokens are optional
- improve code quanity on config/mod.rs
- add/use macro config_get_fn
pull/194/head
sigoden 8 months ago committed by GitHub
parent 64202758ec
commit 7f2210dbca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

1
Cargo.lock generated

@ -1611,6 +1611,7 @@ version = "1.0.107"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65"
dependencies = [
"indexmap 2.0.2",
"itoa",
"ryu",
"serde",

@ -21,7 +21,7 @@ inquire = "0.6.2"
is-terminal = "0.4.9"
reedline = "0.21.0"
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.93"
serde_json = { version = "1.0.93", features = ["preserve_order"] }
serde_yaml = "0.9.17"
tokio = { version = "1.26.0", features = ["rt", "time", "macros", "signal"] }
crossbeam = "0.8.2"

@ -10,33 +10,30 @@ keybindings: emacs # REPL keybindings. values: emacs, vi
clients:
# All clients have the following configuration:
# ```
# - type: xxxx
# name: nova # Only use it to distinguish clients with the same client type. Optional
# extra:
# proxy: socks5://127.0.0.1:1080 # Specify https/socks5 proxy server. Note HTTPS_PROXY/ALL_PROXY also works.
# connect_timeout: 10 # Set a timeout in seconds for connect to server
# ```
# See https://platform.openai.com/docs/quickstart
- type: openai
api_key: sk-xxx
organization_id: org-xxx # Organization ID. Optional
organization_id:
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
- type: azure-openai
api_base: https://RESOURCE.openai.azure.com
api_key: xxx
models: # Support models
models:
- name: MyGPT4 # Model deployment name
max_tokens: 8192
# See https://github.com/go-skynet/LocalAI
- type: localai
api_base: http://localhost:8080/v1
api_key: xxx
chat_endpoint: /chat/completions # Optional
models: # Support models
chat_endpoint: /chat/completions
models:
- name: gpt4all-j
max_tokens: 8192

@ -1,8 +1,5 @@
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
use super::{
prompt_input_api_base, prompt_input_api_key, prompt_input_max_token, prompt_input_model_name,
Client, ClientConfig, ExtraConfig, ModelInfo, SendData,
};
use super::{AzureOpenAIClient, Client, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
@ -14,17 +11,10 @@ use serde::Deserialize;
use std::env;
#[derive(Debug)]
pub struct AzureOpenAIClient {
global_config: SharedConfig,
config: AzureOpenAIConfig,
model_info: ModelInfo,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAIConfig {
pub name: Option<String>,
pub api_base: String,
pub api_base: Option<String>,
pub api_key: Option<String>,
pub models: Vec<AzureOpenAIModel>,
pub extra: Option<ExtraConfig>,
@ -33,17 +23,13 @@ pub struct AzureOpenAIConfig {
#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAIModel {
name: String,
max_tokens: usize,
max_tokens: Option<usize>,
}
#[async_trait]
impl Client for AzureOpenAIClient {
fn config(&self) -> &SharedConfig {
&self.global_config
}
fn extra_config(&self) -> &Option<ExtraConfig> {
&self.config.extra
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
@ -63,27 +49,19 @@ impl Client for AzureOpenAIClient {
}
impl AzureOpenAIClient {
pub const NAME: &str = "azure-openai";
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
let config = {
if let ClientConfig::AzureOpenAI(c) = &global_config.read().clients[model_info.index] {
c.clone()
} else {
return None;
}
};
Some(Box::new(Self {
global_config,
config,
model_info,
}))
}
pub fn name(local_config: &AzureOpenAIConfig) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptType<'static>; 4] = [
("api_base", "API Base:", true, PromptKind::String),
("api_key", "API Key:", true, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),
(
"models[].max_tokens",
"Max Tokens:",
true,
PromptKind::Integer,
),
];
pub fn list_models(local_config: &AzureOpenAIConfig, index: usize) -> Vec<ModelInfo> {
let client = Self::name(local_config);
@ -95,26 +73,6 @@ impl AzureOpenAIClient {
.collect()
}
pub fn create_config() -> Result<String> {
let mut client_config = format!("clients:\n - type: {}\n", Self::NAME);
let api_base = prompt_input_api_base()?;
client_config.push_str(&format!(" api_base: {api_base}\n"));
let api_key = prompt_input_api_key()?;
client_config.push_str(&format!(" api_key: {api_key}\n"));
let model_name = prompt_input_model_name()?;
let max_tokens = prompt_input_max_token()?;
client_config.push_str(&format!(
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
));
Ok(client_config)
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.config.api_key.clone();
let api_key = api_key
@ -127,11 +85,13 @@ impl AzureOpenAIClient {
})
.ok_or_else(|| anyhow!("Miss api_key"))?;
let api_base = self.get_api_base()?;
let body = openai_build_body(data, self.model_info.name.clone());
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2023-05-15",
self.config.api_base, self.model_info.name
&api_base, self.model_info.name
);
let builder = client.post(url).header("api-key", api_key).json(&body);

@ -0,0 +1,314 @@
use crate::{
config::{Message, SharedConfig},
repl::{ReplyStreamHandler, SharedAbortSignal},
utils::{init_tokio_runtime, prompt_input_integer, prompt_input_string, tokenize, PromptKind},
};
use anyhow::{Context, Result};
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy};
use serde::Deserialize;
use serde_json::{json, Value};
use std::{env, time::Duration};
use tokio::time::sleep;
use super::{openai::OpenAIConfig, ClientConfig};
#[macro_export]
macro_rules! register_role {
(
$(($name:literal, $config_key:ident, $config:ident, $client:ident),)+
) => {
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum ClientConfig {
$(
#[serde(rename = $name)]
$config_key($config),
)+
#[serde(other)]
Unknown,
}
$(
#[derive(Debug)]
pub struct $client {
global_config: SharedConfig,
config: $config,
model_info: ModelInfo,
}
impl $client {
pub const NAME: &str = $name;
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
let config = {
if let ClientConfig::$config_key(c) = &global_config.read().clients[model_info.index] {
c.clone()
} else {
return None;
}
};
Some(Box::new(Self {
global_config,
config,
model_info,
}))
}
pub fn name(local_config: &$config) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}
}
)+
pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
None
$(.or_else(|| $client::init(config.clone())))+
.ok_or_else(|| {
let model_info = config.read().model_info.clone();
anyhow!(
"Unknown client {} at config.clients[{}]",
&model_info.client,
&model_info.index
)
})
}
pub fn list_client_types() -> Vec<&'static str> {
vec![$($client::NAME,)+]
}
pub fn create_client_config(client: &str) -> Result<Value> {
$(
if client == $client::NAME {
return create_config(&$client::PROMPTS, $client::NAME)
}
)+
bail!("Unknown client {}", client)
}
pub fn all_models(config: &Config) -> Vec<ModelInfo> {
config
.clients
.iter()
.enumerate()
.flat_map(|(i, v)| match v {
$(ClientConfig::$config_key(c) => $client::list_models(c, i),)+
ClientConfig::Unknown => vec![],
})
.collect()
}
};
}
macro_rules! config_get_fn {
($field_name:ident, $fn_name:ident) => {
fn $fn_name(&self) -> Result<String> {
let api_key = self.config.$field_name.clone();
api_key
.or_else(|| {
let env_prefix = Self::name(&self.config);
let env_name =
format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase();
env::var(&env_name).ok()
})
.ok_or_else(|| anyhow::anyhow!("Miss {}", stringify!($field_name)))
}
};
}
#[async_trait]
pub trait Client {
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>);
fn build_client(&self) -> Result<ReqwestClient> {
let mut builder = ReqwestClient::builder();
let options = self.config().1;
let timeout = options
.as_ref()
.and_then(|v| v.connect_timeout)
.unwrap_or(10);
let proxy = options.as_ref().and_then(|v| v.proxy.clone());
builder = set_proxy(builder, &proxy)?;
let client = builder
.connect_timeout(Duration::from_secs(timeout))
.build()
.with_context(|| "Failed to build client")?;
Ok(client)
}
fn send_message(&self, content: &str) -> Result<String> {
init_tokio_runtime()?.block_on(async {
let global_config = self.config().0;
if global_config.read().dry_run {
let content = global_config.read().echo_messages(content);
return Ok(content);
}
let client = self.build_client()?;
let data = global_config.read().prepare_send_data(content, false)?;
self.send_message_inner(&client, data)
.await
.with_context(|| "Failed to get awswer")
})
}
fn send_message_streaming(
&self,
content: &str,
handler: &mut ReplyStreamHandler,
) -> Result<()> {
async fn watch_abort(abort: SharedAbortSignal) {
loop {
if abort.aborted() {
break;
}
sleep(Duration::from_millis(100)).await;
}
}
let abort = handler.get_abort();
init_tokio_runtime()?.block_on(async {
tokio::select! {
ret = async {
let global_config = self.config().0;
if global_config.read().dry_run {
let content = global_config.read().echo_messages(content);
let tokens = tokenize(&content);
for token in tokens {
tokio::time::sleep(Duration::from_millis(25)).await;
handler.text(&token)?;
}
return Ok(());
}
let client = self.build_client()?;
let data = global_config.read().prepare_send_data(content, true)?;
self.send_message_streaming_inner(&client, handler, data).await
} => {
handler.done()?;
ret.with_context(|| "Failed to get awswer")
}
_ = watch_abort(abort.clone()) => {
handler.done()?;
Ok(())
},
_ = tokio::signal::ctrl_c() => {
abort.set_ctrlc();
Ok(())
}
}
})
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String>;
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyStreamHandler,
data: SendData,
) -> Result<()>;
}
impl Default for ClientConfig {
fn default() -> Self {
Self::OpenAI(OpenAIConfig::default())
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct ExtraConfig {
pub proxy: Option<String>,
pub connect_timeout: Option<u64>,
}
#[derive(Debug)]
pub struct SendData {
pub messages: Vec<Message>,
pub temperature: Option<f64>,
pub stream: bool,
}
pub type PromptType<'a> = (&'a str, &'a str, bool, PromptKind);
pub fn create_config(list: &[PromptType], client: &str) -> Result<Value> {
let mut config = json!({
"type": client,
});
for (path, desc, required, kind) in list {
match kind {
PromptKind::String => {
let value = prompt_input_string(desc, *required)?;
set_config_value(&mut config, path, kind, &value);
}
PromptKind::Integer => {
let value = prompt_input_integer(desc, *required)?;
set_config_value(&mut config, path, kind, &value);
}
}
}
let clients = json!(vec![config]);
Ok(clients)
}
fn set_config_value(json: &mut Value, path: &str, kind: &PromptKind, value: &str) {
let segs: Vec<&str> = path.split('.').collect();
match segs.as_slice() {
[name] => json[name] = to_json(kind, value),
[scope, name] => match scope.split_once('[') {
None => {
if json.get(scope).is_none() {
let mut obj = json!({});
obj[name] = to_json(kind, value);
json[scope] = obj;
} else {
json[scope][name] = to_json(kind, value);
}
}
Some((scope, _)) => {
if json.get(scope).is_none() {
let mut obj = json!({});
obj[name] = to_json(kind, value);
json[scope] = json!([obj]);
} else {
json[scope][0][name] = to_json(kind, value);
}
}
},
_ => {}
}
}
fn to_json(kind: &PromptKind, value: &str) -> Value {
if value.is_empty() {
return Value::Null;
}
match kind {
PromptKind::String => value.into(),
PromptKind::Integer => match value.parse::<i32>() {
Ok(value) => value.into(),
Err(_) => value.into(),
},
}
}
fn set_proxy(builder: ClientBuilder, proxy: &Option<String>) -> Result<ClientBuilder> {
let proxy = if let Some(proxy) = proxy {
if proxy.is_empty() || proxy == "false" || proxy == "-" {
return Ok(builder);
}
proxy.clone()
} else if let Ok(proxy) = env::var("HTTPS_PROXY").or_else(|_| env::var("ALL_PROXY")) {
proxy
} else {
return Ok(builder);
};
let builder =
builder.proxy(Proxy::all(&proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
Ok(builder)
}

@ -1,8 +1,5 @@
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
use super::{
prompt_input_api_base, prompt_input_api_key_optional, prompt_input_max_token,
prompt_input_model_name, Client, ClientConfig, ExtraConfig, ModelInfo, SendData,
};
use super::{Client, ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
@ -13,13 +10,6 @@ use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use std::env;
#[derive(Debug)]
pub struct LocalAIClient {
global_config: SharedConfig,
config: LocalAIConfig,
model_info: ModelInfo,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LocalAIConfig {
pub name: Option<String>,
@ -33,17 +23,13 @@ pub struct LocalAIConfig {
#[derive(Debug, Clone, Deserialize)]
pub struct LocalAIModel {
name: String,
max_tokens: usize,
max_tokens: Option<usize>,
}
#[async_trait]
impl Client for LocalAIClient {
fn config(&self) -> &SharedConfig {
&self.global_config
}
fn extra_config(&self) -> &Option<ExtraConfig> {
&self.config.extra
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
@ -63,27 +49,19 @@ impl Client for LocalAIClient {
}
impl LocalAIClient {
pub const NAME: &str = "localai";
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
let config = {
if let ClientConfig::LocalAI(c) = &global_config.read().clients[model_info.index] {
c.clone()
} else {
return None;
}
};
Some(Box::new(Self {
global_config,
config,
model_info,
}))
}
pub fn name(local_config: &LocalAIConfig) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 4] = [
("api_base", "API Base:", true, PromptKind::String),
("api_key", "API Key:", false, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),
(
"models[].max_tokens",
"Max Tokens:",
false,
PromptKind::Integer,
),
];
pub fn list_models(local_config: &LocalAIConfig, index: usize) -> Vec<ModelInfo> {
let client = Self::name(local_config);
@ -95,32 +73,8 @@ impl LocalAIClient {
.collect()
}
pub fn create_config() -> Result<String> {
let mut client_config = format!("clients:\n - type: {}\n", Self::NAME);
let api_base = prompt_input_api_base()?;
client_config.push_str(&format!(" api_base: {api_base}\n"));
let api_key = prompt_input_api_key_optional()?;
client_config.push_str(&format!(" api_key: {api_key}\n"));
let model_name = prompt_input_model_name()?;
let max_tokens = prompt_input_max_token()?;
client_config.push_str(&format!(
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
));
Ok(client_config)
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.config.api_key.clone();
let api_key = api_key.or_else(|| {
let env_prefix = Self::name(&self.config).to_uppercase();
env::var(format!("{env_prefix}_API_KEY")).ok()
});
let api_key = self.get_api_key().ok();
let body = openai_build_body(data, self.model_info.name.clone());

@ -1,272 +1,32 @@
#[macro_use]
mod common;
pub mod azure_openai;
pub mod localai;
pub mod openai;
use self::{
azure_openai::{AzureOpenAIClient, AzureOpenAIConfig},
localai::LocalAIConfig,
openai::{OpenAIClient, OpenAIConfig},
};
pub use common::*;
use self::azure_openai::AzureOpenAIConfig;
use self::localai::LocalAIConfig;
use self::openai::OpenAIConfig;
use crate::{
client::localai::LocalAIClient,
config::{Config, Message, SharedConfig},
repl::{ReplyStreamHandler, SharedAbortSignal},
utils::tokenize,
config::{Config, ModelInfo, SharedConfig},
utils::PromptKind,
};
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
use inquire::{required, Text};
use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy};
use anyhow::{anyhow, bail, Result};
use serde::Deserialize;
use std::{env, time::Duration};
use tokio::time::sleep;
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum ClientConfig {
#[serde(rename = "openai")]
OpenAI(OpenAIConfig),
#[serde(rename = "localai")]
LocalAI(LocalAIConfig),
#[serde(rename = "azure-openai")]
AzureOpenAI(AzureOpenAIConfig),
}
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub client: String,
pub name: String,
pub max_tokens: usize,
pub index: usize,
}
impl Default for ModelInfo {
fn default() -> Self {
OpenAIClient::list_models(&OpenAIConfig::default(), 0)[0].clone()
}
}
impl ModelInfo {
pub fn new(client: &str, name: &str, max_tokens: usize, index: usize) -> Self {
Self {
client: client.into(),
name: name.into(),
max_tokens,
index,
}
}
pub fn stringify(&self) -> String {
format!("{}:{}", self.client, self.name)
}
}
#[derive(Debug)]
pub struct SendData {
pub messages: Vec<Message>,
pub temperature: Option<f64>,
pub stream: bool,
}
#[async_trait]
pub trait Client {
fn config(&self) -> &SharedConfig;
fn extra_config(&self) -> &Option<ExtraConfig>;
fn build_client(&self) -> Result<ReqwestClient> {
let mut builder = ReqwestClient::builder();
let options = self.extra_config();
let timeout = options
.as_ref()
.and_then(|v| v.connect_timeout)
.unwrap_or(10);
let proxy = options.as_ref().and_then(|v| v.proxy.clone());
builder = set_proxy(builder, &proxy)?;
let client = builder
.connect_timeout(Duration::from_secs(timeout))
.build()
.with_context(|| "Failed to build client")?;
Ok(client)
}
fn send_message(&self, content: &str) -> Result<String> {
init_tokio_runtime()?.block_on(async {
if self.config().read().dry_run {
let content = self.config().read().echo_messages(content);
return Ok(content);
}
let client = self.build_client()?;
let data = self.config().read().prepare_send_data(content, false)?;
self.send_message_inner(&client, data)
.await
.with_context(|| "Failed to fetch")
})
}
fn send_message_streaming(
&self,
content: &str,
handler: &mut ReplyStreamHandler,
) -> Result<()> {
async fn watch_abort(abort: SharedAbortSignal) {
loop {
if abort.aborted() {
break;
}
sleep(Duration::from_millis(100)).await;
}
}
let abort = handler.get_abort();
init_tokio_runtime()?.block_on(async {
tokio::select! {
ret = async {
if self.config().read().dry_run {
let content = self.config().read().echo_messages(content);
let tokens = tokenize(&content);
for token in tokens {
tokio::time::sleep(Duration::from_millis(25)).await;
handler.text(&token)?;
}
return Ok(());
}
let client = self.build_client()?;
let data = self.config().read().prepare_send_data(content, true)?;
self.send_message_streaming_inner(&client, handler, data).await
} => {
handler.done()?;
ret.with_context(|| "Failed to fetch stream")
}
_ = watch_abort(abort.clone()) => {
handler.done()?;
Ok(())
},
_ = tokio::signal::ctrl_c() => {
abort.set_ctrlc();
Ok(())
}
}
})
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String>;
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyStreamHandler,
data: SendData,
) -> Result<()>;
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct ExtraConfig {
pub proxy: Option<String>,
pub connect_timeout: Option<u64>,
}
pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
OpenAIClient::init(config.clone())
.or_else(|| LocalAIClient::init(config.clone()))
.or_else(|| AzureOpenAIClient::init(config.clone()))
.ok_or_else(|| {
let model_info = config.read().model_info.clone();
anyhow!(
"Unknown client {} at config.clients[{}]",
&model_info.client,
&model_info.index
)
})
}
pub fn list_client_types() -> Vec<&'static str> {
vec![
OpenAIClient::NAME,
LocalAIClient::NAME,
AzureOpenAIClient::NAME,
]
}
pub fn create_client_config(client: &str) -> Result<String> {
if client == OpenAIClient::NAME {
OpenAIClient::create_config()
} else if client == LocalAIClient::NAME {
LocalAIClient::create_config()
} else if client == AzureOpenAIClient::NAME {
AzureOpenAIClient::create_config()
} else {
bail!("Unknown client {}", &client)
}
}
pub fn list_models(config: &Config) -> Vec<ModelInfo> {
config
.clients
.iter()
.enumerate()
.flat_map(|(i, v)| match v {
ClientConfig::OpenAI(c) => OpenAIClient::list_models(c, i),
ClientConfig::LocalAI(c) => LocalAIClient::list_models(c, i),
ClientConfig::AzureOpenAI(c) => AzureOpenAIClient::list_models(c, i),
})
.collect()
}
pub(crate) fn init_tokio_runtime() -> Result<tokio::runtime::Runtime> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.with_context(|| "Failed to init tokio")
}
pub(crate) fn prompt_input_api_base() -> Result<String> {
Text::new("API Base:")
.with_validator(required!("This field is required"))
.prompt()
.map_err(prompt_op_err)
}
pub(crate) fn prompt_input_api_key() -> Result<String> {
Text::new("API Key:")
.with_validator(required!("This field is required"))
.prompt()
.map_err(prompt_op_err)
}
pub(crate) fn prompt_input_api_key_optional() -> Result<String> {
Text::new("API Key:").prompt().map_err(prompt_op_err)
}
pub(crate) fn prompt_input_model_name() -> Result<String> {
Text::new("Model Name:")
.with_validator(required!("This field is required"))
.prompt()
.map_err(prompt_op_err)
}
pub(crate) fn prompt_input_max_token() -> Result<String> {
Text::new("Max tokens:")
.with_default("4096")
.with_validator(required!("This field is required"))
.prompt()
.map_err(prompt_op_err)
}
pub(crate) fn prompt_op_err<T>(_: T) -> anyhow::Error {
anyhow!("An error happened, try again later.")
}
fn set_proxy(builder: ClientBuilder, proxy: &Option<String>) -> Result<ClientBuilder> {
let proxy = if let Some(proxy) = proxy {
if proxy.is_empty() || proxy == "false" || proxy == "-" {
return Ok(builder);
}
proxy.clone()
} else if let Ok(proxy) = env::var("HTTPS_PROXY").or_else(|_| env::var("ALL_PROXY")) {
proxy
} else {
return Ok(builder);
};
let builder =
builder.proxy(Proxy::all(&proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
Ok(builder)
}
use serde_json::Value;
register_role!(
("openai", OpenAI, OpenAIConfig, OpenAIClient),
("localai", LocalAI, LocalAIConfig, LocalAIClient),
(
"azure-openai",
AzureOpenAI,
AzureOpenAIConfig,
AzureOpenAIClient
),
);

@ -1,4 +1,4 @@
use super::{prompt_input_api_key, Client, ClientConfig, ExtraConfig, ModelInfo, SendData};
use super::{Client, ExtraConfig, ModelInfo, OpenAIClient, PromptKind, PromptType, SendData};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
@ -14,12 +14,12 @@ use std::env;
const API_BASE: &str = "https://api.openai.com/v1";
#[derive(Debug)]
pub struct OpenAIClient {
global_config: SharedConfig,
config: OpenAIConfig,
model_info: ModelInfo,
}
const MODELS: [(&str, usize); 4] = [
("gpt-3.5-turbo", 4096),
("gpt-3.5-turbo-16k", 16384),
("gpt-4", 8192),
("gpt-4-32k", 32768),
];
#[derive(Debug, Clone, Deserialize, Default)]
pub struct OpenAIConfig {
@ -31,12 +31,8 @@ pub struct OpenAIConfig {
#[async_trait]
impl Client for OpenAIClient {
fn config(&self) -> &SharedConfig {
&self.global_config
}
fn extra_config(&self) -> &Option<ExtraConfig> {
&self.config.extra
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
@ -56,61 +52,25 @@ impl Client for OpenAIClient {
}
impl OpenAIClient {
pub const NAME: &str = "openai";
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
let config = {
if let ClientConfig::OpenAI(c) = &global_config.read().clients[model_info.index] {
c.clone()
} else {
return None;
}
};
Some(Box::new(Self {
global_config,
config,
model_info,
}))
}
config_get_fn!(api_key, get_api_key);
pub fn name(local_config: &OpenAIConfig) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
pub fn list_models(local_config: &OpenAIConfig, index: usize) -> Vec<ModelInfo> {
let client = Self::name(local_config);
[
("gpt-3.5-turbo", 4096),
("gpt-3.5-turbo-16k", 16384),
("gpt-4", 8192),
("gpt-4-32k", 32768),
]
.into_iter()
.map(|(name, max_tokens)| ModelInfo::new(client, name, max_tokens, index))
.collect()
}
pub fn create_config() -> Result<String> {
let mut client_config = format!("clients:\n - type: {}\n", Self::NAME);
let api_key = prompt_input_api_key()?;
client_config.push_str(&format!(" api_key: {api_key}\n"));
Ok(client_config)
MODELS
.into_iter()
.map(|(name, max_tokens)| ModelInfo::new(client, name, Some(max_tokens), index))
.collect()
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let env_prefix = Self::name(&self.config).to_uppercase();
let api_key = self.config.api_key.clone();
let api_key = api_key
.or_else(|| env::var(format!("{env_prefix}_API_KEY")).ok())
.ok_or_else(|| anyhow!("Miss api_key"))?;
let api_key = self.get_api_key()?;
let body = openai_build_body(data, self.model_info.name.clone());
let env_prefix = Self::name(&self.config).to_uppercase();
let api_base = env::var(format!("{env_prefix}_API_BASE"))
.ok()
.unwrap_or_else(|| API_BASE.to_string());
@ -127,20 +87,20 @@ impl OpenAIClient {
}
}
pub(crate) async fn openai_send_message(builder: RequestBuilder) -> Result<String> {
pub async fn openai_send_message(builder: RequestBuilder) -> Result<String> {
let data: Value = builder.send().await?.json().await?;
if let Some(err_msg) = data["error"]["message"].as_str() {
bail!("Request failed, {err_msg}");
bail!("{err_msg}");
}
let output = data["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Unexpected response {data}"))?;
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok(output.to_string())
}
pub(crate) async fn openai_send_message_streaming(
pub async fn openai_send_message_streaming(
builder: RequestBuilder,
handler: &mut ReplyStreamHandler,
) -> Result<()> {
@ -148,7 +108,7 @@ pub(crate) async fn openai_send_message_streaming(
if !res.status().is_success() {
let data: Value = res.json().await?;
if let Some(err_msg) = data["error"]["message"].as_str() {
bail!("Request failed, {err_msg}");
bail!("{err_msg}");
}
bail!("Request failed");
}
@ -159,37 +119,30 @@ pub(crate) async fn openai_send_message_streaming(
break;
}
let data: Value = serde_json::from_str(&chunk)?;
let text = data["choices"][0]["delta"]["content"]
.as_str()
.unwrap_or_default();
if text.is_empty() {
continue;
if let Some(text) = data["choices"][0]["delta"]["content"].as_str() {
handler.text(text)?;
}
handler.text(text)?;
}
Ok(())
}
pub(crate) fn openai_build_body(data: SendData, model: String) -> Value {
pub fn openai_build_body(data: SendData, model: String) -> Value {
let SendData {
messages,
temperature,
stream,
} = data;
let mut body = json!({
"model": model,
"messages": messages,
});
if let Some(v) = temperature {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
body["temperature"] = v.into();
}
if stream {
body.as_object_mut()
.and_then(|m| m.insert("stream".into(), json!(true)));
body["stream"] = true.into();
}
body
}

@ -17,7 +17,7 @@ impl Message {
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageRole {
System,
@ -25,6 +25,13 @@ pub enum MessageRole {
User,
}
impl MessageRole {
#[allow(dead_code)]
pub fn is_system(&self) -> bool {
matches!(self, MessageRole::System)
}
}
pub fn num_tokens_from_messages(messages: &[Message]) -> usize {
let mut num_tokens = 0;
for message in messages.iter() {

@ -1,19 +1,20 @@
mod message;
mod model_info;
mod role;
mod session;
pub use self::message::Message;
pub use self::model_info::ModelInfo;
use self::role::Role;
use self::session::{Session, TEMP_SESSION_NAME};
use crate::client::openai::{OpenAIClient, OpenAIConfig};
use crate::client::{
create_client_config, list_client_types, list_models, prompt_op_err, ClientConfig, ExtraConfig,
ModelInfo, SendData,
all_models, create_client_config, list_client_types, ClientConfig, ExtraConfig, OpenAIClient,
SendData,
};
use crate::config::message::num_tokens_from_messages;
use crate::render::RenderOptions;
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now};
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err};
use anyhow::{anyhow, bail, Context, Result};
use inquire::{Confirm, Select, Text};
@ -49,6 +50,8 @@ const SET_COMPLETIONS: [&str; 7] = [
".set dry_run false",
];
const CLIENTS_FIELD: &str = "clients";
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct Config {
@ -61,7 +64,7 @@ pub struct Config {
pub save: bool,
/// Whether to disable highlight
pub highlight: bool,
/// Used only for debugging
/// Dry-run flag
pub dry_run: bool,
/// Whether to use a light theme
pub light_theme: bool,
@ -105,7 +108,7 @@ impl Default for Config {
wrap_code: false,
auto_copy: false,
keybindings: Default::default(),
clients: vec![ClientConfig::OpenAI(OpenAIConfig::default())],
clients: vec![ClientConfig::default()],
roles: vec![],
role: None,
session: None,
@ -145,11 +148,11 @@ impl Config {
config.temperature = config.default_temperature;
config.set_model_info()?;
config.merge_env_vars();
config.load_roles()?;
config.ensure_sessions_dir()?;
config.detect_theme()?;
config.setup_model_info()?;
config.setup_highlight();
config.setup_light_theme()?;
Ok(config)
}
@ -296,8 +299,10 @@ impl Config {
vec![message]
};
let tokens = num_tokens_from_messages(&messages);
if tokens >= self.model_info.max_tokens {
bail!("Exceed max tokens limit")
if let Some(max_tokens) = self.model_info.max_tokens {
if tokens >= max_tokens {
bail!("Exceed max tokens limit")
}
}
Ok(messages)
@ -318,7 +323,7 @@ impl Config {
}
pub fn set_model(&mut self, value: &str) -> Result<()> {
let models = list_models(self);
let models = all_models(self);
let mut model_info = None;
if value.contains(':') {
if let Some(model) = models.iter().find(|v| v.stringify() == value) {
@ -339,14 +344,6 @@ impl Config {
}
}
pub const fn get_reamind_tokens(&self) -> usize {
let mut tokens = self.model_info.max_tokens;
if let Some(session) = self.session.as_ref() {
tokens = tokens.saturating_sub(session.tokens);
}
tokens
}
pub fn info(&self) -> Result<String> {
let path_info = |path: &Path| {
let state = if path.exists() { "" } else { " ⚠️" };
@ -390,12 +387,7 @@ impl Config {
completion.extend(SET_COMPLETIONS.map(std::string::ToString::to_string));
completion.extend(
list_models(self)
.iter()
.map(|v| format!(".model {}", v.stringify())),
);
completion.extend(
list_models(self)
all_models(self)
.iter()
.map(|v| format!(".model {}", v.stringify())),
);
@ -504,6 +496,14 @@ impl Config {
name = Text::new("Session name:").with_default(&name).prompt()?;
}
let session_path = Self::session_file(&name)?;
let sessions_dir = session_path.parent().ok_or_else(|| {
anyhow!("Unable to save session file to {}", session_path.display())
})?;
if !sessions_dir.exists() {
create_dir_all(sessions_dir).with_context(|| {
format!("Failed to create session_dir '{}'", sessions_dir.display())
})?;
}
session.save(&session_path)?;
}
}
@ -556,6 +556,24 @@ impl Config {
Ok(RenderOptions::new(theme, wrap, self.wrap_code))
}
pub fn render_prompt_right(&self) -> String {
if let Some(session) = &self.session {
let tokens = session.tokens;
// 10000(%32)
match self.model_info.max_tokens {
Some(max_tokens) => {
let ratio = tokens as f32 / max_tokens as f32;
let percent = ratio * 100.0;
let percent = (percent * 100.0).round() / 100.0;
format!("{tokens}({percent}%)")
}
None => format!("{tokens}"),
}
} else {
String::new()
}
}
pub fn prepare_send_data(&self, content: &str, stream: bool) -> Result<SendData> {
let messages = self.build_messages(content)?;
Ok(SendData {
@ -585,11 +603,20 @@ impl Config {
}
fn load_config(config_path: &Path) -> Result<Self> {
let content = read_to_string(config_path)
.with_context(|| format!("Failed to load config at {}", config_path.display()))?;
let ctx = || format!("Failed to load config at {}", config_path.display());
let content = read_to_string(config_path).with_context(ctx)?;
let config: Self = serde_yaml::from_str(&content)
.with_context(|| format!("Invalid config at {}", config_path.display()))?;
.map_err(|err| {
let err_msg = err.to_string();
if err_msg.starts_with(&format!("{}: ", CLIENTS_FIELD)) {
anyhow!("clients: invalid value")
} else {
anyhow!("{err_msg}")
}
})
.with_context(ctx)?;
Ok(config)
}
@ -606,11 +633,11 @@ impl Config {
Ok(())
}
fn set_model_info(&mut self) -> Result<()> {
fn setup_model_info(&mut self) -> Result<()> {
let model = match &self.model {
Some(v) => v.clone(),
None => {
let models = self::list_models(self);
let models = all_models(self);
if models.is_empty() {
bail!("No available model");
}
@ -622,7 +649,7 @@ impl Config {
Ok(())
}
fn merge_env_vars(&mut self) {
fn setup_highlight(&mut self) {
if let Ok(value) = env::var("NO_COLOR") {
let mut no_color = false;
set_bool(&mut no_color, &value);
@ -632,17 +659,7 @@ impl Config {
}
}
fn ensure_sessions_dir(&self) -> Result<()> {
let sessions_dir = Self::sessions_dir()?;
if !sessions_dir.exists() {
create_dir_all(&sessions_dir).with_context(|| {
format!("Failed to create session_dir '{}'", sessions_dir.display())
})?;
}
Ok(())
}
fn detect_theme(&mut self) -> Result<()> {
fn setup_light_theme(&mut self) -> Result<()> {
if self.light_theme {
return Ok(());
}
@ -660,7 +677,7 @@ impl Config {
fn compat_old_config(&mut self, config_path: &PathBuf) -> Result<()> {
let content = read_to_string(config_path)?;
let value: serde_json::Value = serde_yaml::from_str(&content)?;
if value.get("clients").is_some() {
if value.get(CLIENTS_FIELD).is_some() {
return Ok(());
}
@ -725,16 +742,18 @@ fn create_config_file(config_path: &Path) -> Result<()> {
exit(0);
}
let client = Select::new("AI Platform:", list_client_types())
let client = Select::new("Platform:", list_client_types())
.prompt()
.map_err(prompt_op_err)?;
let mut raw_config = create_client_config(client)?;
let mut config = serde_json::json!({});
config["model"] = client.into();
config[CLIENTS_FIELD] = create_client_config(client)?;
raw_config.push_str(&format!("model: {client}\n"));
let config_data = serde_yaml::to_string(&config).with_context(|| "Failed to create config")?;
ensure_parent_exists(config_path)?;
std::fs::write(config_path, raw_config).with_context(|| "Failed to write to config file")?;
std::fs::write(config_path, config_data).with_context(|| "Failed to write to config file")?;
#[cfg(unix)]
{
use std::os::unix::prelude::PermissionsExt;

@ -0,0 +1,27 @@
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub client: String,
pub name: String,
pub max_tokens: Option<usize>,
pub index: usize,
}
impl Default for ModelInfo {
fn default() -> Self {
ModelInfo::new("", "", None, 0)
}
}
impl ModelInfo {
pub fn new(client: &str, name: &str, max_tokens: Option<usize>, index: usize) -> Self {
Self {
client: client.into(),
name: name.into(),
max_tokens,
index,
}
}
pub fn stringify(&self) -> String {
format!("{}:{}", self.client, self.name)
}
}

@ -12,7 +12,7 @@ use crate::config::{Config, SharedConfig};
use anyhow::Result;
use clap::Parser;
use client::{init_client, list_models};
use client::{all_models, init_client};
use crossbeam::sync::WaitGroup;
use is_terminal::IsTerminal;
use parking_lot::RwLock;
@ -36,7 +36,7 @@ fn main() -> Result<()> {
exit(0);
}
if cli.list_models {
for model in list_models(&config.read()) {
for model in all_models(&config.read()) {
println!("{}", model.stringify());
}
exit(0);

@ -32,11 +32,7 @@ impl Prompt for ReplPrompt {
}
fn render_prompt_right(&self) -> Cow<str> {
if self.config.read().session.is_none() {
Cow::Borrowed("")
} else {
self.config.read().get_reamind_tokens().to_string().into()
}
Cow::Owned(self.config.read().render_prompt_right())
}
fn render_prompt_indicator(&self, _prompt_mode: reedline::PromptEditMode) -> Cow<str> {

@ -1,6 +1,8 @@
mod prompt_input;
mod split_line;
mod tiktoken;
pub use self::prompt_input::*;
pub use self::split_line::*;
pub use self::tiktoken::cl100k_base_singleton;
@ -63,3 +65,11 @@ pub fn light_theme_from_colorfgbg(colorfgbg: &str) -> Option<bool> {
let light = v > 128.0;
Some(light)
}
pub fn init_tokio_runtime() -> anyhow::Result<tokio::runtime::Runtime> {
use anyhow::Context;
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.with_context(|| "Failed to init tokio")
}

@ -0,0 +1,58 @@
use inquire::{required, validator::Validation, Text};
const MSG_REQUIRED: &str = "This field is required";
const MSG_OPTIONAL: &str = "Optional field - Press ↵ to skip";
pub fn prompt_input_string(desc: &str, required: bool) -> anyhow::Result<String> {
let mut text = Text::new(desc);
if required {
text = text.with_validator(required!(MSG_REQUIRED))
} else {
text = text.with_help_message(MSG_OPTIONAL)
}
text.prompt().map_err(prompt_op_err)
}
pub fn prompt_input_integer(desc: &str, required: bool) -> anyhow::Result<String> {
let mut text = Text::new(desc);
if required {
text = text.with_validator(|text: &str| {
let out = if text.is_empty() {
Validation::Invalid(MSG_REQUIRED.into())
} else {
validate_integer(text)
};
Ok(out)
})
} else {
text = text
.with_validator(|text: &str| {
let out = if text.is_empty() {
Validation::Valid
} else {
validate_integer(text)
};
Ok(out)
})
.with_help_message(MSG_OPTIONAL)
}
text.prompt().map_err(prompt_op_err)
}
pub fn prompt_op_err<T>(_: T) -> anyhow::Error {
anyhow::anyhow!("Not finish questionnaire, try again later!")
}
#[derive(Debug, Clone, Copy)]
pub enum PromptKind {
String,
Integer,
}
fn validate_integer(text: &str) -> Validation {
if text.parse::<i32>().is_err() {
Validation::Invalid("Must be a integer".into())
} else {
Validation::Valid
}
}
Loading…
Cancel
Save