fix: invalid tool_calls of qianwen client (#740)

pull/741/head
sigoden 3 months ago committed by GitHub
parent cf9d06f51e
commit 9037074eb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -156,23 +156,21 @@ async fn chat_completions_streaming(
debug!("stream-data: {data}"); debug!("stream-data: {data}");
if model_name == "qwen-long" { if model_name == "qwen-long" {
if let Some(text) = data["output"]["choices"][0]["message"]["content"].as_str() { if let Some(text) = data["output"]["choices"][0]["message"]["content"].as_str() {
let delta_text = &text[prev_text.len()..]; handler.text(text)?;
prev_text = text.to_string();
handler.text(delta_text)?;
} }
} else if model.supports_vision() { } else if model.supports_vision() {
if let Some(text) = if let Some(text) =
data["output"]["choices"][0]["message"]["content"][0]["text"].as_str() data["output"]["choices"][0]["message"]["content"][0]["text"].as_str()
{ {
let delta_text = &text[prev_text.len()..]; handler.text(text)?;
prev_text = text.to_string();
handler.text(delta_text)?;
} }
} else if let Some(text) = data["output"]["text"].as_str() { } else if let Some(text) = data["output"]["text"].as_str() {
if let Some(pos) = text.rfind("✿FUNCTION") { if let Some(pos) = text.rfind("✿FUNCTION") {
if pos > prev_text.len() { if pos > prev_text.len() {
let delta_text = &text[prev_text.len()..pos]; let delta_text = &text[prev_text.len()..pos];
handler.text(delta_text)?; if delta_text != ": \n" {
handler.text(delta_text)?;
}
} }
prev_text = text.to_string(); prev_text = text.to_string();
if let Some((name, arguments)) = parse_tool_call(&text[pos..]) { if let Some((name, arguments)) = parse_tool_call(&text[pos..]) {
@ -182,7 +180,10 @@ async fn chat_completions_streaming(
handler.tool_call(ToolCall::new(name.to_string(), arguments, None))?; handler.tool_call(ToolCall::new(name.to_string(), arguments, None))?;
} }
} else { } else {
let delta_text = &text[prev_text.len()..]; let mut delta_text = &text[prev_text.len()..];
if prev_text.is_empty() && delta_text.starts_with(": ") {
delta_text = &delta_text[2..];
}
prev_text = text.to_string(); prev_text = text.to_string();
handler.text(delta_text)?; handler.text(delta_text)?;
} }
@ -199,7 +200,7 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
temperature, temperature,
top_p, top_p,
functions, functions,
stream: _, stream,
} = data; } = data;
let mut has_upload = false; let mut has_upload = false;
@ -236,11 +237,12 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
"messages": messages, "messages": messages,
}) })
} else { } else {
let messages: Vec<Value> = messages let messages: Vec<Value> =
.into_iter() messages
.flat_map(|message| { .into_iter()
let role = message.role; .flat_map(|message| {
match message.content { let role = message.role;
match message.content {
MessageContent::Text(text) => vec![json!({ "role": role, "content": text })], MessageContent::Text(text) => vec![json!({ "role": role, "content": text })],
MessageContent::Array(list) => { MessageContent::Array(list) => {
let parts: Vec<_> = list let parts: Vec<_> = list
@ -260,30 +262,30 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
vec![json!({ "role": role, "content": parts })] vec![json!({ "role": role, "content": parts })]
} }
MessageContent::ToolResults((tool_results, _)) => { MessageContent::ToolResults((tool_results, _)) => {
let content = tool_results tool_results.into_iter().flat_map(|tool_result| vec![
.iter() json!({
.map(|tool_result| { "role": MessageRole::Assistant,
format!( "content": "",
"✿FUNCTION✿: {}\n✿ARGS✿: {}\n✿RESULT✿", "tool_calls": vec![
tool_result.call.name, tool_result.call.arguments json!({
) "type": "function",
}) "function": {
.collect::<Vec<String>>() "name": tool_result.call.name,
.join("\n"); "arguments": tool_result.call.arguments.to_string(),
let mut messages = },
vec![json!({ "role": MessageRole::Assistant, "content": content })]; })
for tool_result in tool_results { ],
messages.push(json!({ }),
json!({
"role": "tool", "role": "tool",
"content": tool_result.output.to_string(), "content": tool_result.output.to_string(),
"name": tool_result.call.name, "name": tool_result.call.name,
})); }),
} ]).collect()
messages
} }
} }
}) })
.collect(); .collect();
json!({ json!({
"messages": messages, "messages": messages,
}) })
@ -291,6 +293,10 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
let mut parameters = json!({}); let mut parameters = json!({});
if stream && (model.name() == "qwen-long" || model.supports_vision()) {
parameters["incremental_output"] = true.into();
}
if let Some(v) = model.max_tokens_param() { if let Some(v) = model.max_tokens_param() {
parameters["max_tokens"] = v.into(); parameters["max_tokens"] = v.into();
} }

Loading…
Cancel
Save