|
|
|
@ -18,27 +18,50 @@ use std::{env, sync::Mutex};
|
|
|
|
|
const API_BASE: &str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1";
|
|
|
|
|
const ACCESS_TOKEN_URL: &str = "https://aip.baidubce.com/oauth/2.0/token";
|
|
|
|
|
|
|
|
|
|
const MODELS: [(&str, usize, &str); 7] = [
|
|
|
|
|
const MODELS: [(&str, &str, usize, isize); 7] = [
|
|
|
|
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
|
|
|
|
("ernie-4.0-8k", 5120, "/wenxinworkshop/chat/completions_pro"),
|
|
|
|
|
(
|
|
|
|
|
"ernie-3.5-8k",
|
|
|
|
|
"ernie-4.0-8k",
|
|
|
|
|
"/wenxinworkshop/chat/completions_pro",
|
|
|
|
|
5120,
|
|
|
|
|
2048,
|
|
|
|
|
),
|
|
|
|
|
(
|
|
|
|
|
"ernie-3.5-8k",
|
|
|
|
|
"/wenxinworkshop/chat/ernie-3.5-8k-0205",
|
|
|
|
|
5120,
|
|
|
|
|
2048,
|
|
|
|
|
),
|
|
|
|
|
(
|
|
|
|
|
"ernie-3.5-4k",
|
|
|
|
|
2048,
|
|
|
|
|
"/wenxinworkshop/chat/ernie-3.5-4k-0205",
|
|
|
|
|
2048,
|
|
|
|
|
2048,
|
|
|
|
|
),
|
|
|
|
|
(
|
|
|
|
|
"ernie-speed-8k",
|
|
|
|
|
"/wenxinworkshop/chat/ernie_speed",
|
|
|
|
|
7168,
|
|
|
|
|
2048,
|
|
|
|
|
),
|
|
|
|
|
("ernie-speed-8k", 7168, "/wenxinworkshop/chat/ernie_speed"),
|
|
|
|
|
(
|
|
|
|
|
"ernie-speed-128k",
|
|
|
|
|
124000,
|
|
|
|
|
"/wenxinworkshop/chat/ernie-speed-128k",
|
|
|
|
|
124000,
|
|
|
|
|
4096,
|
|
|
|
|
),
|
|
|
|
|
(
|
|
|
|
|
"ernie-lite-8k",
|
|
|
|
|
"/wenxinworkshop/chat/ernie-lite-8k",
|
|
|
|
|
7168,
|
|
|
|
|
2048,
|
|
|
|
|
),
|
|
|
|
|
(
|
|
|
|
|
"ernie-tiny-8k",
|
|
|
|
|
"/wenxinworkshop/chat/ernie-tiny-8k",
|
|
|
|
|
7168,
|
|
|
|
|
2048,
|
|
|
|
|
),
|
|
|
|
|
("ernie-lite-8k", 7168, "/wenxinworkshop/chat/ernie-lite-8k"),
|
|
|
|
|
("ernie-tiny-8k", 7168, "/wenxinworkshop/chat/ernie-tiny-8k"),
|
|
|
|
|
];
|
|
|
|
|
|
|
|
|
|
lazy_static! {
|
|
|
|
@ -85,17 +108,21 @@ impl ErnieClient {
|
|
|
|
|
let client_name = Self::name(local_config);
|
|
|
|
|
MODELS
|
|
|
|
|
.into_iter()
|
|
|
|
|
.map(|(name, _, _)| Model::new(client_name, name)) // ERNIE tokenizer is different from cl100k_base
|
|
|
|
|
.map(|(name, _, max_input_tokens, max_output_tokens)| {
|
|
|
|
|
Model::new(client_name, name)
|
|
|
|
|
.set_max_input_tokens(Some(max_input_tokens))
|
|
|
|
|
.set_max_output_tokens(Some(max_output_tokens))
|
|
|
|
|
}) // ERNIE tokenizer is different from cl100k_base
|
|
|
|
|
.collect()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
|
|
|
|
let body = build_body(data, self.model.name.clone());
|
|
|
|
|
let body = build_body(data, &self.model);
|
|
|
|
|
|
|
|
|
|
let model = self.model.name.clone();
|
|
|
|
|
let (_, _, chat_endpoint) = MODELS
|
|
|
|
|
let model = &self.model.name;
|
|
|
|
|
let (_, chat_endpoint, _, _) = MODELS
|
|
|
|
|
.iter()
|
|
|
|
|
.find(|(v, _, _)| v == &model)
|
|
|
|
|
.find(|(v, _, _, _)| v == model)
|
|
|
|
|
.ok_or_else(|| anyhow!("Miss Model '{}'", self.model.id()))?;
|
|
|
|
|
|
|
|
|
|
let access_token = ACCESS_TOKEN
|
|
|
|
@ -207,7 +234,7 @@ fn check_error(data: &Value) -> Result<()> {
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn build_body(data: SendData, _model: String) -> Value {
|
|
|
|
|
fn build_body(data: SendData, model: &Model) -> Value {
|
|
|
|
|
let SendData {
|
|
|
|
|
mut messages,
|
|
|
|
|
temperature,
|
|
|
|
@ -223,6 +250,11 @@ fn build_body(data: SendData, _model: String) -> Value {
|
|
|
|
|
if let Some(temperature) = temperature {
|
|
|
|
|
body["temperature"] = temperature.into();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if let Some(max_output_tokens) = model.max_output_tokens {
|
|
|
|
|
body["max_output_tokens"] = max_output_tokens.into();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if stream {
|
|
|
|
|
body["stream"] = true.into();
|
|
|
|
|
}
|
|
|
|
|