diff --git a/README.md b/README.md index eb202a4..a061bf2 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ Download it from [GitHub Releases](https://github.com/sigoden/aichat/releases), - Azure-OpenAI: user deployed gpt3.5/gpt4 - PaLM: chat-bison-001 - Ernie: eb-instant/ernie-bot/ernie-bot-4 +- Qianwen: qwen-turbo/qwen-plus ## Features diff --git a/config.example.yaml b/config.example.yaml index 8978536..30e00c9 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -45,4 +45,8 @@ clients: # See https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html - type: ernie api_key: xxx - secret_key: xxx \ No newline at end of file + secret_key: xxx + + # See https://help.aliyun.com/zh/dashscope/ + - type: qianwen + api_key: xxx \ No newline at end of file diff --git a/src/client/mod.rs b/src/client/mod.rs index ab58056..be29d79 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -18,4 +18,5 @@ register_client!( ), (palm, "palm", PaLMConfig, PaLMClient), (ernie, "ernie", ErnieConfig, ErnieClient), + (qianwen, "qianwen", QianwenConfig, QianwenClient), ); diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs new file mode 100644 index 0000000..d5145c2 --- /dev/null +++ b/src/client/qianwen.rs @@ -0,0 +1,160 @@ +use super::{QianwenClient, Client, ExtraConfig, PromptType, SendData, Model}; + +use crate::{ + config::GlobalConfig, + render::ReplyHandler, + utils::PromptKind, +}; + +use anyhow::{anyhow, bail, Result}; +use async_trait::async_trait; +use futures_util::StreamExt; +use reqwest::{Client as ReqwestClient, RequestBuilder}; +use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt}; +use serde::Deserialize; +use serde_json::{json, Value}; + +const API_URL: &str = + "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"; + +const MODELS: [(&str, usize); 2] = [("qwen-turbo", 6144), ("qwen-plus", 6144)]; + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct QianwenConfig { + pub name: Option, + pub api_key: Option, + pub extra: Option, +} + +#[async_trait] +impl Client for QianwenClient { + fn config(&self) -> (&GlobalConfig, &Option) { + (&self.global_config, &self.config.extra) + } + + async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { + let builder = self.request_builder(client, data)?; + send_message(builder).await + } + + async fn send_message_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut ReplyHandler, + data: SendData, + ) -> Result<()> { + let builder = self.request_builder(client, data)?; + send_message_streaming(builder, handler).await + } +} + +impl QianwenClient { + config_get_fn!(api_key, get_api_key); + + pub const PROMPTS: [PromptType<'static>; 1] = + [("api_key", "API Key:", true, PromptKind::String)]; + + pub fn list_models(local_config: &QianwenConfig, client_index: usize) -> Vec { + let client_name = Self::name(local_config); + MODELS + .into_iter() + .map(|(name, max_tokens)| Model::new(client_index, client_name, name).set_max_tokens(Some(max_tokens))) + .collect() + } + + fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + let api_key = self.get_api_key()?; + + let stream = data.stream; + let body = build_body(data, self.model.llm_name.clone()); + + let mut builder = client.post(API_URL).bearer_auth(api_key).json(&body); + if stream { + builder = builder.header("X-DashScope-SSE", "enable"); + } + + Ok(builder) + } +} + +async fn send_message(builder: RequestBuilder) -> Result { + let data: Value = builder.send().await?.json().await?; + check_error(&data)?; + + let output = data["output"]["text"].as_str() + .ok_or_else(|| anyhow!("Unexpected response {data}"))?; + + Ok(output.to_string()) +} + +async fn send_message_streaming( + builder: RequestBuilder, + handler: &mut ReplyHandler, +) -> Result<()> { + let mut es = builder.eventsource()?; + let mut offset = 0; + + while let Some(event) = es.next().await { + match event { + Ok(Event::Open) => {} + Ok(Event::Message(message)) => { + let data: Value = serde_json::from_str(&message.data)?; + if let Some(text) = data["output"]["text"].as_str() { + + let text = &text[offset..]; + handler.text(text)?; + offset += text.len(); + } + } + Err(err) => { + match err { + EventSourceError::InvalidStatusCode(_, res) => { + let data: Value = res.json().await?; + check_error(&data)?; + bail!("Request failed"); + } + EventSourceError::StreamEnded => {} + _ => { + bail!("{}", err); + } + } + es.close(); + } + } + } + + Ok(()) +} + +fn check_error(data: &Value) -> Result<()> { + if let Some(code) = data["code"].as_str() { + if let Some(message) = data["message"].as_str() { + bail!("{message}"); + } else { + bail!("{code}"); + } + } + Ok(()) +} + +fn build_body(data: SendData, model: String) -> Value { + let SendData { + messages, + temperature, + stream: _, + } = data; + + let mut parameters = json!({}); + + if let Some(v) = temperature { + parameters["temperature"] = v.into(); + } + + json!({ + "model": model, + "input": json!({ + "messages": messages, + }), + "parameters": parameters + }) +}