mirror of
https://github.com/sigoden/aichat
synced 2024-11-16 06:15:26 +00:00
feat: support cohere (#397)
This commit is contained in:
parent
ce1f9929f2
commit
5915bc2f3a
@ -21,6 +21,7 @@ Chat REPL mode:
|
||||
- Gemini (free, vision)
|
||||
- Claude: (paid)
|
||||
- Mistral (paid)
|
||||
- Cohere (paid)
|
||||
- OpenAI-Compatible (local)
|
||||
- Ollama (free, local)
|
||||
- Azure-OpenAI (paid)
|
||||
|
@ -45,9 +45,14 @@ clients:
|
||||
- type: claude
|
||||
api_key: sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
|
||||
# See https://docs.mistral.ai/
|
||||
- type: mistral
|
||||
api_key: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
|
||||
# See https://docs.cohere.com/docs/the-cohere-platform
|
||||
- type: cohere
|
||||
api_key: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
|
||||
# Any openai-compatible API providers
|
||||
- type: openai-compatible
|
||||
name: localai
|
||||
|
244
src/client/cohere.rs
Normal file
244
src/client/cohere.rs
Normal file
@ -0,0 +1,244 @@
|
||||
use super::{
|
||||
message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model, PromptType,
|
||||
SendData, TokensCountFactors,
|
||||
};
|
||||
|
||||
use crate::{render::ReplyHandler, utils::PromptKind};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const API_URL: &str = "https://api.cohere.ai/v1/chat";
|
||||
|
||||
const MODELS: [(&str, usize, &str); 2] = [
|
||||
// https://docs.cohere.com/docs/command-r
|
||||
("command-r", 128000, "text"),
|
||||
("command-r-plus", 128000, "text"),
|
||||
];
|
||||
|
||||
const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct CohereConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Client for CohereClient {
|
||||
client_common_fns!();
|
||||
|
||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
send_message(builder).await
|
||||
}
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut ReplyHandler,
|
||||
data: SendData,
|
||||
) -> Result<()> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
send_message_streaming(builder, handler).await
|
||||
}
|
||||
}
|
||||
|
||||
impl CohereClient {
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
|
||||
pub const PROMPTS: [PromptType<'static>; 1] =
|
||||
[("api_key", "API Key:", false, PromptKind::String)];
|
||||
|
||||
pub fn list_models(local_config: &CohereConfig) -> Vec<Model> {
|
||||
let client_name = Self::name(local_config);
|
||||
MODELS
|
||||
.into_iter()
|
||||
.map(|(name, max_input_tokens, capabilities)| {
|
||||
Model::new(client_name, name)
|
||||
.set_capabilities(capabilities.into())
|
||||
.set_max_input_tokens(Some(max_input_tokens))
|
||||
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
let api_key = self.get_api_key().ok();
|
||||
|
||||
let mut body = build_body(data, self.model.name.clone())?;
|
||||
self.model.merge_extra_fields(&mut body);
|
||||
|
||||
let url = API_URL;
|
||||
|
||||
debug!("Cohere Request: {url} {body}");
|
||||
|
||||
let mut builder = client.post(url).json(&body);
|
||||
if let Some(api_key) = api_key {
|
||||
builder = builder.bearer_auth(api_key);
|
||||
}
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn send_message(builder: RequestBuilder) -> Result<String> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if status != 200 {
|
||||
check_error(&data)?;
|
||||
}
|
||||
let output = extract_text(&data)?;
|
||||
Ok(output.to_string())
|
||||
}
|
||||
|
||||
pub(crate) async fn send_message_streaming(
|
||||
builder: RequestBuilder,
|
||||
handler: &mut ReplyHandler,
|
||||
) -> Result<()> {
|
||||
let res = builder.send().await?;
|
||||
if res.status() != 200 {
|
||||
let data: Value = res.json().await?;
|
||||
check_error(&data)?;
|
||||
} else {
|
||||
let mut buffer = vec![];
|
||||
let mut cursor = 0;
|
||||
let mut start = 0;
|
||||
let mut balances = vec![];
|
||||
let mut quoting = false;
|
||||
let mut stream = res.bytes_stream();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
let chunk = std::str::from_utf8(&chunk)?;
|
||||
buffer.extend(chunk.chars());
|
||||
for i in cursor..buffer.len() {
|
||||
let ch = buffer[i];
|
||||
if quoting {
|
||||
if ch == '"' && buffer[i - 1] != '\\' {
|
||||
quoting = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
match ch {
|
||||
'"' => quoting = true,
|
||||
'{' => {
|
||||
if balances.is_empty() {
|
||||
start = i;
|
||||
}
|
||||
balances.push(ch);
|
||||
}
|
||||
'[' => {
|
||||
if start != 0 {
|
||||
balances.push(ch);
|
||||
}
|
||||
}
|
||||
'}' => {
|
||||
balances.pop();
|
||||
if balances.is_empty() {
|
||||
let value: String = buffer[start..=i].iter().collect();
|
||||
let value: Value = serde_json::from_str(&value)?;
|
||||
if let Some("text-generation") = value["event_type"].as_str() {
|
||||
handler.text(extract_text(&value)?)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
']' => {
|
||||
balances.pop();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
cursor = buffer.len();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn extract_text(data: &Value) -> Result<&str> {
|
||||
match data["text"].as_str() {
|
||||
Some(text) => Ok(text),
|
||||
None => {
|
||||
bail!("Invalid response data: {data}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn check_error(data: &Value) -> Result<()> {
|
||||
if let Some(message) = data["message"].as_str() {
|
||||
bail!("{message}");
|
||||
} else {
|
||||
bail!("Error {}", data);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build_body(data: SendData, model: String) -> Result<Value> {
|
||||
let SendData {
|
||||
mut messages,
|
||||
temperature,
|
||||
stream,
|
||||
} = data;
|
||||
|
||||
patch_system_message(&mut messages);
|
||||
|
||||
let mut image_urls = vec![];
|
||||
let mut messages: Vec<Value> = messages
|
||||
.into_iter()
|
||||
.map(|message| {
|
||||
let role = match message.role {
|
||||
MessageRole::User => "USER",
|
||||
_ => "CHATBOT",
|
||||
};
|
||||
match message.content {
|
||||
MessageContent::Text(text) => json!({
|
||||
"role": role,
|
||||
"message": text,
|
||||
}),
|
||||
MessageContent::Array(list) => {
|
||||
let list: Vec<String> = list
|
||||
.into_iter()
|
||||
.filter_map(|item| match item {
|
||||
MessageContentPart::Text { text } => Some(text),
|
||||
MessageContentPart::ImageUrl {
|
||||
image_url: ImageUrl { url },
|
||||
} => {
|
||||
image_urls.push(url.clone());
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
json!({ "role": role, "message": list.join("\n\n") })
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !image_urls.is_empty() {
|
||||
bail!("The model does not support images: {:?}", image_urls);
|
||||
}
|
||||
let message = messages.pop().unwrap();
|
||||
let message = message["message"].as_str().unwrap_or_default();
|
||||
|
||||
let mut body = json!({
|
||||
"model": model,
|
||||
"message": message,
|
||||
});
|
||||
|
||||
if !messages.is_empty() {
|
||||
body["chat_history"] = messages.into();
|
||||
}
|
||||
|
||||
if let Some(temperature) = temperature {
|
||||
body["temperature"] = temperature.into();
|
||||
}
|
||||
if stream {
|
||||
body["stream"] = true.into();
|
||||
}
|
||||
|
||||
Ok(body)
|
||||
}
|
@ -12,6 +12,7 @@ register_client!(
|
||||
(gemini, "gemini", GeminiConfig, GeminiClient),
|
||||
(claude, "claude", ClaudeConfig, ClaudeClient),
|
||||
(mistral, "mistral", MistralConfig, MistralClient),
|
||||
(cohere, "cohere", CohereConfig, CohereClient),
|
||||
(
|
||||
openai_compatible,
|
||||
"openai-compatible",
|
||||
|
Loading…
Reference in New Issue
Block a user