|
|
|
@ -14,11 +14,23 @@ use std::env;
|
|
|
|
|
const API_BASE: &str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1";
|
|
|
|
|
const ACCESS_TOKEN_URL: &str = "https://aip.baidubce.com/oauth/2.0/token";
|
|
|
|
|
|
|
|
|
|
const MODELS: [(&str, &str); 4] = [
|
|
|
|
|
("ernie-bot", "/wenxinworkshop/chat/completions"),
|
|
|
|
|
("ernie-bot-4", "/wenxinworkshop/chat/completions_pro"),
|
|
|
|
|
("ernie-bot-8k", "/wenxinworkshop/chat/ernie_bot_8k"),
|
|
|
|
|
("ernie-bot-turbo", "/wenxinworkshop/chat/eb-instant"),
|
|
|
|
|
const MODELS: [(&str, usize, &str); 7] = [
|
|
|
|
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
|
|
|
|
("ernie-bot-4", 5120, "/wenxinworkshop/chat/completions_pro"),
|
|
|
|
|
("ernie-bot-8k", 5120, "/wenxinworkshop/chat/ernie_bot_8k"),
|
|
|
|
|
("ernie-bot", 2048, "/wenxinworkshop/chat/completions"),
|
|
|
|
|
(
|
|
|
|
|
"ernie-3.5-4k-0205",
|
|
|
|
|
2048,
|
|
|
|
|
"/wenxinworkshop/chat/ernie-3.5-4k-0205",
|
|
|
|
|
),
|
|
|
|
|
(
|
|
|
|
|
"ernie-3.5-8k-0205",
|
|
|
|
|
5120,
|
|
|
|
|
"/wenxinworkshop/chat/ernie-3.5-8k-0205",
|
|
|
|
|
),
|
|
|
|
|
("ernie-speed", 7168, "/wenxinworkshop/chat/ernie_speed"),
|
|
|
|
|
("ernie-bot-turbo", 7168, "/wenxinworkshop/chat/eb-instant"),
|
|
|
|
|
];
|
|
|
|
|
|
|
|
|
|
static mut ACCESS_TOKEN: String = String::new(); // safe under linear operation
|
|
|
|
@ -63,7 +75,9 @@ impl ErnieClient {
|
|
|
|
|
let client_name = Self::name(local_config);
|
|
|
|
|
MODELS
|
|
|
|
|
.into_iter()
|
|
|
|
|
.map(|(name, _)| Model::new(client_name, name))
|
|
|
|
|
.map(|(name, max_input_tokens, _)| {
|
|
|
|
|
Model::new(client_name, name).set_max_input_tokens(Some(max_input_tokens))
|
|
|
|
|
})
|
|
|
|
|
.collect()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -71,9 +85,9 @@ impl ErnieClient {
|
|
|
|
|
let body = build_body(data, self.model.name.clone());
|
|
|
|
|
|
|
|
|
|
let model = self.model.name.clone();
|
|
|
|
|
let (_, chat_endpoint) = MODELS
|
|
|
|
|
let (_, _, chat_endpoint) = MODELS
|
|
|
|
|
.iter()
|
|
|
|
|
.find(|(v, _)| v == &model)
|
|
|
|
|
.find(|(v, _, _)| v == &model)
|
|
|
|
|
.ok_or_else(|| anyhow!("Miss Model '{}'", self.model.id()))?;
|
|
|
|
|
|
|
|
|
|
let url = format!("{API_BASE}{chat_endpoint}?access_token={}", unsafe {
|
|
|
|
|