From 6c9d7a679ea0097ab12f295caf16bec3e75ee267 Mon Sep 17 00:00:00 2001 From: sigoden Date: Tue, 19 Dec 2023 23:10:35 +0800 Subject: [PATCH] feat: support qianwen:qwen-vl-plus (#275) --- src/client/gemini.rs | 19 +++--- src/client/qianwen.rs | 150 +++++++++++++++++++++++++++++++----------- 2 files changed, 121 insertions(+), 48 deletions(-) diff --git a/src/client/gemini.rs b/src/client/gemini.rs index 6497069..98fd60d 100644 --- a/src/client/gemini.rs +++ b/src/client/gemini.rs @@ -1,9 +1,9 @@ use super::{ - patch_system_message, Client, ExtraConfig, GeminiClient, Model, PromptType, SendData, - TokensCountFactors, + message::*, patch_system_message, Client, ExtraConfig, GeminiClient, Model, PromptType, + SendData, TokensCountFactors, }; -use crate::{client::*, config::GlobalConfig, render::ReplyHandler, utils::PromptKind}; +use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind}; use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; @@ -123,7 +123,7 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand for i in cursor..buffer.len() { let ch = buffer[i]; if quoting { - if ch == '"' && buffer[i-1] != '\\' { + if ch == '"' && buffer[i - 1] != '\\' { quoting = false; } continue; @@ -189,7 +189,7 @@ fn build_body(data: SendData, _model: String) -> Result { patch_system_message(&mut messages); - let mut invalid_urls = vec![]; + let mut network_image_urls = vec![]; let contents: Vec = messages .into_iter() .map(|message| { @@ -211,7 +211,7 @@ fn build_body(data: SendData, _model: String) -> Result { if let Some((mime_type, data)) = url.strip_prefix("data:").and_then(|v| v.split_once(";base64,")) { json!({ "inline_data": { "mime_type": mime_type, "data": data } }) } else { - invalid_urls.push(url.clone()); + network_image_urls.push(url.clone()); json!({ "url": url }) } }, @@ -223,8 +223,11 @@ fn build_body(data: SendData, _model: String) -> Result { }) .collect(); - if !invalid_urls.is_empty() { - bail!("The model does not support non-data URLs: {:?}", invalid_urls); + if !network_image_urls.is_empty() { + bail!( + "The model does not support network images: {:?}", + network_image_urls + ); } let mut body = json!({ diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index 619ed76..ed4b6e6 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -1,4 +1,4 @@ -use super::{Client, ExtraConfig, Model, PromptType, QianwenClient, SendData}; +use super::{message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, SendData}; use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind}; @@ -13,10 +13,15 @@ use serde_json::{json, Value}; const API_URL: &str = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"; -const MODELS: [(&str, usize); 3] = [ - ("qwen-turbo", 6144), - ("qwen-plus", 6144), - ("qwen-max", 6144), +const API_URL_VL: &str = + "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"; + +const MODELS: [(&str, usize); 5] = [ + ("qwen-turbo", 8192), + ("qwen-plus", 32768), + ("qwen-max", 8192), + ("qwen-max-longcontext", 30720), + ("qwen-vl-plus", 0), ]; #[derive(Debug, Clone, Deserialize, Default)] @@ -34,7 +39,7 @@ impl Client for QianwenClient { async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { let builder = self.request_builder(client, data)?; - send_message(builder).await + send_message(builder, self.is_vl()).await } async fn send_message_streaming_inner( @@ -44,7 +49,7 @@ impl Client for QianwenClient { data: SendData, ) -> Result<()> { let builder = self.request_builder(client, data)?; - send_message_streaming(builder, handler).await + send_message_streaming(builder, handler, self.is_vl()).await } } @@ -68,49 +73,71 @@ impl QianwenClient { let api_key = self.get_api_key()?; let stream = data.stream; - let body = build_body(data, self.model.name.clone()); - debug!("Qianwen Request: {API_URL} {body}"); + let is_vl = self.is_vl(); + let url = match is_vl { + true => API_URL_VL, + false => API_URL, + }; + let body = build_body(data, self.model.name.clone(), is_vl)?; + + debug!("Qianwen Request: {url} {body}"); - let mut builder = client.post(API_URL).bearer_auth(api_key).json(&body); + let mut builder = client.post(url).bearer_auth(api_key).json(&body); if stream { builder = builder.header("X-DashScope-SSE", "enable"); } Ok(builder) } + + fn is_vl(&self) -> bool { + self.model.name.starts_with("qwen-vl") + } } -async fn send_message(builder: RequestBuilder) -> Result { +async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result { let data: Value = builder.send().await?.json().await?; check_error(&data)?; - let output = data["output"]["text"] - .as_str() - .ok_or_else(|| anyhow!("Unexpected response {data}"))?; + let output = if is_vl { + data["output"]["choices"][0]["message"]["content"][0]["text"].as_str() + } else { + data["output"]["text"].as_str() + }; + + let output = output.ok_or_else(|| anyhow!("Unexpected response {data}"))?; Ok(output.to_string()) } -async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> { +async fn send_message_streaming( + builder: RequestBuilder, + handler: &mut ReplyHandler, + is_vl: bool, +) -> Result<()> { let mut es = builder.eventsource()?; + let mut offset = 0; while let Some(event) = es.next().await { match event { Ok(Event::Open) => {} Ok(Event::Message(message)) => { let data: Value = serde_json::from_str(&message.data)?; - if let Some(text) = data["output"]["text"].as_str() { + check_error(&data)?; + if is_vl { + let text = data["output"]["choices"][0]["message"]["content"][0]["text"].as_str(); + if let Some(text) = text { + let text = &text[offset..]; + handler.text(text)?; + offset += text.len(); + } + } else if let Some(text) = data["output"]["text"].as_str() { handler.text(text)?; } } Err(err) => { match err { - EventSourceError::InvalidStatusCode(_, res) => { - let data: Value = res.json().await?; - check_error(&data)?; - bail!("Request failed"); - } EventSourceError::StreamEnded => {} _ => { bail!("{}", err); @@ -125,38 +152,81 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand } fn check_error(data: &Value) -> Result<()> { - if let Some(code) = data["code"].as_str() { - if let Some(message) = data["message"].as_str() { - bail!("{message}"); - } else { - bail!("{code}"); - } + if let (Some(code), Some(message)) = (data["code"].as_str(), data["message"].as_str()) { + bail!("{code}: {message}"); } Ok(()) } -fn build_body(data: SendData, model: String) -> Value { +fn build_body(data: SendData, model: String, is_vl: bool) -> Result { let SendData { messages, temperature, stream, } = data; - let mut parameters = json!({}); - if stream { - parameters["incremental_output"] = true.into(); - } + let (input, parameters) = if is_vl { + let mut exist_embeded_image = false; - if let Some(v) = temperature { - parameters["temperature"] = v.into(); - } + let messages: Vec = messages + .into_iter() + .map(|message| { + let role = message.role; + let content = match message.content { + MessageContent::Text(text) => vec![json!({"text": text})], + MessageContent::Array(list) => list + .into_iter() + .map(|item| match item { + MessageContentPart::Text { text } => json!({"text": text}), + MessageContentPart::ImageUrl { + image_url: ImageUrl { url }, + } => { + if url.starts_with("data:") { + exist_embeded_image = true; + } + json!({"image": url}) + }, + }) + .collect(), + }; + json!({ "role": role, "content": content }) + }) + .collect(); - json!({ - "model": model, - "input": json!({ + if exist_embeded_image { + bail!("The model does not support embeded images"); + } + + let input = json!({ + "messages": messages, + }); + + let mut parameters = json!({}); + if let Some(v) = temperature { + parameters["top_k"] = ((v * 50.0).round() as usize).into(); + } + (input, parameters) + } else { + let input = json!({ "messages": messages, - }), + }); + + let mut parameters = json!({}); + if stream { + parameters["incremental_output"] = true.into(); + } + + if let Some(v) = temperature { + parameters["temperature"] = v.into(); + } + (input, parameters) + }; + + let body = json!({ + "model": model, + "input": input, "parameters": parameters - }) + }); + Ok(body) }