mirror of
https://github.com/sigoden/aichat
synced 2024-11-10 07:10:36 +00:00
feat: improve client (#189)
This commit is contained in:
parent
9b614600c6
commit
a137483b03
35
README.md
35
README.md
@ -44,10 +44,8 @@ On first launch, aichat will guide you through the configuration.
|
||||
|
||||
```
|
||||
> No config file, create a new one? Yes
|
||||
> Select AI? openai
|
||||
> API key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
> Has Organization? No
|
||||
> Use proxy? No
|
||||
> AI Platform: openai
|
||||
> API Key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
> Save chat messages Yes
|
||||
```
|
||||
|
||||
@ -64,24 +62,35 @@ wrap_code: false # Whether wrap code block
|
||||
auto_copy: false # Automatically copy the last output to the clipboard
|
||||
keybindings: emacs # REPL keybindings, possible values: emacs (default), vi
|
||||
|
||||
clients: # Setup AIs
|
||||
clients:
|
||||
# All clients have the following configuration:
|
||||
# ```
|
||||
# - type: xxxx
|
||||
# name: nova # Only use it to distinguish clients with the same client type. Optional
|
||||
# extra:
|
||||
# proxy: socks5://127.0.0.1:1080 # Specify https/socks5 proxy server. Note HTTPS_PROXY/ALL_PROXY also works.
|
||||
# connect_timeout: 10 # Set a timeout in seconds for connect to server
|
||||
# ```
|
||||
|
||||
# See https://platform.openai.com/docs/quickstart
|
||||
- type: openai # OpenAI configuration
|
||||
api_key: sk-xxx # OpenAI api key, alternative to OPENAI_API_KEY
|
||||
- type: openai
|
||||
api_key: sk-xxx
|
||||
organization_id: org-xxx # Organization ID. Optional
|
||||
|
||||
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
|
||||
- type: azure-openai # Azure openai configuration
|
||||
api_base: https://RESOURCE.openai.azure.com # Azure openai base URL
|
||||
api_key: xxx # Azure openai api key, alternative to AZURE_OPENAI_KEY
|
||||
- type: azure-openai
|
||||
api_base: https://RESOURCE.openai.azure.com
|
||||
api_key: xxx
|
||||
models: # Support models
|
||||
- name: MyGPT4 # Model deployment name
|
||||
max_tokens: 8192
|
||||
|
||||
|
||||
# See https://github.com/go-skynet/LocalAI
|
||||
- type: localai # LocalAI configuration
|
||||
url: http://localhost:8080/v1/chat/completions # LocalAI api server
|
||||
api_key: xxx # Api key. alternative to LOCALAI_API_KEY
|
||||
- type: localai
|
||||
api_base: http://localhost:8080/v1
|
||||
api_key: xxx
|
||||
chat_endpoint: /chat/completions # Optional
|
||||
models: # Support models
|
||||
- name: gpt4all-j
|
||||
max_tokens: 8192
|
||||
|
@ -8,30 +8,35 @@ wrap_code: false # Whether wrap code block
|
||||
auto_copy: false # Automatically copy the last output to the clipboard
|
||||
keybindings: emacs # REPL keybindings, possible values: emacs (default), vi
|
||||
|
||||
clients: # Setup AIs
|
||||
clients:
|
||||
# All clients have the following configuration:
|
||||
# ```
|
||||
# - type: xxxx
|
||||
# name: nova # Only use it to distinguish clients with the same client type. Optional
|
||||
# extra:
|
||||
# proxy: socks5://127.0.0.1:1080 # Specify https/socks5 proxy server. Note HTTPS_PROXY/ALL_PROXY also works.
|
||||
# connect_timeout: 10 # Set a timeout in seconds for connect to server
|
||||
# ```
|
||||
|
||||
# See https://platform.openai.com/docs/quickstart
|
||||
- type: openai # OpenAI configuration
|
||||
api_key: sk-xxx # OpenAI api key, alternative to OPENAI_API_KEY
|
||||
- type: openai
|
||||
api_key: sk-xxx
|
||||
organization_id: org-xxx # Organization ID. Optional
|
||||
proxy: socks5://127.0.0.1:1080
|
||||
connect_timeout: 10
|
||||
|
||||
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
|
||||
- type: azure-openai # Azure openai configuration
|
||||
api_base: https://RESOURCE.openai.azure.com # Azure openai base URL
|
||||
api_key: xxx # Azure openai api key, alternative to AZURE_OPENAI_KEY
|
||||
- type: azure-openai
|
||||
api_base: https://RESOURCE.openai.azure.com
|
||||
api_key: xxx
|
||||
models: # Support models
|
||||
- name: MyGPT4 # Model deployment name
|
||||
max_tokens: 8192
|
||||
proxy: socks5://127.0.0.1:1080 # Set proxy server. Optional
|
||||
connect_timeout: 10 # Set a timeout in seconds for connect to gpt. Optional
|
||||
|
||||
|
||||
# See https://github.com/go-skynet/LocalAI
|
||||
- type: localai # LocalAI configuration
|
||||
url: http://localhost:8080/v1/chat/completions # LocalAI api server
|
||||
api_key: xxx # Api key. alternative to LOCALAI_API_KEY
|
||||
- type: localai
|
||||
api_base: http://localhost:8080/v1
|
||||
api_key: xxx
|
||||
chat_endpoint: /chat/completions # Optional
|
||||
models: # Support models
|
||||
- name: gpt4all-j
|
||||
max_tokens: 8192
|
||||
proxy: socks5://127.0.0.1:1080
|
||||
connect_timeout: 10
|
@ -1,34 +1,33 @@
|
||||
use super::openai::{openai_send_message, openai_send_message_streaming};
|
||||
use super::{set_proxy, Client, ClientConfig, ModelInfo};
|
||||
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
|
||||
use super::{
|
||||
prompt_input_api_base, prompt_input_api_key, prompt_input_max_token, prompt_input_model_name,
|
||||
Client, ClientConfig, ExtraConfig, ModelInfo, SendData,
|
||||
};
|
||||
|
||||
use crate::config::SharedConfig;
|
||||
use crate::repl::ReplyStreamHandler;
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use inquire::{Confirm, Text};
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use std::env;
|
||||
use std::time::Duration;
|
||||
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
use std::env;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AzureOpenAIClient {
|
||||
global_config: SharedConfig,
|
||||
local_config: AzureOpenAIConfig,
|
||||
config: AzureOpenAIConfig,
|
||||
model_info: ModelInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AzureOpenAIConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_base: String,
|
||||
pub api_key: Option<String>,
|
||||
pub models: Vec<AzureOpenAIModel>,
|
||||
pub proxy: Option<String>,
|
||||
/// Set a timeout in seconds for connect to server
|
||||
pub connect_timeout: Option<u64>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@ -39,32 +38,36 @@ pub struct AzureOpenAIModel {
|
||||
|
||||
#[async_trait]
|
||||
impl Client for AzureOpenAIClient {
|
||||
fn get_config(&self) -> &SharedConfig {
|
||||
fn config(&self) -> &SharedConfig {
|
||||
&self.global_config
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, content: &str) -> Result<String> {
|
||||
let builder = self.request_builder(content, false)?;
|
||||
fn extra_config(&self) -> &Option<ExtraConfig> {
|
||||
&self.config.extra
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
openai_send_message(builder).await
|
||||
}
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
content: &str,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
data: SendData,
|
||||
) -> Result<()> {
|
||||
let builder = self.request_builder(content, true)?;
|
||||
let builder = self.request_builder(client, data)?;
|
||||
openai_send_message_streaming(builder, handler).await
|
||||
}
|
||||
}
|
||||
|
||||
impl AzureOpenAIClient {
|
||||
pub const NAME: &str = "azure-openai";
|
||||
|
||||
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
|
||||
let model_info = global_config.read().model_info.clone();
|
||||
if model_info.client != AzureOpenAIClient::name() {
|
||||
return None;
|
||||
}
|
||||
let local_config = {
|
||||
let config = {
|
||||
if let ClientConfig::AzureOpenAI(c) = &global_config.read().clients[model_info.index] {
|
||||
c.clone()
|
||||
} else {
|
||||
@ -73,59 +76,37 @@ impl AzureOpenAIClient {
|
||||
};
|
||||
Some(Box::new(Self {
|
||||
global_config,
|
||||
local_config,
|
||||
config,
|
||||
model_info,
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn name() -> &'static str {
|
||||
"azure-openai"
|
||||
pub fn name(local_config: &AzureOpenAIConfig) -> &str {
|
||||
local_config.name.as_deref().unwrap_or(Self::NAME)
|
||||
}
|
||||
|
||||
pub fn list_models(local_config: &AzureOpenAIConfig, index: usize) -> Vec<ModelInfo> {
|
||||
let client = Self::name(local_config);
|
||||
|
||||
local_config
|
||||
.models
|
||||
.iter()
|
||||
.map(|v| ModelInfo::new(Self::name(), &v.name, v.max_tokens, index))
|
||||
.map(|v| ModelInfo::new(client, &v.name, v.max_tokens, index))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn create_config() -> Result<String> {
|
||||
let mut client_config = format!("clients:\n - type: {}\n", Self::name());
|
||||
|
||||
let api_base = Text::new("api_base:")
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("An error happened when asking for api base, try again later."))?;
|
||||
let mut client_config = format!("clients:\n - type: {}\n", Self::NAME);
|
||||
|
||||
let api_base = prompt_input_api_base()?;
|
||||
client_config.push_str(&format!(" api_base: {api_base}\n"));
|
||||
|
||||
if env::var("AZURE_OPENAI_KEY").is_err() {
|
||||
let api_key = Text::new("API key:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for api key, try again later.")
|
||||
})?;
|
||||
|
||||
let api_key = prompt_input_api_key()?;
|
||||
client_config.push_str(&format!(" api_key: {api_key}\n"));
|
||||
}
|
||||
|
||||
let model_name = Text::new("Model Name:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for model name, try again later.")
|
||||
})?;
|
||||
let model_name = prompt_input_model_name()?;
|
||||
|
||||
let max_tokens = Text::new("Max tokens:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for max tokens, try again later.")
|
||||
})?;
|
||||
|
||||
let ans = Confirm::new("Use proxy?")
|
||||
.with_default(false)
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
|
||||
|
||||
if ans {
|
||||
let proxy = Text::new("Set proxy:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for proxy, try again later.")
|
||||
})?;
|
||||
client_config.push_str(&format!(" proxy: {proxy}\n"));
|
||||
}
|
||||
let max_tokens = prompt_input_max_token()?;
|
||||
|
||||
client_config.push_str(&format!(
|
||||
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
|
||||
@ -134,50 +115,26 @@ impl AzureOpenAIClient {
|
||||
Ok(client_config)
|
||||
}
|
||||
|
||||
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
|
||||
let messages = self.global_config.read().build_messages(content)?;
|
||||
|
||||
let mut body = json!({
|
||||
"messages": messages,
|
||||
});
|
||||
|
||||
if let Some(v) = self.global_config.read().get_temperature() {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("temperature".into(), json!(v)));
|
||||
}
|
||||
|
||||
if stream {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("stream".into(), json!(true)));
|
||||
}
|
||||
|
||||
let client = {
|
||||
let mut builder = ReqwestClient::builder();
|
||||
builder = set_proxy(builder, &self.local_config.proxy)?;
|
||||
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
|
||||
builder
|
||||
.connect_timeout(timeout)
|
||||
.build()
|
||||
.with_context(|| "Failed to build client")?
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
let api_key = self.config.api_key.clone();
|
||||
let api_key = api_key
|
||||
.or_else(|| {
|
||||
let env_prefix = match &self.config.name {
|
||||
None => "AZURE".into(),
|
||||
Some(v) => v.to_uppercase(),
|
||||
};
|
||||
let mut api_base = self.local_config.api_base.clone();
|
||||
if !api_base.ends_with('/') {
|
||||
api_base = format!("{api_base}/");
|
||||
}
|
||||
env::var(format!("{env_prefix}_OPENAI_KEY")).ok()
|
||||
})
|
||||
.ok_or_else(|| anyhow!("Miss api_key"))?;
|
||||
|
||||
let body = openai_build_body(data, self.model_info.name.clone());
|
||||
|
||||
let url = format!(
|
||||
"{api_base}openai/deployments/{}/chat/completions?api-version=2023-05-15",
|
||||
self.model_info.name
|
||||
"{}/openai/deployments/{}/chat/completions?api-version=2023-05-15",
|
||||
self.config.api_base, self.model_info.name
|
||||
);
|
||||
|
||||
let mut builder = client.post(url);
|
||||
|
||||
if let Some(api_key) = &self.local_config.api_key {
|
||||
builder = builder.header("api-key", api_key)
|
||||
} else if let Ok(api_key) = env::var("AZURE_OPENAI_KEY") {
|
||||
builder = builder.header("api-key", api_key)
|
||||
}
|
||||
builder = builder.json(&body);
|
||||
let builder = client.post(url).header("api-key", api_key).json(&body);
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
|
@ -1,33 +1,33 @@
|
||||
use super::openai::{openai_send_message, openai_send_message_streaming};
|
||||
use super::{set_proxy, Client, ClientConfig, ModelInfo};
|
||||
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
|
||||
use super::{
|
||||
prompt_input_api_base, prompt_input_api_key_optional, prompt_input_max_token,
|
||||
prompt_input_model_name, Client, ClientConfig, ExtraConfig, ModelInfo, SendData,
|
||||
};
|
||||
|
||||
use crate::config::SharedConfig;
|
||||
use crate::repl::ReplyStreamHandler;
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use inquire::{Confirm, Text};
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use std::env;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LocalAIClient {
|
||||
global_config: SharedConfig,
|
||||
local_config: LocalAIConfig,
|
||||
config: LocalAIConfig,
|
||||
model_info: ModelInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct LocalAIConfig {
|
||||
pub url: String,
|
||||
pub name: Option<String>,
|
||||
pub api_base: String,
|
||||
pub api_key: Option<String>,
|
||||
pub chat_endpoint: Option<String>,
|
||||
pub models: Vec<LocalAIModel>,
|
||||
pub proxy: Option<String>,
|
||||
/// Set a timeout in seconds for connect to server
|
||||
pub connect_timeout: Option<u64>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@ -38,32 +38,36 @@ pub struct LocalAIModel {
|
||||
|
||||
#[async_trait]
|
||||
impl Client for LocalAIClient {
|
||||
fn get_config(&self) -> &SharedConfig {
|
||||
fn config(&self) -> &SharedConfig {
|
||||
&self.global_config
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, content: &str) -> Result<String> {
|
||||
let builder = self.request_builder(content, false)?;
|
||||
fn extra_config(&self) -> &Option<ExtraConfig> {
|
||||
&self.config.extra
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
openai_send_message(builder).await
|
||||
}
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
content: &str,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
data: SendData,
|
||||
) -> Result<()> {
|
||||
let builder = self.request_builder(content, true)?;
|
||||
let builder = self.request_builder(client, data)?;
|
||||
openai_send_message_streaming(builder, handler).await
|
||||
}
|
||||
}
|
||||
|
||||
impl LocalAIClient {
|
||||
pub const NAME: &str = "localai";
|
||||
|
||||
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
|
||||
let model_info = global_config.read().model_info.clone();
|
||||
if model_info.client != LocalAIClient::name() {
|
||||
return None;
|
||||
}
|
||||
let local_config = {
|
||||
let config = {
|
||||
if let ClientConfig::LocalAI(c) = &global_config.read().clients[model_info.index] {
|
||||
c.clone()
|
||||
} else {
|
||||
@ -72,64 +76,37 @@ impl LocalAIClient {
|
||||
};
|
||||
Some(Box::new(Self {
|
||||
global_config,
|
||||
local_config,
|
||||
config,
|
||||
model_info,
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn name() -> &'static str {
|
||||
"localai"
|
||||
pub fn name(local_config: &LocalAIConfig) -> &str {
|
||||
local_config.name.as_deref().unwrap_or(Self::NAME)
|
||||
}
|
||||
|
||||
pub fn list_models(local_config: &LocalAIConfig, index: usize) -> Vec<ModelInfo> {
|
||||
let client = Self::name(local_config);
|
||||
|
||||
local_config
|
||||
.models
|
||||
.iter()
|
||||
.map(|v| ModelInfo::new(Self::name(), &v.name, v.max_tokens, index))
|
||||
.map(|v| ModelInfo::new(client, &v.name, v.max_tokens, index))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn create_config() -> Result<String> {
|
||||
let mut client_config = format!("clients:\n - type: {}\n", Self::name());
|
||||
let mut client_config = format!("clients:\n - type: {}\n", Self::NAME);
|
||||
|
||||
let url = Text::new("URL:")
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("An error happened when asking for url, try again later."))?;
|
||||
|
||||
client_config.push_str(&format!(" url: {url}\n"));
|
||||
|
||||
let ans = Confirm::new("Use auth?")
|
||||
.with_default(false)
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
|
||||
|
||||
if ans {
|
||||
let api_key = Text::new("API key:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for api key, try again later.")
|
||||
})?;
|
||||
let api_base = prompt_input_api_base()?;
|
||||
client_config.push_str(&format!(" api_base: {api_base}\n"));
|
||||
|
||||
let api_key = prompt_input_api_key_optional()?;
|
||||
client_config.push_str(&format!(" api_key: {api_key}\n"));
|
||||
}
|
||||
|
||||
let model_name = Text::new("Model Name:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for model name, try again later.")
|
||||
})?;
|
||||
let model_name = prompt_input_model_name()?;
|
||||
|
||||
let max_tokens = Text::new("Max tokens:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for max tokens, try again later.")
|
||||
})?;
|
||||
|
||||
let ans = Confirm::new("Use proxy?")
|
||||
.with_default(false)
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
|
||||
|
||||
if ans {
|
||||
let proxy = Text::new("Set proxy:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for proxy, try again later.")
|
||||
})?;
|
||||
client_config.push_str(&format!(" proxy: {proxy}\n"));
|
||||
}
|
||||
let max_tokens = prompt_input_max_token()?;
|
||||
|
||||
client_config.push_str(&format!(
|
||||
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
|
||||
@ -138,41 +115,27 @@ impl LocalAIClient {
|
||||
Ok(client_config)
|
||||
}
|
||||
|
||||
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
|
||||
let messages = self.global_config.read().build_messages(content)?;
|
||||
|
||||
let mut body = json!({
|
||||
"model": self.model_info.name,
|
||||
"messages": messages,
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
let api_key = self.config.api_key.clone();
|
||||
let api_key = api_key.or_else(|| {
|
||||
let env_prefix = Self::name(&self.config).to_uppercase();
|
||||
env::var(format!("{env_prefix}_API_KEY")).ok()
|
||||
});
|
||||
|
||||
if let Some(v) = self.global_config.read().get_temperature() {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("temperature".into(), json!(v)));
|
||||
}
|
||||
let body = openai_build_body(data, self.model_info.name.clone());
|
||||
|
||||
if stream {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("stream".into(), json!(true)));
|
||||
}
|
||||
let chat_endpoint = self
|
||||
.config
|
||||
.chat_endpoint
|
||||
.as_deref()
|
||||
.unwrap_or("/chat/completions");
|
||||
|
||||
let client = {
|
||||
let mut builder = ReqwestClient::builder();
|
||||
builder = set_proxy(builder, &self.local_config.proxy)?;
|
||||
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
|
||||
builder
|
||||
.connect_timeout(timeout)
|
||||
.build()
|
||||
.with_context(|| "Failed to build client")?
|
||||
};
|
||||
let url = format!("{}{chat_endpoint}", self.config.api_base);
|
||||
|
||||
let mut builder = client.post(&self.local_config.url);
|
||||
if let Some(api_key) = &self.local_config.api_key {
|
||||
builder = builder.bearer_auth(api_key);
|
||||
} else if let Ok(api_key) = env::var("LOCALAI_API_KEY") {
|
||||
let mut builder = client.post(url).json(&body);
|
||||
if let Some(api_key) = api_key {
|
||||
builder = builder.bearer_auth(api_key);
|
||||
}
|
||||
builder = builder.json(&body);
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
|
@ -8,20 +8,21 @@ use self::{
|
||||
openai::{OpenAIClient, OpenAIConfig},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::{ClientBuilder, Proxy};
|
||||
use serde::Deserialize;
|
||||
use std::{env, time::Duration};
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::{
|
||||
client::localai::LocalAIClient,
|
||||
config::{Config, SharedConfig},
|
||||
config::{Config, Message, SharedConfig},
|
||||
repl::{ReplyStreamHandler, SharedAbortSignal},
|
||||
utils::tokenize,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use inquire::{required, Text};
|
||||
use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy};
|
||||
use serde::Deserialize;
|
||||
use std::{env, time::Duration};
|
||||
use tokio::time::sleep;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ClientConfig {
|
||||
@ -32,7 +33,6 @@ pub enum ClientConfig {
|
||||
#[serde(rename = "azure-openai")]
|
||||
AzureOpenAI(AzureOpenAIConfig),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelInfo {
|
||||
pub client: String,
|
||||
@ -61,17 +61,43 @@ impl ModelInfo {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SendData {
|
||||
pub messages: Vec<Message>,
|
||||
pub temperature: Option<f64>,
|
||||
pub stream: bool,
|
||||
}
|
||||
#[async_trait]
|
||||
pub trait Client {
|
||||
fn get_config(&self) -> &SharedConfig;
|
||||
fn config(&self) -> &SharedConfig;
|
||||
|
||||
fn extra_config(&self) -> &Option<ExtraConfig>;
|
||||
|
||||
fn build_client(&self) -> Result<ReqwestClient> {
|
||||
let mut builder = ReqwestClient::builder();
|
||||
let options = self.extra_config();
|
||||
let timeout = options
|
||||
.as_ref()
|
||||
.and_then(|v| v.connect_timeout)
|
||||
.unwrap_or(10);
|
||||
let proxy = options.as_ref().and_then(|v| v.proxy.clone());
|
||||
builder = set_proxy(builder, &proxy)?;
|
||||
let client = builder
|
||||
.connect_timeout(Duration::from_secs(timeout))
|
||||
.build()
|
||||
.with_context(|| "Failed to build client")?;
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn send_message(&self, content: &str) -> Result<String> {
|
||||
init_tokio_runtime()?.block_on(async {
|
||||
if self.get_config().read().dry_run {
|
||||
let content = self.get_config().read().echo_messages(content);
|
||||
if self.config().read().dry_run {
|
||||
let content = self.config().read().echo_messages(content);
|
||||
return Ok(content);
|
||||
}
|
||||
self.send_message_inner(content)
|
||||
let client = self.build_client()?;
|
||||
let data = self.config().read().prepare_send_data(content, false)?;
|
||||
self.send_message_inner(&client, data)
|
||||
.await
|
||||
.with_context(|| "Failed to fetch")
|
||||
})
|
||||
@ -94,8 +120,8 @@ pub trait Client {
|
||||
init_tokio_runtime()?.block_on(async {
|
||||
tokio::select! {
|
||||
ret = async {
|
||||
if self.get_config().read().dry_run {
|
||||
let content = self.get_config().read().echo_messages(content);
|
||||
if self.config().read().dry_run {
|
||||
let content = self.config().read().echo_messages(content);
|
||||
let tokens = tokenize(&content);
|
||||
for token in tokens {
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
@ -103,7 +129,9 @@ pub trait Client {
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
self.send_message_streaming_inner(content, handler).await
|
||||
let client = self.build_client()?;
|
||||
let data = self.config().read().prepare_send_data(content, true)?;
|
||||
self.send_message_streaming_inner(&client, handler, data).await
|
||||
} => {
|
||||
handler.done()?;
|
||||
ret.with_context(|| "Failed to fetch stream")
|
||||
@ -120,15 +148,22 @@ pub trait Client {
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, content: &str) -> Result<String>;
|
||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String>;
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
content: &str,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
data: SendData,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct ExtraConfig {
|
||||
pub proxy: Option<String>,
|
||||
pub connect_timeout: Option<u64>,
|
||||
}
|
||||
|
||||
pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
|
||||
OpenAIClient::init(config.clone())
|
||||
.or_else(|| LocalAIClient::init(config.clone()))
|
||||
@ -143,20 +178,20 @@ pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn all_clients() -> Vec<&'static str> {
|
||||
pub fn list_client_types() -> Vec<&'static str> {
|
||||
vec![
|
||||
OpenAIClient::name(),
|
||||
LocalAIClient::name(),
|
||||
AzureOpenAIClient::name(),
|
||||
OpenAIClient::NAME,
|
||||
LocalAIClient::NAME,
|
||||
AzureOpenAIClient::NAME,
|
||||
]
|
||||
}
|
||||
|
||||
pub fn create_client_config(client: &str) -> Result<String> {
|
||||
if client == OpenAIClient::name() {
|
||||
if client == OpenAIClient::NAME {
|
||||
OpenAIClient::create_config()
|
||||
} else if client == LocalAIClient::name() {
|
||||
} else if client == LocalAIClient::NAME {
|
||||
LocalAIClient::create_config()
|
||||
} else if client == AzureOpenAIClient::name() {
|
||||
} else if client == AzureOpenAIClient::NAME {
|
||||
AzureOpenAIClient::create_config()
|
||||
} else {
|
||||
bail!("Unknown client {}", &client)
|
||||
@ -176,14 +211,51 @@ pub fn list_models(config: &Config) -> Vec<ModelInfo> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn init_tokio_runtime() -> Result<tokio::runtime::Runtime> {
|
||||
pub(crate) fn init_tokio_runtime() -> Result<tokio::runtime::Runtime> {
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.with_context(|| "Failed to init tokio")
|
||||
}
|
||||
|
||||
pub(crate) fn set_proxy(builder: ClientBuilder, proxy: &Option<String>) -> Result<ClientBuilder> {
|
||||
pub(crate) fn prompt_input_api_base() -> Result<String> {
|
||||
Text::new("API Base:")
|
||||
.with_validator(required!("This field is required"))
|
||||
.prompt()
|
||||
.map_err(prompt_op_err)
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_input_api_key() -> Result<String> {
|
||||
Text::new("API Key:")
|
||||
.with_validator(required!("This field is required"))
|
||||
.prompt()
|
||||
.map_err(prompt_op_err)
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_input_api_key_optional() -> Result<String> {
|
||||
Text::new("API Key:").prompt().map_err(prompt_op_err)
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_input_model_name() -> Result<String> {
|
||||
Text::new("Model Name:")
|
||||
.with_validator(required!("This field is required"))
|
||||
.prompt()
|
||||
.map_err(prompt_op_err)
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_input_max_token() -> Result<String> {
|
||||
Text::new("Max tokens:")
|
||||
.with_default("4096")
|
||||
.with_validator(required!("This field is required"))
|
||||
.prompt()
|
||||
.map_err(prompt_op_err)
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_op_err<T>(_: T) -> anyhow::Error {
|
||||
anyhow!("An error happened, try again later.")
|
||||
}
|
||||
|
||||
fn set_proxy(builder: ClientBuilder, proxy: &Option<String>) -> Result<ClientBuilder> {
|
||||
let proxy = if let Some(proxy) = proxy {
|
||||
if proxy.is_empty() || proxy == "false" || proxy == "-" {
|
||||
return Ok(builder);
|
||||
|
@ -1,65 +1,66 @@
|
||||
use super::{set_proxy, Client, ClientConfig, ModelInfo};
|
||||
use super::{prompt_input_api_key, Client, ClientConfig, ExtraConfig, ModelInfo, SendData};
|
||||
|
||||
use crate::config::SharedConfig;
|
||||
use crate::repl::ReplyStreamHandler;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use async_trait::async_trait;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures_util::StreamExt;
|
||||
use inquire::{Confirm, Text};
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::env;
|
||||
use std::time::Duration;
|
||||
|
||||
const API_BASE: &str = "https://api.openai.com/v1";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OpenAIClient {
|
||||
global_config: SharedConfig,
|
||||
local_config: OpenAIConfig,
|
||||
config: OpenAIConfig,
|
||||
model_info: ModelInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct OpenAIConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub organization_id: Option<String>,
|
||||
pub proxy: Option<String>,
|
||||
/// Set a timeout in seconds for connect to openai server
|
||||
pub connect_timeout: Option<u64>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Client for OpenAIClient {
|
||||
fn get_config(&self) -> &SharedConfig {
|
||||
fn config(&self) -> &SharedConfig {
|
||||
&self.global_config
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, content: &str) -> Result<String> {
|
||||
let builder = self.request_builder(content, false)?;
|
||||
fn extra_config(&self) -> &Option<ExtraConfig> {
|
||||
&self.config.extra
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
openai_send_message(builder).await
|
||||
}
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
content: &str,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
data: SendData,
|
||||
) -> Result<()> {
|
||||
let builder = self.request_builder(content, true)?;
|
||||
let builder = self.request_builder(client, data)?;
|
||||
openai_send_message_streaming(builder, handler).await
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
pub const NAME: &str = "openai";
|
||||
|
||||
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
|
||||
let model_info = global_config.read().model_info.clone();
|
||||
if model_info.client != OpenAIClient::name() {
|
||||
return None;
|
||||
}
|
||||
let local_config = {
|
||||
let config = {
|
||||
if let ClientConfig::OpenAI(c) = &global_config.read().clients[model_info.index] {
|
||||
c.clone()
|
||||
} else {
|
||||
@ -68,16 +69,18 @@ impl OpenAIClient {
|
||||
};
|
||||
Some(Box::new(Self {
|
||||
global_config,
|
||||
local_config,
|
||||
config,
|
||||
model_info,
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn name() -> &'static str {
|
||||
"openai"
|
||||
pub fn name(local_config: &OpenAIConfig) -> &str {
|
||||
local_config.name.as_deref().unwrap_or(Self::NAME)
|
||||
}
|
||||
|
||||
pub fn list_models(_local_config: &OpenAIConfig, index: usize) -> Vec<ModelInfo> {
|
||||
pub fn list_models(local_config: &OpenAIConfig, index: usize) -> Vec<ModelInfo> {
|
||||
let client = Self::name(local_config);
|
||||
|
||||
[
|
||||
("gpt-3.5-turbo", 4096),
|
||||
("gpt-3.5-turbo-16k", 16384),
|
||||
@ -85,90 +88,38 @@ impl OpenAIClient {
|
||||
("gpt-4-32k", 32768),
|
||||
]
|
||||
.into_iter()
|
||||
.map(|(name, max_tokens)| ModelInfo::new(Self::name(), name, max_tokens, index))
|
||||
.map(|(name, max_tokens)| ModelInfo::new(client, name, max_tokens, index))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn create_config() -> Result<String> {
|
||||
let mut client_config = format!("clients:\n - type: {}\n", Self::name());
|
||||
|
||||
let api_key = Text::new("API key:")
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("An error happened when asking for api key, try again later."))?;
|
||||
let mut client_config = format!("clients:\n - type: {}\n", Self::NAME);
|
||||
|
||||
let api_key = prompt_input_api_key()?;
|
||||
client_config.push_str(&format!(" api_key: {api_key}\n"));
|
||||
|
||||
let ans = Confirm::new("Has Organization?")
|
||||
.with_default(false)
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
|
||||
|
||||
if ans {
|
||||
let organization_id = Text::new("Organization ID:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for proxy, try again later.")
|
||||
})?;
|
||||
client_config.push_str(&format!(" organization_id: {organization_id}\n"));
|
||||
}
|
||||
|
||||
let ans = Confirm::new("Use proxy?")
|
||||
.with_default(false)
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
|
||||
|
||||
if ans {
|
||||
let proxy = Text::new("Set proxy:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for proxy, try again later.")
|
||||
})?;
|
||||
client_config.push_str(&format!(" proxy: {proxy}\n"));
|
||||
}
|
||||
|
||||
Ok(client_config)
|
||||
}
|
||||
|
||||
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
|
||||
let api_key = if let Some(api_key) = &self.local_config.api_key {
|
||||
api_key.to_string()
|
||||
} else if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||
api_key.to_string()
|
||||
} else {
|
||||
bail!("Miss api_key")
|
||||
};
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
let env_prefix = Self::name(&self.config).to_uppercase();
|
||||
|
||||
let messages = self.global_config.read().build_messages(content)?;
|
||||
let api_key = self.config.api_key.clone();
|
||||
let api_key = api_key
|
||||
.or_else(|| env::var(format!("{env_prefix}_API_KEY")).ok())
|
||||
.ok_or_else(|| anyhow!("Miss api_key"))?;
|
||||
|
||||
let mut body = json!({
|
||||
"model": self.model_info.name,
|
||||
"messages": messages,
|
||||
});
|
||||
let body = openai_build_body(data, self.model_info.name.clone());
|
||||
|
||||
if let Some(v) = self.global_config.read().get_temperature() {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("temperature".into(), json!(v)));
|
||||
}
|
||||
|
||||
if stream {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("stream".into(), json!(true)));
|
||||
}
|
||||
|
||||
let client = {
|
||||
let mut builder = ReqwestClient::builder();
|
||||
builder = set_proxy(builder, &self.local_config.proxy)?;
|
||||
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
|
||||
builder
|
||||
.connect_timeout(timeout)
|
||||
.build()
|
||||
.with_context(|| "Failed to build client")?
|
||||
};
|
||||
|
||||
let api_base = env::var("OPENAI_API_BASE")
|
||||
let api_base = env::var(format!("{env_prefix}_API_BASE"))
|
||||
.ok()
|
||||
.unwrap_or_else(|| API_BASE.to_string());
|
||||
|
||||
let url = format!("{api_base}/chat/completions");
|
||||
|
||||
let mut builder = client.post(url).bearer_auth(api_key).json(&body);
|
||||
|
||||
if let Some(organization_id) = &self.local_config.organization_id {
|
||||
if let Some(organization_id) = &self.config.organization_id {
|
||||
builder = builder.header("OpenAI-Organization", organization_id);
|
||||
}
|
||||
|
||||
@ -219,3 +170,26 @@ pub(crate) async fn openai_send_message_streaming(
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn openai_build_body(data: SendData, model: String) -> Value {
|
||||
let SendData {
|
||||
messages,
|
||||
temperature,
|
||||
stream,
|
||||
} = data;
|
||||
let mut body = json!({
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
});
|
||||
|
||||
if let Some(v) = temperature {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("temperature".into(), json!(v)));
|
||||
}
|
||||
|
||||
if stream {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("stream".into(), json!(true)));
|
||||
}
|
||||
body
|
||||
}
|
||||
|
@ -2,12 +2,15 @@ mod message;
|
||||
mod role;
|
||||
mod session;
|
||||
|
||||
use self::message::Message;
|
||||
pub use self::message::Message;
|
||||
use self::role::Role;
|
||||
use self::session::{Session, TEMP_SESSION_NAME};
|
||||
|
||||
use crate::client::openai::{OpenAIClient, OpenAIConfig};
|
||||
use crate::client::{all_clients, create_client_config, list_models, ClientConfig, ModelInfo};
|
||||
use crate::client::{
|
||||
create_client_config, list_client_types, list_models, prompt_op_err, ClientConfig, ExtraConfig,
|
||||
ModelInfo, SendData,
|
||||
};
|
||||
use crate::config::message::num_tokens_from_messages;
|
||||
use crate::render::RenderOptions;
|
||||
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now};
|
||||
@ -94,7 +97,7 @@ impl Default for Config {
|
||||
Self {
|
||||
model: None,
|
||||
default_temperature: None,
|
||||
save: false,
|
||||
save: true,
|
||||
highlight: true,
|
||||
dry_run: false,
|
||||
light_theme: false,
|
||||
@ -136,19 +139,17 @@ impl Config {
|
||||
config.compat_old_config(&config_path)?;
|
||||
}
|
||||
|
||||
if let Some(name) = config.model.clone() {
|
||||
config.set_model(&name)?;
|
||||
}
|
||||
if let Some(wrap) = config.wrap.clone() {
|
||||
config.set_wrap(&wrap)?;
|
||||
}
|
||||
|
||||
config.temperature = config.default_temperature;
|
||||
|
||||
config.set_model_info()?;
|
||||
config.merge_env_vars();
|
||||
config.load_roles()?;
|
||||
config.ensure_sessions_dir()?;
|
||||
config.check_term_theme()?;
|
||||
config.detect_theme()?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
@ -327,7 +328,7 @@ impl Config {
|
||||
model_info = Some(model.clone());
|
||||
}
|
||||
match model_info {
|
||||
None => bail!("Invalid model"),
|
||||
None => bail!("Unknown model '{}'", value),
|
||||
Some(model_info) => {
|
||||
if let Some(session) = self.session.as_mut() {
|
||||
session.set_model(&model_info.stringify())?;
|
||||
@ -555,6 +556,15 @@ impl Config {
|
||||
Ok(RenderOptions::new(theme, wrap, self.wrap_code))
|
||||
}
|
||||
|
||||
pub fn prepare_send_data(&self, content: &str, stream: bool) -> Result<SendData> {
|
||||
let messages = self.build_messages(content)?;
|
||||
Ok(SendData {
|
||||
messages,
|
||||
temperature: self.get_temperature(),
|
||||
stream,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn maybe_print_send_tokens(&self, input: &str) {
|
||||
if self.dry_run {
|
||||
if let Ok(messages) = self.build_messages(input) {
|
||||
@ -596,6 +606,22 @@ impl Config {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_model_info(&mut self) -> Result<()> {
|
||||
let model = match &self.model {
|
||||
Some(v) => v.clone(),
|
||||
None => {
|
||||
let models = self::list_models(self);
|
||||
if models.is_empty() {
|
||||
bail!("No available model");
|
||||
}
|
||||
|
||||
models[0].stringify()
|
||||
}
|
||||
};
|
||||
self.set_model(&model)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn merge_env_vars(&mut self) {
|
||||
if let Ok(value) = env::var("NO_COLOR") {
|
||||
let mut no_color = false;
|
||||
@ -616,7 +642,7 @@ impl Config {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_term_theme(&mut self) -> Result<()> {
|
||||
fn detect_theme(&mut self) -> Result<()> {
|
||||
if self.light_theme {
|
||||
return Ok(());
|
||||
}
|
||||
@ -640,7 +666,7 @@ impl Config {
|
||||
|
||||
if let Some(model_name) = value.get("model").and_then(|v| v.as_str()) {
|
||||
if model_name.starts_with("gpt") {
|
||||
self.model = Some(format!("{}:{}", OpenAIClient::name(), model_name));
|
||||
self.model = Some(format!("{}:{}", OpenAIClient::NAME, model_name));
|
||||
}
|
||||
}
|
||||
|
||||
@ -653,13 +679,17 @@ impl Config {
|
||||
client_config.organization_id = Some(organization_id.to_string())
|
||||
}
|
||||
|
||||
let mut extra_config = ExtraConfig::default();
|
||||
|
||||
if let Some(proxy) = value.get("proxy").and_then(|v| v.as_str()) {
|
||||
client_config.proxy = Some(proxy.to_string())
|
||||
extra_config.proxy = Some(proxy.to_string())
|
||||
}
|
||||
|
||||
if let Some(connect_timeout) = value.get("connect_timeout").and_then(|v| v.as_i64()) {
|
||||
client_config.connect_timeout = Some(connect_timeout as _)
|
||||
extra_config.connect_timeout = Some(connect_timeout as _)
|
||||
}
|
||||
|
||||
client_config.extra = Some(extra_config);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@ -690,27 +720,19 @@ fn create_config_file(config_path: &Path) -> Result<()> {
|
||||
let ans = Confirm::new("No config file, create a new one?")
|
||||
.with_default(true)
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
|
||||
.map_err(prompt_op_err)?;
|
||||
if !ans {
|
||||
exit(0);
|
||||
}
|
||||
|
||||
let client = Select::new("Select AI?", all_clients())
|
||||
let client = Select::new("AI Platform:", list_client_types())
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("An error happened when selecting platform, try again later."))?;
|
||||
.map_err(prompt_op_err)?;
|
||||
|
||||
let mut raw_config = create_client_config(client)?;
|
||||
|
||||
raw_config.push_str(&format!("model: {client}\n"));
|
||||
|
||||
let ans = Confirm::new("Save chat messages")
|
||||
.with_default(true)
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
|
||||
|
||||
if ans {
|
||||
raw_config.push_str("save: true\n");
|
||||
}
|
||||
ensure_parent_exists(config_path)?;
|
||||
std::fs::write(config_path, raw_config).with_context(|| "Failed to write to config file")?;
|
||||
#[cfg(unix)]
|
||||
|
Loading…
Reference in New Issue
Block a user