|
|
@ -1,34 +1,33 @@
|
|
|
|
use super::openai::{openai_send_message, openai_send_message_streaming};
|
|
|
|
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
|
|
|
|
use super::{set_proxy, Client, ClientConfig, ModelInfo};
|
|
|
|
use super::{
|
|
|
|
|
|
|
|
prompt_input_api_base, prompt_input_api_key, prompt_input_max_token, prompt_input_model_name,
|
|
|
|
|
|
|
|
Client, ClientConfig, ExtraConfig, ModelInfo, SendData,
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
use crate::config::SharedConfig;
|
|
|
|
use crate::config::SharedConfig;
|
|
|
|
use crate::repl::ReplyStreamHandler;
|
|
|
|
use crate::repl::ReplyStreamHandler;
|
|
|
|
|
|
|
|
|
|
|
|
use anyhow::{anyhow, Context, Result};
|
|
|
|
use anyhow::{anyhow, Result};
|
|
|
|
use async_trait::async_trait;
|
|
|
|
use async_trait::async_trait;
|
|
|
|
use inquire::{Confirm, Text};
|
|
|
|
|
|
|
|
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
|
|
|
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
|
|
|
use serde::Deserialize;
|
|
|
|
use serde::Deserialize;
|
|
|
|
use serde_json::json;
|
|
|
|
|
|
|
|
use std::env;
|
|
|
|
use std::env;
|
|
|
|
use std::time::Duration;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[allow(clippy::module_name_repetitions)]
|
|
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
#[derive(Debug)]
|
|
|
|
pub struct AzureOpenAIClient {
|
|
|
|
pub struct AzureOpenAIClient {
|
|
|
|
global_config: SharedConfig,
|
|
|
|
global_config: SharedConfig,
|
|
|
|
local_config: AzureOpenAIConfig,
|
|
|
|
config: AzureOpenAIConfig,
|
|
|
|
model_info: ModelInfo,
|
|
|
|
model_info: ModelInfo,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
|
|
pub struct AzureOpenAIConfig {
|
|
|
|
pub struct AzureOpenAIConfig {
|
|
|
|
|
|
|
|
pub name: Option<String>,
|
|
|
|
pub api_base: String,
|
|
|
|
pub api_base: String,
|
|
|
|
pub api_key: Option<String>,
|
|
|
|
pub api_key: Option<String>,
|
|
|
|
pub models: Vec<AzureOpenAIModel>,
|
|
|
|
pub models: Vec<AzureOpenAIModel>,
|
|
|
|
pub proxy: Option<String>,
|
|
|
|
pub extra: Option<ExtraConfig>,
|
|
|
|
/// Set a timeout in seconds for connect to server
|
|
|
|
|
|
|
|
pub connect_timeout: Option<u64>,
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
|
@ -39,32 +38,36 @@ pub struct AzureOpenAIModel {
|
|
|
|
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
#[async_trait]
|
|
|
|
impl Client for AzureOpenAIClient {
|
|
|
|
impl Client for AzureOpenAIClient {
|
|
|
|
fn get_config(&self) -> &SharedConfig {
|
|
|
|
fn config(&self) -> &SharedConfig {
|
|
|
|
&self.global_config
|
|
|
|
&self.global_config
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async fn send_message_inner(&self, content: &str) -> Result<String> {
|
|
|
|
fn extra_config(&self) -> &Option<ExtraConfig> {
|
|
|
|
let builder = self.request_builder(content, false)?;
|
|
|
|
&self.config.extra
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
|
|
|
|
|
|
|
let builder = self.request_builder(client, data)?;
|
|
|
|
openai_send_message(builder).await
|
|
|
|
openai_send_message(builder).await
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async fn send_message_streaming_inner(
|
|
|
|
async fn send_message_streaming_inner(
|
|
|
|
&self,
|
|
|
|
&self,
|
|
|
|
content: &str,
|
|
|
|
client: &ReqwestClient,
|
|
|
|
handler: &mut ReplyStreamHandler,
|
|
|
|
handler: &mut ReplyStreamHandler,
|
|
|
|
|
|
|
|
data: SendData,
|
|
|
|
) -> Result<()> {
|
|
|
|
) -> Result<()> {
|
|
|
|
let builder = self.request_builder(content, true)?;
|
|
|
|
let builder = self.request_builder(client, data)?;
|
|
|
|
openai_send_message_streaming(builder, handler).await
|
|
|
|
openai_send_message_streaming(builder, handler).await
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl AzureOpenAIClient {
|
|
|
|
impl AzureOpenAIClient {
|
|
|
|
|
|
|
|
pub const NAME: &str = "azure-openai";
|
|
|
|
|
|
|
|
|
|
|
|
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
|
|
|
|
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
|
|
|
|
let model_info = global_config.read().model_info.clone();
|
|
|
|
let model_info = global_config.read().model_info.clone();
|
|
|
|
if model_info.client != AzureOpenAIClient::name() {
|
|
|
|
let config = {
|
|
|
|
return None;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
let local_config = {
|
|
|
|
|
|
|
|
if let ClientConfig::AzureOpenAI(c) = &global_config.read().clients[model_info.index] {
|
|
|
|
if let ClientConfig::AzureOpenAI(c) = &global_config.read().clients[model_info.index] {
|
|
|
|
c.clone()
|
|
|
|
c.clone()
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
@ -73,59 +76,37 @@ impl AzureOpenAIClient {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
Some(Box::new(Self {
|
|
|
|
Some(Box::new(Self {
|
|
|
|
global_config,
|
|
|
|
global_config,
|
|
|
|
local_config,
|
|
|
|
config,
|
|
|
|
model_info,
|
|
|
|
model_info,
|
|
|
|
}))
|
|
|
|
}))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
pub fn name() -> &'static str {
|
|
|
|
pub fn name(local_config: &AzureOpenAIConfig) -> &str {
|
|
|
|
"azure-openai"
|
|
|
|
local_config.name.as_deref().unwrap_or(Self::NAME)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
pub fn list_models(local_config: &AzureOpenAIConfig, index: usize) -> Vec<ModelInfo> {
|
|
|
|
pub fn list_models(local_config: &AzureOpenAIConfig, index: usize) -> Vec<ModelInfo> {
|
|
|
|
|
|
|
|
let client = Self::name(local_config);
|
|
|
|
|
|
|
|
|
|
|
|
local_config
|
|
|
|
local_config
|
|
|
|
.models
|
|
|
|
.models
|
|
|
|
.iter()
|
|
|
|
.iter()
|
|
|
|
.map(|v| ModelInfo::new(Self::name(), &v.name, v.max_tokens, index))
|
|
|
|
.map(|v| ModelInfo::new(client, &v.name, v.max_tokens, index))
|
|
|
|
.collect()
|
|
|
|
.collect()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
pub fn create_config() -> Result<String> {
|
|
|
|
pub fn create_config() -> Result<String> {
|
|
|
|
let mut client_config = format!("clients:\n - type: {}\n", Self::name());
|
|
|
|
let mut client_config = format!("clients:\n - type: {}\n", Self::NAME);
|
|
|
|
|
|
|
|
|
|
|
|
let api_base = Text::new("api_base:")
|
|
|
|
|
|
|
|
.prompt()
|
|
|
|
|
|
|
|
.map_err(|_| anyhow!("An error happened when asking for api base, try again later."))?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let api_base = prompt_input_api_base()?;
|
|
|
|
client_config.push_str(&format!(" api_base: {api_base}\n"));
|
|
|
|
client_config.push_str(&format!(" api_base: {api_base}\n"));
|
|
|
|
|
|
|
|
|
|
|
|
if env::var("AZURE_OPENAI_KEY").is_err() {
|
|
|
|
let api_key = prompt_input_api_key()?;
|
|
|
|
let api_key = Text::new("API key:").prompt().map_err(|_| {
|
|
|
|
|
|
|
|
anyhow!("An error happened when asking for api key, try again later.")
|
|
|
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client_config.push_str(&format!(" api_key: {api_key}\n"));
|
|
|
|
client_config.push_str(&format!(" api_key: {api_key}\n"));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let model_name = Text::new("Model Name:").prompt().map_err(|_| {
|
|
|
|
|
|
|
|
anyhow!("An error happened when asking for model name, try again later.")
|
|
|
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let max_tokens = Text::new("Max tokens:").prompt().map_err(|_| {
|
|
|
|
let model_name = prompt_input_model_name()?;
|
|
|
|
anyhow!("An error happened when asking for max tokens, try again later.")
|
|
|
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let ans = Confirm::new("Use proxy?")
|
|
|
|
let max_tokens = prompt_input_max_token()?;
|
|
|
|
.with_default(false)
|
|
|
|
|
|
|
|
.prompt()
|
|
|
|
|
|
|
|
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ans {
|
|
|
|
|
|
|
|
let proxy = Text::new("Set proxy:").prompt().map_err(|_| {
|
|
|
|
|
|
|
|
anyhow!("An error happened when asking for proxy, try again later.")
|
|
|
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
client_config.push_str(&format!(" proxy: {proxy}\n"));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client_config.push_str(&format!(
|
|
|
|
client_config.push_str(&format!(
|
|
|
|
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
|
|
|
|
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
|
|
|
@ -134,50 +115,26 @@ impl AzureOpenAIClient {
|
|
|
|
Ok(client_config)
|
|
|
|
Ok(client_config)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
|
|
|
|
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
|
|
|
let messages = self.global_config.read().build_messages(content)?;
|
|
|
|
let api_key = self.config.api_key.clone();
|
|
|
|
|
|
|
|
let api_key = api_key
|
|
|
|
let mut body = json!({
|
|
|
|
.or_else(|| {
|
|
|
|
"messages": messages,
|
|
|
|
let env_prefix = match &self.config.name {
|
|
|
|
});
|
|
|
|
None => "AZURE".into(),
|
|
|
|
|
|
|
|
Some(v) => v.to_uppercase(),
|
|
|
|
if let Some(v) = self.global_config.read().get_temperature() {
|
|
|
|
|
|
|
|
body.as_object_mut()
|
|
|
|
|
|
|
|
.and_then(|m| m.insert("temperature".into(), json!(v)));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if stream {
|
|
|
|
|
|
|
|
body.as_object_mut()
|
|
|
|
|
|
|
|
.and_then(|m| m.insert("stream".into(), json!(true)));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let client = {
|
|
|
|
|
|
|
|
let mut builder = ReqwestClient::builder();
|
|
|
|
|
|
|
|
builder = set_proxy(builder, &self.local_config.proxy)?;
|
|
|
|
|
|
|
|
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
|
|
|
|
|
|
|
|
builder
|
|
|
|
|
|
|
|
.connect_timeout(timeout)
|
|
|
|
|
|
|
|
.build()
|
|
|
|
|
|
|
|
.with_context(|| "Failed to build client")?
|
|
|
|
|
|
|
|
};
|
|
|
|
};
|
|
|
|
let mut api_base = self.local_config.api_base.clone();
|
|
|
|
env::var(format!("{env_prefix}_OPENAI_KEY")).ok()
|
|
|
|
if !api_base.ends_with('/') {
|
|
|
|
})
|
|
|
|
api_base = format!("{api_base}/");
|
|
|
|
.ok_or_else(|| anyhow!("Miss api_key"))?;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
let body = openai_build_body(data, self.model_info.name.clone());
|
|
|
|
|
|
|
|
|
|
|
|
let url = format!(
|
|
|
|
let url = format!(
|
|
|
|
"{api_base}openai/deployments/{}/chat/completions?api-version=2023-05-15",
|
|
|
|
"{}/openai/deployments/{}/chat/completions?api-version=2023-05-15",
|
|
|
|
self.model_info.name
|
|
|
|
self.config.api_base, self.model_info.name
|
|
|
|
);
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
|
|
let mut builder = client.post(url);
|
|
|
|
let builder = client.post(url).header("api-key", api_key).json(&body);
|
|
|
|
|
|
|
|
|
|
|
|
if let Some(api_key) = &self.local_config.api_key {
|
|
|
|
|
|
|
|
builder = builder.header("api-key", api_key)
|
|
|
|
|
|
|
|
} else if let Ok(api_key) = env::var("AZURE_OPENAI_KEY") {
|
|
|
|
|
|
|
|
builder = builder.header("api-key", api_key)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
builder = builder.json(&body);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ok(builder)
|
|
|
|
Ok(builder)
|
|
|
|
}
|
|
|
|
}
|
|
|
|