refactor: qiawen client add qwen-long (#537)

pull/544/head
sigoden 4 weeks ago committed by GitHub
parent 50b13d2de9
commit 2ccbb0f06a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -455,20 +455,20 @@
max_input_tokens: 124000
max_output_tokens: 4096
pass_max_tokens: true
input_price: 0.56
output_price: 1.12
input_price: 0
output_price: 0
- name: ernie-lite-8k
max_input_tokens: 7168
max_output_tokens: 2048
pass_max_tokens: true
input_price: 0.42
output_price: 0.84
input_price: 0
output_price: 0
- name: ernie-tiny-8k
max_input_tokens: 7168
max_output_tokens: 2048
pass_max_tokens: true
input_price: 0.14
output_price: 0.14
input_price: 0
output_price: 0
- platform: qianwen
# docs:
@ -477,22 +477,28 @@
# notes:
# - get max_output_tokens info from models doc
models:
- name: qwen-long
max_input_tokens: 1000000
input_price: 0.07
output_price: 0.28
- name: qwen-turbo
max_input_tokens: 6000
max_output_tokens: 1500
input_price: 1.12
output_price: 1.12
input_price: 0.28
output_price: 0.84
- name: qwen-plus
max_input_tokens: 30000
max_output_tokens: 2000
input_price: 2.8
output_price: 2.8
input_price: 0.56
output_price: 1.68
- name: qwen-max
max_input_tokens: 6000
max_output_tokens: 2000
input_price: 16.8
input_price: 5.6
output_price: 16.8
- name: qwen-max-longcontext
input_price: 5.6
output_price: 16.8
max_input_tokens: 28000
max_output_tokens: 2000
- name: qwen-vl-plus

@ -42,13 +42,12 @@ impl QianwenClient {
let api_key = self.get_api_key()?;
let stream = data.stream;
let is_vl = self.is_vl();
let url = match is_vl {
let url = match self.model.supports_vision() {
true => API_URL_VL,
false => API_URL,
};
let (mut body, has_upload) = build_body(data, &self.model, is_vl)?;
let (mut body, has_upload) = build_body(data, &self.model)?;
self.patch_request_body(&mut body);
debug!("Qianwen Request: {url} {body}");
@ -63,12 +62,10 @@ impl QianwenClient {
Ok(builder)
}
fn is_vl(&self) -> bool {
self.model.name().starts_with("qwen-vl")
}
}
#[async_trait]
impl Client for QianwenClient {
client_common_fns!();
@ -81,7 +78,7 @@ impl Client for QianwenClient {
let api_key = self.get_api_key()?;
patch_messages(self.model.name(), &api_key, &mut data.messages).await?;
let builder = self.request_builder(client, data)?;
send_message(builder, self.is_vl()).await
send_message(builder, &self.model).await
}
async fn send_message_streaming_inner(
@ -93,28 +90,35 @@ impl Client for QianwenClient {
let api_key = self.get_api_key()?;
patch_messages(self.model.name(), &api_key, &mut data.messages).await?;
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler, self.is_vl()).await
send_message_streaming(builder, handler, &self.model).await
}
}
async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result<CompletionOutput> {
async fn send_message(builder: RequestBuilder, model: &Model) -> Result<CompletionOutput> {
let data: Value = builder.send().await?.json().await?;
maybe_catch_error(&data)?;
debug!("non-stream-data: {data}");
extract_completion_text(&data, is_vl)
extract_completion_text(&data, model)
}
async fn send_message_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
is_vl: bool,
model: &Model,
) -> Result<()> {
let model_name = model.name();
let handle = |message: SsMmessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
maybe_catch_error(&data)?;
debug!("stream-data: {data}");
if is_vl {
if model_name == "qwen-long" {
if let Some(text) =
data["output"]["choices"][0]["message"]["content"].as_str()
{
handler.text(text)?;
}
} else if model.supports_vision() {
if let Some(text) =
data["output"]["choices"][0]["message"]["content"][0]["text"].as_str()
{
@ -129,7 +133,7 @@ async fn send_message_streaming(
sse_stream(builder, handle).await
}
fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool)> {
fn build_body(data: SendData, model: &Model) -> Result<(Value, bool)> {
let SendData {
messages,
temperature,
@ -140,7 +144,7 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool
let mut has_upload = false;
let mut is_tool_call = false;
let input = if is_vl {
let input = if model.supports_vision() {
let messages: Vec<Value> = messages
.into_iter()
.map(|message| {
@ -206,9 +210,13 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool
Ok((body, has_upload))
}
fn extract_completion_text(data: &Value, is_vl: bool) -> Result<CompletionOutput> {
fn extract_completion_text(data: &Value, model: &Model) -> Result<CompletionOutput> {
let err = || anyhow!("Invalid response data: {data}");
let text = if is_vl {
let text = if model.name() == "qwen-long" {
data["output"]["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(err)?
} else if model.supports_vision() {
data["output"]["choices"][0]["message"]["content"][0]["text"]
.as_str()
.ok_or_else(err)?

Loading…
Cancel
Save