From 9037074eb238221d38aa81112328d9c919c0bfc2 Mon Sep 17 00:00:00 2001 From: sigoden Date: Wed, 24 Jul 2024 15:51:00 +0800 Subject: [PATCH] fix: invalid tool_calls of qianwen client (#740) --- src/client/qianwen.rs | 72 +++++++++++++++++++++++-------------------- 1 file changed, 39 insertions(+), 33 deletions(-) diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index ec84011..3e2763b 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -156,23 +156,21 @@ async fn chat_completions_streaming( debug!("stream-data: {data}"); if model_name == "qwen-long" { if let Some(text) = data["output"]["choices"][0]["message"]["content"].as_str() { - let delta_text = &text[prev_text.len()..]; - prev_text = text.to_string(); - handler.text(delta_text)?; + handler.text(text)?; } } else if model.supports_vision() { if let Some(text) = data["output"]["choices"][0]["message"]["content"][0]["text"].as_str() { - let delta_text = &text[prev_text.len()..]; - prev_text = text.to_string(); - handler.text(delta_text)?; + handler.text(text)?; } } else if let Some(text) = data["output"]["text"].as_str() { if let Some(pos) = text.rfind("✿FUNCTION") { if pos > prev_text.len() { let delta_text = &text[prev_text.len()..pos]; - handler.text(delta_text)?; + if delta_text != ": \n" { + handler.text(delta_text)?; + } } prev_text = text.to_string(); if let Some((name, arguments)) = parse_tool_call(&text[pos..]) { @@ -182,7 +180,10 @@ async fn chat_completions_streaming( handler.tool_call(ToolCall::new(name.to_string(), arguments, None))?; } } else { - let delta_text = &text[prev_text.len()..]; + let mut delta_text = &text[prev_text.len()..]; + if prev_text.is_empty() && delta_text.starts_with(": ") { + delta_text = &delta_text[2..]; + } prev_text = text.to_string(); handler.text(delta_text)?; } @@ -199,7 +200,7 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu temperature, top_p, functions, - stream: _, + stream, } = data; let mut has_upload = false; @@ -236,11 +237,12 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu "messages": messages, }) } else { - let messages: Vec = messages - .into_iter() - .flat_map(|message| { - let role = message.role; - match message.content { + let messages: Vec = + messages + .into_iter() + .flat_map(|message| { + let role = message.role; + match message.content { MessageContent::Text(text) => vec![json!({ "role": role, "content": text })], MessageContent::Array(list) => { let parts: Vec<_> = list @@ -260,30 +262,30 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu vec![json!({ "role": role, "content": parts })] } MessageContent::ToolResults((tool_results, _)) => { - let content = tool_results - .iter() - .map(|tool_result| { - format!( - "✿FUNCTION✿: {}\n✿ARGS✿: {}\n✿RESULT✿", - tool_result.call.name, tool_result.call.arguments - ) - }) - .collect::>() - .join("\n"); - let mut messages = - vec![json!({ "role": MessageRole::Assistant, "content": content })]; - for tool_result in tool_results { - messages.push(json!({ + tool_results.into_iter().flat_map(|tool_result| vec![ + json!({ + "role": MessageRole::Assistant, + "content": "", + "tool_calls": vec![ + json!({ + "type": "function", + "function": { + "name": tool_result.call.name, + "arguments": tool_result.call.arguments.to_string(), + }, + }) + ], + }), + json!({ "role": "tool", "content": tool_result.output.to_string(), "name": tool_result.call.name, - })); - } - messages + }), + ]).collect() } } - }) - .collect(); + }) + .collect(); json!({ "messages": messages, }) @@ -291,6 +293,10 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu let mut parameters = json!({}); + if stream && (model.name() == "qwen-long" || model.supports_vision()) { + parameters["incremental_output"] = true.into(); + } + if let Some(v) = model.max_tokens_param() { parameters["max_tokens"] = v.into(); }