feat: improve client (#189)

This commit is contained in:
sigoden 2023-10-31 16:50:08 +08:00 committed by GitHub
parent 9b614600c6
commit a137483b03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 349 additions and 347 deletions

View File

@ -44,10 +44,8 @@ On first launch, aichat will guide you through the configuration.
```
> No config file, create a new one? Yes
> Select AI? openai
> API key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
> Has Organization? No
> Use proxy? No
> AI Platform: openai
> API Key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
> Save chat messages Yes
```
@ -64,24 +62,35 @@ wrap_code: false # Whether wrap code block
auto_copy: false # Automatically copy the last output to the clipboard
keybindings: emacs # REPL keybindings, possible values: emacs (default), vi
clients: # Setup AIs
clients:
# All clients have the following configuration:
# ```
# - type: xxxx
# name: nova # Only use it to distinguish clients with the same client type. Optional
# extra:
# proxy: socks5://127.0.0.1:1080 # Specify https/socks5 proxy server. Note HTTPS_PROXY/ALL_PROXY also works.
# connect_timeout: 10 # Set a timeout in seconds for connect to server
# ```
# See https://platform.openai.com/docs/quickstart
- type: openai # OpenAI configuration
api_key: sk-xxx # OpenAI api key, alternative to OPENAI_API_KEY
- type: openai
api_key: sk-xxx
organization_id: org-xxx # Organization ID. Optional
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
- type: azure-openai # Azure openai configuration
api_base: https://RESOURCE.openai.azure.com # Azure openai base URL
api_key: xxx # Azure openai api key, alternative to AZURE_OPENAI_KEY
- type: azure-openai
api_base: https://RESOURCE.openai.azure.com
api_key: xxx
models: # Support models
- name: MyGPT4 # Model deployment name
max_tokens: 8192
# See https://github.com/go-skynet/LocalAI
- type: localai # LocalAI configuration
url: http://localhost:8080/v1/chat/completions # LocalAI api server
api_key: xxx # Api key. alternative to LOCALAI_API_KEY
- type: localai
api_base: http://localhost:8080/v1
api_key: xxx
chat_endpoint: /chat/completions # Optional
models: # Support models
- name: gpt4all-j
max_tokens: 8192

View File

@ -8,30 +8,35 @@ wrap_code: false # Whether wrap code block
auto_copy: false # Automatically copy the last output to the clipboard
keybindings: emacs # REPL keybindings, possible values: emacs (default), vi
clients: # Setup AIs
clients:
# All clients have the following configuration:
# ```
# - type: xxxx
# name: nova # Only use it to distinguish clients with the same client type. Optional
# extra:
# proxy: socks5://127.0.0.1:1080 # Specify https/socks5 proxy server. Note HTTPS_PROXY/ALL_PROXY also works.
# connect_timeout: 10 # Set a timeout in seconds for connect to server
# ```
# See https://platform.openai.com/docs/quickstart
- type: openai # OpenAI configuration
api_key: sk-xxx # OpenAI api key, alternative to OPENAI_API_KEY
- type: openai
api_key: sk-xxx
organization_id: org-xxx # Organization ID. Optional
proxy: socks5://127.0.0.1:1080
connect_timeout: 10
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
- type: azure-openai # Azure openai configuration
api_base: https://RESOURCE.openai.azure.com # Azure openai base URL
api_key: xxx # Azure openai api key, alternative to AZURE_OPENAI_KEY
- type: azure-openai
api_base: https://RESOURCE.openai.azure.com
api_key: xxx
models: # Support models
- name: MyGPT4 # Model deployment name
max_tokens: 8192
proxy: socks5://127.0.0.1:1080 # Set proxy server. Optional
connect_timeout: 10 # Set a timeout in seconds for connect to gpt. Optional
# See https://github.com/go-skynet/LocalAI
- type: localai # LocalAI configuration
url: http://localhost:8080/v1/chat/completions # LocalAI api server
api_key: xxx # Api key. alternative to LOCALAI_API_KEY
- type: localai
api_base: http://localhost:8080/v1
api_key: xxx
chat_endpoint: /chat/completions # Optional
models: # Support models
- name: gpt4all-j
max_tokens: 8192
proxy: socks5://127.0.0.1:1080
connect_timeout: 10

View File

@ -1,34 +1,33 @@
use super::openai::{openai_send_message, openai_send_message_streaming};
use super::{set_proxy, Client, ClientConfig, ModelInfo};
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
use super::{
prompt_input_api_base, prompt_input_api_key, prompt_input_max_token, prompt_input_model_name,
Client, ClientConfig, ExtraConfig, ModelInfo, SendData,
};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
use anyhow::{anyhow, Context, Result};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use inquire::{Confirm, Text};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::json;
use std::env;
use std::time::Duration;
#[allow(clippy::module_name_repetitions)]
use std::env;
#[derive(Debug)]
pub struct AzureOpenAIClient {
global_config: SharedConfig,
local_config: AzureOpenAIConfig,
config: AzureOpenAIConfig,
model_info: ModelInfo,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAIConfig {
pub name: Option<String>,
pub api_base: String,
pub api_key: Option<String>,
pub models: Vec<AzureOpenAIModel>,
pub proxy: Option<String>,
/// Set a timeout in seconds for connect to server
pub connect_timeout: Option<u64>,
pub extra: Option<ExtraConfig>,
}
#[derive(Debug, Clone, Deserialize)]
@ -39,32 +38,36 @@ pub struct AzureOpenAIModel {
#[async_trait]
impl Client for AzureOpenAIClient {
fn get_config(&self) -> &SharedConfig {
fn config(&self) -> &SharedConfig {
&self.global_config
}
async fn send_message_inner(&self, content: &str) -> Result<String> {
let builder = self.request_builder(content, false)?;
fn extra_config(&self) -> &Option<ExtraConfig> {
&self.config.extra
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
openai_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
content: &str,
client: &ReqwestClient,
handler: &mut ReplyStreamHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(content, true)?;
let builder = self.request_builder(client, data)?;
openai_send_message_streaming(builder, handler).await
}
}
impl AzureOpenAIClient {
pub const NAME: &str = "azure-openai";
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
if model_info.client != AzureOpenAIClient::name() {
return None;
}
let local_config = {
let config = {
if let ClientConfig::AzureOpenAI(c) = &global_config.read().clients[model_info.index] {
c.clone()
} else {
@ -73,59 +76,37 @@ impl AzureOpenAIClient {
};
Some(Box::new(Self {
global_config,
local_config,
config,
model_info,
}))
}
pub fn name() -> &'static str {
"azure-openai"
pub fn name(local_config: &AzureOpenAIConfig) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}
pub fn list_models(local_config: &AzureOpenAIConfig, index: usize) -> Vec<ModelInfo> {
let client = Self::name(local_config);
local_config
.models
.iter()
.map(|v| ModelInfo::new(Self::name(), &v.name, v.max_tokens, index))
.map(|v| ModelInfo::new(client, &v.name, v.max_tokens, index))
.collect()
}
pub fn create_config() -> Result<String> {
let mut client_config = format!("clients:\n - type: {}\n", Self::name());
let api_base = Text::new("api_base:")
.prompt()
.map_err(|_| anyhow!("An error happened when asking for api base, try again later."))?;
let mut client_config = format!("clients:\n - type: {}\n", Self::NAME);
let api_base = prompt_input_api_base()?;
client_config.push_str(&format!(" api_base: {api_base}\n"));
if env::var("AZURE_OPENAI_KEY").is_err() {
let api_key = Text::new("API key:").prompt().map_err(|_| {
anyhow!("An error happened when asking for api key, try again later.")
})?;
let api_key = prompt_input_api_key()?;
client_config.push_str(&format!(" api_key: {api_key}\n"));
}
let model_name = Text::new("Model Name:").prompt().map_err(|_| {
anyhow!("An error happened when asking for model name, try again later.")
})?;
let model_name = prompt_input_model_name()?;
let max_tokens = Text::new("Max tokens:").prompt().map_err(|_| {
anyhow!("An error happened when asking for max tokens, try again later.")
})?;
let ans = Confirm::new("Use proxy?")
.with_default(false)
.prompt()
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
if ans {
let proxy = Text::new("Set proxy:").prompt().map_err(|_| {
anyhow!("An error happened when asking for proxy, try again later.")
})?;
client_config.push_str(&format!(" proxy: {proxy}\n"));
}
let max_tokens = prompt_input_max_token()?;
client_config.push_str(&format!(
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
@ -134,50 +115,26 @@ impl AzureOpenAIClient {
Ok(client_config)
}
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
let messages = self.global_config.read().build_messages(content)?;
let mut body = json!({
"messages": messages,
});
if let Some(v) = self.global_config.read().get_temperature() {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
}
if stream {
body.as_object_mut()
.and_then(|m| m.insert("stream".into(), json!(true)));
}
let client = {
let mut builder = ReqwestClient::builder();
builder = set_proxy(builder, &self.local_config.proxy)?;
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
builder
.connect_timeout(timeout)
.build()
.with_context(|| "Failed to build client")?
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.config.api_key.clone();
let api_key = api_key
.or_else(|| {
let env_prefix = match &self.config.name {
None => "AZURE".into(),
Some(v) => v.to_uppercase(),
};
let mut api_base = self.local_config.api_base.clone();
if !api_base.ends_with('/') {
api_base = format!("{api_base}/");
}
env::var(format!("{env_prefix}_OPENAI_KEY")).ok()
})
.ok_or_else(|| anyhow!("Miss api_key"))?;
let body = openai_build_body(data, self.model_info.name.clone());
let url = format!(
"{api_base}openai/deployments/{}/chat/completions?api-version=2023-05-15",
self.model_info.name
"{}/openai/deployments/{}/chat/completions?api-version=2023-05-15",
self.config.api_base, self.model_info.name
);
let mut builder = client.post(url);
if let Some(api_key) = &self.local_config.api_key {
builder = builder.header("api-key", api_key)
} else if let Ok(api_key) = env::var("AZURE_OPENAI_KEY") {
builder = builder.header("api-key", api_key)
}
builder = builder.json(&body);
let builder = client.post(url).header("api-key", api_key).json(&body);
Ok(builder)
}

View File

@ -1,33 +1,33 @@
use super::openai::{openai_send_message, openai_send_message_streaming};
use super::{set_proxy, Client, ClientConfig, ModelInfo};
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
use super::{
prompt_input_api_base, prompt_input_api_key_optional, prompt_input_max_token,
prompt_input_model_name, Client, ClientConfig, ExtraConfig, ModelInfo, SendData,
};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
use anyhow::{anyhow, Context, Result};
use anyhow::Result;
use async_trait::async_trait;
use inquire::{Confirm, Text};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::json;
use std::env;
use std::time::Duration;
#[derive(Debug)]
pub struct LocalAIClient {
global_config: SharedConfig,
local_config: LocalAIConfig,
config: LocalAIConfig,
model_info: ModelInfo,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LocalAIConfig {
pub url: String,
pub name: Option<String>,
pub api_base: String,
pub api_key: Option<String>,
pub chat_endpoint: Option<String>,
pub models: Vec<LocalAIModel>,
pub proxy: Option<String>,
/// Set a timeout in seconds for connect to server
pub connect_timeout: Option<u64>,
pub extra: Option<ExtraConfig>,
}
#[derive(Debug, Clone, Deserialize)]
@ -38,32 +38,36 @@ pub struct LocalAIModel {
#[async_trait]
impl Client for LocalAIClient {
fn get_config(&self) -> &SharedConfig {
fn config(&self) -> &SharedConfig {
&self.global_config
}
async fn send_message_inner(&self, content: &str) -> Result<String> {
let builder = self.request_builder(content, false)?;
fn extra_config(&self) -> &Option<ExtraConfig> {
&self.config.extra
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
openai_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
content: &str,
client: &ReqwestClient,
handler: &mut ReplyStreamHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(content, true)?;
let builder = self.request_builder(client, data)?;
openai_send_message_streaming(builder, handler).await
}
}
impl LocalAIClient {
pub const NAME: &str = "localai";
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
if model_info.client != LocalAIClient::name() {
return None;
}
let local_config = {
let config = {
if let ClientConfig::LocalAI(c) = &global_config.read().clients[model_info.index] {
c.clone()
} else {
@ -72,64 +76,37 @@ impl LocalAIClient {
};
Some(Box::new(Self {
global_config,
local_config,
config,
model_info,
}))
}
pub fn name() -> &'static str {
"localai"
pub fn name(local_config: &LocalAIConfig) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}
pub fn list_models(local_config: &LocalAIConfig, index: usize) -> Vec<ModelInfo> {
let client = Self::name(local_config);
local_config
.models
.iter()
.map(|v| ModelInfo::new(Self::name(), &v.name, v.max_tokens, index))
.map(|v| ModelInfo::new(client, &v.name, v.max_tokens, index))
.collect()
}
pub fn create_config() -> Result<String> {
let mut client_config = format!("clients:\n - type: {}\n", Self::name());
let mut client_config = format!("clients:\n - type: {}\n", Self::NAME);
let url = Text::new("URL:")
.prompt()
.map_err(|_| anyhow!("An error happened when asking for url, try again later."))?;
client_config.push_str(&format!(" url: {url}\n"));
let ans = Confirm::new("Use auth?")
.with_default(false)
.prompt()
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
if ans {
let api_key = Text::new("API key:").prompt().map_err(|_| {
anyhow!("An error happened when asking for api key, try again later.")
})?;
let api_base = prompt_input_api_base()?;
client_config.push_str(&format!(" api_base: {api_base}\n"));
let api_key = prompt_input_api_key_optional()?;
client_config.push_str(&format!(" api_key: {api_key}\n"));
}
let model_name = Text::new("Model Name:").prompt().map_err(|_| {
anyhow!("An error happened when asking for model name, try again later.")
})?;
let model_name = prompt_input_model_name()?;
let max_tokens = Text::new("Max tokens:").prompt().map_err(|_| {
anyhow!("An error happened when asking for max tokens, try again later.")
})?;
let ans = Confirm::new("Use proxy?")
.with_default(false)
.prompt()
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
if ans {
let proxy = Text::new("Set proxy:").prompt().map_err(|_| {
anyhow!("An error happened when asking for proxy, try again later.")
})?;
client_config.push_str(&format!(" proxy: {proxy}\n"));
}
let max_tokens = prompt_input_max_token()?;
client_config.push_str(&format!(
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
@ -138,41 +115,27 @@ impl LocalAIClient {
Ok(client_config)
}
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
let messages = self.global_config.read().build_messages(content)?;
let mut body = json!({
"model": self.model_info.name,
"messages": messages,
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.config.api_key.clone();
let api_key = api_key.or_else(|| {
let env_prefix = Self::name(&self.config).to_uppercase();
env::var(format!("{env_prefix}_API_KEY")).ok()
});
if let Some(v) = self.global_config.read().get_temperature() {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
}
let body = openai_build_body(data, self.model_info.name.clone());
if stream {
body.as_object_mut()
.and_then(|m| m.insert("stream".into(), json!(true)));
}
let chat_endpoint = self
.config
.chat_endpoint
.as_deref()
.unwrap_or("/chat/completions");
let client = {
let mut builder = ReqwestClient::builder();
builder = set_proxy(builder, &self.local_config.proxy)?;
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
builder
.connect_timeout(timeout)
.build()
.with_context(|| "Failed to build client")?
};
let url = format!("{}{chat_endpoint}", self.config.api_base);
let mut builder = client.post(&self.local_config.url);
if let Some(api_key) = &self.local_config.api_key {
builder = builder.bearer_auth(api_key);
} else if let Ok(api_key) = env::var("LOCALAI_API_KEY") {
let mut builder = client.post(url).json(&body);
if let Some(api_key) = api_key {
builder = builder.bearer_auth(api_key);
}
builder = builder.json(&body);
Ok(builder)
}

View File

@ -8,20 +8,21 @@ use self::{
openai::{OpenAIClient, OpenAIConfig},
};
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
use reqwest::{ClientBuilder, Proxy};
use serde::Deserialize;
use std::{env, time::Duration};
use tokio::time::sleep;
use crate::{
client::localai::LocalAIClient,
config::{Config, SharedConfig},
config::{Config, Message, SharedConfig},
repl::{ReplyStreamHandler, SharedAbortSignal},
utils::tokenize,
};
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
use inquire::{required, Text};
use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy};
use serde::Deserialize;
use std::{env, time::Duration};
use tokio::time::sleep;
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum ClientConfig {
@ -32,7 +33,6 @@ pub enum ClientConfig {
#[serde(rename = "azure-openai")]
AzureOpenAI(AzureOpenAIConfig),
}
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub client: String,
@ -61,17 +61,43 @@ impl ModelInfo {
}
}
#[derive(Debug)]
pub struct SendData {
pub messages: Vec<Message>,
pub temperature: Option<f64>,
pub stream: bool,
}
#[async_trait]
pub trait Client {
fn get_config(&self) -> &SharedConfig;
fn config(&self) -> &SharedConfig;
fn extra_config(&self) -> &Option<ExtraConfig>;
fn build_client(&self) -> Result<ReqwestClient> {
let mut builder = ReqwestClient::builder();
let options = self.extra_config();
let timeout = options
.as_ref()
.and_then(|v| v.connect_timeout)
.unwrap_or(10);
let proxy = options.as_ref().and_then(|v| v.proxy.clone());
builder = set_proxy(builder, &proxy)?;
let client = builder
.connect_timeout(Duration::from_secs(timeout))
.build()
.with_context(|| "Failed to build client")?;
Ok(client)
}
fn send_message(&self, content: &str) -> Result<String> {
init_tokio_runtime()?.block_on(async {
if self.get_config().read().dry_run {
let content = self.get_config().read().echo_messages(content);
if self.config().read().dry_run {
let content = self.config().read().echo_messages(content);
return Ok(content);
}
self.send_message_inner(content)
let client = self.build_client()?;
let data = self.config().read().prepare_send_data(content, false)?;
self.send_message_inner(&client, data)
.await
.with_context(|| "Failed to fetch")
})
@ -94,8 +120,8 @@ pub trait Client {
init_tokio_runtime()?.block_on(async {
tokio::select! {
ret = async {
if self.get_config().read().dry_run {
let content = self.get_config().read().echo_messages(content);
if self.config().read().dry_run {
let content = self.config().read().echo_messages(content);
let tokens = tokenize(&content);
for token in tokens {
tokio::time::sleep(Duration::from_millis(25)).await;
@ -103,7 +129,9 @@ pub trait Client {
}
return Ok(());
}
self.send_message_streaming_inner(content, handler).await
let client = self.build_client()?;
let data = self.config().read().prepare_send_data(content, true)?;
self.send_message_streaming_inner(&client, handler, data).await
} => {
handler.done()?;
ret.with_context(|| "Failed to fetch stream")
@ -120,15 +148,22 @@ pub trait Client {
})
}
async fn send_message_inner(&self, content: &str) -> Result<String>;
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String>;
async fn send_message_streaming_inner(
&self,
content: &str,
client: &ReqwestClient,
handler: &mut ReplyStreamHandler,
data: SendData,
) -> Result<()>;
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct ExtraConfig {
pub proxy: Option<String>,
pub connect_timeout: Option<u64>,
}
pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
OpenAIClient::init(config.clone())
.or_else(|| LocalAIClient::init(config.clone()))
@ -143,20 +178,20 @@ pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
})
}
pub fn all_clients() -> Vec<&'static str> {
pub fn list_client_types() -> Vec<&'static str> {
vec![
OpenAIClient::name(),
LocalAIClient::name(),
AzureOpenAIClient::name(),
OpenAIClient::NAME,
LocalAIClient::NAME,
AzureOpenAIClient::NAME,
]
}
pub fn create_client_config(client: &str) -> Result<String> {
if client == OpenAIClient::name() {
if client == OpenAIClient::NAME {
OpenAIClient::create_config()
} else if client == LocalAIClient::name() {
} else if client == LocalAIClient::NAME {
LocalAIClient::create_config()
} else if client == AzureOpenAIClient::name() {
} else if client == AzureOpenAIClient::NAME {
AzureOpenAIClient::create_config()
} else {
bail!("Unknown client {}", &client)
@ -176,14 +211,51 @@ pub fn list_models(config: &Config) -> Vec<ModelInfo> {
.collect()
}
pub fn init_tokio_runtime() -> Result<tokio::runtime::Runtime> {
pub(crate) fn init_tokio_runtime() -> Result<tokio::runtime::Runtime> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.with_context(|| "Failed to init tokio")
}
pub(crate) fn set_proxy(builder: ClientBuilder, proxy: &Option<String>) -> Result<ClientBuilder> {
pub(crate) fn prompt_input_api_base() -> Result<String> {
Text::new("API Base:")
.with_validator(required!("This field is required"))
.prompt()
.map_err(prompt_op_err)
}
pub(crate) fn prompt_input_api_key() -> Result<String> {
Text::new("API Key:")
.with_validator(required!("This field is required"))
.prompt()
.map_err(prompt_op_err)
}
pub(crate) fn prompt_input_api_key_optional() -> Result<String> {
Text::new("API Key:").prompt().map_err(prompt_op_err)
}
pub(crate) fn prompt_input_model_name() -> Result<String> {
Text::new("Model Name:")
.with_validator(required!("This field is required"))
.prompt()
.map_err(prompt_op_err)
}
pub(crate) fn prompt_input_max_token() -> Result<String> {
Text::new("Max tokens:")
.with_default("4096")
.with_validator(required!("This field is required"))
.prompt()
.map_err(prompt_op_err)
}
pub(crate) fn prompt_op_err<T>(_: T) -> anyhow::Error {
anyhow!("An error happened, try again later.")
}
fn set_proxy(builder: ClientBuilder, proxy: &Option<String>) -> Result<ClientBuilder> {
let proxy = if let Some(proxy) = proxy {
if proxy.is_empty() || proxy == "false" || proxy == "-" {
return Ok(builder);

View File

@ -1,65 +1,66 @@
use super::{set_proxy, Client, ClientConfig, ModelInfo};
use super::{prompt_input_api_key, Client, ClientConfig, ExtraConfig, ModelInfo, SendData};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
use anyhow::{anyhow, bail, Context, Result};
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures_util::StreamExt;
use inquire::{Confirm, Text};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use std::env;
use std::time::Duration;
const API_BASE: &str = "https://api.openai.com/v1";
#[derive(Debug)]
pub struct OpenAIClient {
global_config: SharedConfig,
local_config: OpenAIConfig,
config: OpenAIConfig,
model_info: ModelInfo,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct OpenAIConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub organization_id: Option<String>,
pub proxy: Option<String>,
/// Set a timeout in seconds for connect to openai server
pub connect_timeout: Option<u64>,
pub extra: Option<ExtraConfig>,
}
#[async_trait]
impl Client for OpenAIClient {
fn get_config(&self) -> &SharedConfig {
fn config(&self) -> &SharedConfig {
&self.global_config
}
async fn send_message_inner(&self, content: &str) -> Result<String> {
let builder = self.request_builder(content, false)?;
fn extra_config(&self) -> &Option<ExtraConfig> {
&self.config.extra
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
openai_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
content: &str,
client: &ReqwestClient,
handler: &mut ReplyStreamHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(content, true)?;
let builder = self.request_builder(client, data)?;
openai_send_message_streaming(builder, handler).await
}
}
impl OpenAIClient {
pub const NAME: &str = "openai";
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
if model_info.client != OpenAIClient::name() {
return None;
}
let local_config = {
let config = {
if let ClientConfig::OpenAI(c) = &global_config.read().clients[model_info.index] {
c.clone()
} else {
@ -68,16 +69,18 @@ impl OpenAIClient {
};
Some(Box::new(Self {
global_config,
local_config,
config,
model_info,
}))
}
pub fn name() -> &'static str {
"openai"
pub fn name(local_config: &OpenAIConfig) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}
pub fn list_models(_local_config: &OpenAIConfig, index: usize) -> Vec<ModelInfo> {
pub fn list_models(local_config: &OpenAIConfig, index: usize) -> Vec<ModelInfo> {
let client = Self::name(local_config);
[
("gpt-3.5-turbo", 4096),
("gpt-3.5-turbo-16k", 16384),
@ -85,90 +88,38 @@ impl OpenAIClient {
("gpt-4-32k", 32768),
]
.into_iter()
.map(|(name, max_tokens)| ModelInfo::new(Self::name(), name, max_tokens, index))
.map(|(name, max_tokens)| ModelInfo::new(client, name, max_tokens, index))
.collect()
}
pub fn create_config() -> Result<String> {
let mut client_config = format!("clients:\n - type: {}\n", Self::name());
let api_key = Text::new("API key:")
.prompt()
.map_err(|_| anyhow!("An error happened when asking for api key, try again later."))?;
let mut client_config = format!("clients:\n - type: {}\n", Self::NAME);
let api_key = prompt_input_api_key()?;
client_config.push_str(&format!(" api_key: {api_key}\n"));
let ans = Confirm::new("Has Organization?")
.with_default(false)
.prompt()
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
if ans {
let organization_id = Text::new("Organization ID:").prompt().map_err(|_| {
anyhow!("An error happened when asking for proxy, try again later.")
})?;
client_config.push_str(&format!(" organization_id: {organization_id}\n"));
}
let ans = Confirm::new("Use proxy?")
.with_default(false)
.prompt()
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
if ans {
let proxy = Text::new("Set proxy:").prompt().map_err(|_| {
anyhow!("An error happened when asking for proxy, try again later.")
})?;
client_config.push_str(&format!(" proxy: {proxy}\n"));
}
Ok(client_config)
}
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
let api_key = if let Some(api_key) = &self.local_config.api_key {
api_key.to_string()
} else if let Ok(api_key) = env::var("OPENAI_API_KEY") {
api_key.to_string()
} else {
bail!("Miss api_key")
};
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let env_prefix = Self::name(&self.config).to_uppercase();
let messages = self.global_config.read().build_messages(content)?;
let api_key = self.config.api_key.clone();
let api_key = api_key
.or_else(|| env::var(format!("{env_prefix}_API_KEY")).ok())
.ok_or_else(|| anyhow!("Miss api_key"))?;
let mut body = json!({
"model": self.model_info.name,
"messages": messages,
});
let body = openai_build_body(data, self.model_info.name.clone());
if let Some(v) = self.global_config.read().get_temperature() {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
}
if stream {
body.as_object_mut()
.and_then(|m| m.insert("stream".into(), json!(true)));
}
let client = {
let mut builder = ReqwestClient::builder();
builder = set_proxy(builder, &self.local_config.proxy)?;
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
builder
.connect_timeout(timeout)
.build()
.with_context(|| "Failed to build client")?
};
let api_base = env::var("OPENAI_API_BASE")
let api_base = env::var(format!("{env_prefix}_API_BASE"))
.ok()
.unwrap_or_else(|| API_BASE.to_string());
let url = format!("{api_base}/chat/completions");
let mut builder = client.post(url).bearer_auth(api_key).json(&body);
if let Some(organization_id) = &self.local_config.organization_id {
if let Some(organization_id) = &self.config.organization_id {
builder = builder.header("OpenAI-Organization", organization_id);
}
@ -219,3 +170,26 @@ pub(crate) async fn openai_send_message_streaming(
Ok(())
}
pub(crate) fn openai_build_body(data: SendData, model: String) -> Value {
let SendData {
messages,
temperature,
stream,
} = data;
let mut body = json!({
"model": model,
"messages": messages,
});
if let Some(v) = temperature {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
}
if stream {
body.as_object_mut()
.and_then(|m| m.insert("stream".into(), json!(true)));
}
body
}

View File

@ -2,12 +2,15 @@ mod message;
mod role;
mod session;
use self::message::Message;
pub use self::message::Message;
use self::role::Role;
use self::session::{Session, TEMP_SESSION_NAME};
use crate::client::openai::{OpenAIClient, OpenAIConfig};
use crate::client::{all_clients, create_client_config, list_models, ClientConfig, ModelInfo};
use crate::client::{
create_client_config, list_client_types, list_models, prompt_op_err, ClientConfig, ExtraConfig,
ModelInfo, SendData,
};
use crate::config::message::num_tokens_from_messages;
use crate::render::RenderOptions;
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now};
@ -94,7 +97,7 @@ impl Default for Config {
Self {
model: None,
default_temperature: None,
save: false,
save: true,
highlight: true,
dry_run: false,
light_theme: false,
@ -136,19 +139,17 @@ impl Config {
config.compat_old_config(&config_path)?;
}
if let Some(name) = config.model.clone() {
config.set_model(&name)?;
}
if let Some(wrap) = config.wrap.clone() {
config.set_wrap(&wrap)?;
}
config.temperature = config.default_temperature;
config.set_model_info()?;
config.merge_env_vars();
config.load_roles()?;
config.ensure_sessions_dir()?;
config.check_term_theme()?;
config.detect_theme()?;
Ok(config)
}
@ -327,7 +328,7 @@ impl Config {
model_info = Some(model.clone());
}
match model_info {
None => bail!("Invalid model"),
None => bail!("Unknown model '{}'", value),
Some(model_info) => {
if let Some(session) = self.session.as_mut() {
session.set_model(&model_info.stringify())?;
@ -555,6 +556,15 @@ impl Config {
Ok(RenderOptions::new(theme, wrap, self.wrap_code))
}
pub fn prepare_send_data(&self, content: &str, stream: bool) -> Result<SendData> {
let messages = self.build_messages(content)?;
Ok(SendData {
messages,
temperature: self.get_temperature(),
stream,
})
}
pub fn maybe_print_send_tokens(&self, input: &str) {
if self.dry_run {
if let Ok(messages) = self.build_messages(input) {
@ -596,6 +606,22 @@ impl Config {
Ok(())
}
fn set_model_info(&mut self) -> Result<()> {
let model = match &self.model {
Some(v) => v.clone(),
None => {
let models = self::list_models(self);
if models.is_empty() {
bail!("No available model");
}
models[0].stringify()
}
};
self.set_model(&model)?;
Ok(())
}
fn merge_env_vars(&mut self) {
if let Ok(value) = env::var("NO_COLOR") {
let mut no_color = false;
@ -616,7 +642,7 @@ impl Config {
Ok(())
}
fn check_term_theme(&mut self) -> Result<()> {
fn detect_theme(&mut self) -> Result<()> {
if self.light_theme {
return Ok(());
}
@ -640,7 +666,7 @@ impl Config {
if let Some(model_name) = value.get("model").and_then(|v| v.as_str()) {
if model_name.starts_with("gpt") {
self.model = Some(format!("{}:{}", OpenAIClient::name(), model_name));
self.model = Some(format!("{}:{}", OpenAIClient::NAME, model_name));
}
}
@ -653,13 +679,17 @@ impl Config {
client_config.organization_id = Some(organization_id.to_string())
}
let mut extra_config = ExtraConfig::default();
if let Some(proxy) = value.get("proxy").and_then(|v| v.as_str()) {
client_config.proxy = Some(proxy.to_string())
extra_config.proxy = Some(proxy.to_string())
}
if let Some(connect_timeout) = value.get("connect_timeout").and_then(|v| v.as_i64()) {
client_config.connect_timeout = Some(connect_timeout as _)
extra_config.connect_timeout = Some(connect_timeout as _)
}
client_config.extra = Some(extra_config);
}
Ok(())
}
@ -690,27 +720,19 @@ fn create_config_file(config_path: &Path) -> Result<()> {
let ans = Confirm::new("No config file, create a new one?")
.with_default(true)
.prompt()
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
.map_err(prompt_op_err)?;
if !ans {
exit(0);
}
let client = Select::new("Select AI?", all_clients())
let client = Select::new("AI Platform:", list_client_types())
.prompt()
.map_err(|_| anyhow!("An error happened when selecting platform, try again later."))?;
.map_err(prompt_op_err)?;
let mut raw_config = create_client_config(client)?;
raw_config.push_str(&format!("model: {client}\n"));
let ans = Confirm::new("Save chat messages")
.with_default(true)
.prompt()
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
if ans {
raw_config.push_str("save: true\n");
}
ensure_parent_exists(config_path)?;
std::fs::write(config_path, raw_config).with_context(|| "Failed to write to config file")?;
#[cfg(unix)]