feat: add new ernie models (#340)

pull/343/head
sigoden 3 months ago committed by GitHub
parent 8e5d4e55b1
commit 20c78d6f15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -14,11 +14,23 @@ use std::env;
const API_BASE: &str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1";
const ACCESS_TOKEN_URL: &str = "https://aip.baidubce.com/oauth/2.0/token";
const MODELS: [(&str, &str); 4] = [
("ernie-bot", "/wenxinworkshop/chat/completions"),
("ernie-bot-4", "/wenxinworkshop/chat/completions_pro"),
("ernie-bot-8k", "/wenxinworkshop/chat/ernie_bot_8k"),
("ernie-bot-turbo", "/wenxinworkshop/chat/eb-instant"),
const MODELS: [(&str, usize, &str); 7] = [
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
("ernie-bot-4", 5120, "/wenxinworkshop/chat/completions_pro"),
("ernie-bot-8k", 5120, "/wenxinworkshop/chat/ernie_bot_8k"),
("ernie-bot", 2048, "/wenxinworkshop/chat/completions"),
(
"ernie-3.5-4k-0205",
2048,
"/wenxinworkshop/chat/ernie-3.5-4k-0205",
),
(
"ernie-3.5-8k-0205",
5120,
"/wenxinworkshop/chat/ernie-3.5-8k-0205",
),
("ernie-speed", 7168, "/wenxinworkshop/chat/ernie_speed"),
("ernie-bot-turbo", 7168, "/wenxinworkshop/chat/eb-instant"),
];
static mut ACCESS_TOKEN: String = String::new(); // safe under linear operation
@ -63,7 +75,9 @@ impl ErnieClient {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, _)| Model::new(client_name, name))
.map(|(name, max_input_tokens, _)| {
Model::new(client_name, name).set_max_input_tokens(Some(max_input_tokens))
})
.collect()
}
@ -71,9 +85,9 @@ impl ErnieClient {
let body = build_body(data, self.model.name.clone());
let model = self.model.name.clone();
let (_, chat_endpoint) = MODELS
let (_, _, chat_endpoint) = MODELS
.iter()
.find(|(v, _)| v == &model)
.find(|(v, _, _)| v == &model)
.ok_or_else(|| anyhow!("Miss Model '{}'", self.model.id()))?;
let url = format!("{API_BASE}{chat_endpoint}?access_token={}", unsafe {

Loading…
Cancel
Save