From 1fb06ecdc4daeed618ec971170e82931464a8399 Mon Sep 17 00:00:00 2001 From: sigoden Date: Wed, 19 Jun 2024 06:17:29 +0800 Subject: [PATCH] feat: qianwen support function calling (#616) --- models.yaml | 6 +- src/client/qianwen.rs | 125 +++++++++++++++++++++++++++++++++++++----- 2 files changed, 113 insertions(+), 18 deletions(-) diff --git a/models.yaml b/models.yaml index b193200..7562f91 100644 --- a/models.yaml +++ b/models.yaml @@ -148,17 +148,14 @@ max_input_tokens: 64000 input_price: 2 output_price: 6 - supports_function_calling: true - name: mistral-small-latest max_input_tokens: 32000 input_price: 1 output_price: 3 - supports_function_calling: true - name: mistral-large-latest max_input_tokens: 32000 input_price: 4 output_price: 12 - supports_function_calling: true - name: codestral-latest max_input_tokens: 32000 input_price: 1 @@ -515,16 +512,19 @@ max_output_tokens: 1500 input_price: 0.28 output_price: 0.84 + supports_function_calling: true - name: qwen-plus max_input_tokens: 30000 max_output_tokens: 2000 input_price: 0.56 output_price: 1.68 + supports_function_calling: true - name: qwen-max max_input_tokens: 6000 max_output_tokens: 2000 input_price: 5.6 output_price: 16.8 + supports_function_calling: true - name: qwen-max-longcontext input_price: 5.6 output_price: 16.8 diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index 908cd1e..b0d0b58 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -149,22 +149,43 @@ async fn chat_completions_streaming( model: &Model, ) -> Result<()> { let model_name = model.name(); + let mut prev_text = String::new(); let handle = |message: SseMmessage| -> Result { let data: Value = serde_json::from_str(&message.data)?; maybe_catch_error(&data)?; debug!("stream-data: {data}"); if model_name == "qwen-long" { if let Some(text) = data["output"]["choices"][0]["message"]["content"].as_str() { - handler.text(text)?; + let delta_text = &text[prev_text.len()..]; + prev_text = text.to_string(); + handler.text(delta_text)?; } } else if model.supports_vision() { if let Some(text) = data["output"]["choices"][0]["message"]["content"][0]["text"].as_str() { - handler.text(text)?; + let delta_text = &text[prev_text.len()..]; + prev_text = text.to_string(); + handler.text(delta_text)?; } } else if let Some(text) = data["output"]["text"].as_str() { - handler.text(text)?; + if let Some(pos) = text.rfind("✿FUNCTION") { + if pos > prev_text.len() { + let delta_text = &text[prev_text.len()..pos]; + handler.text(delta_text)?; + } + prev_text = text.to_string(); + if let Some((name, arguments)) = parse_tool_call(&text[pos..]) { + let arguments: Value = arguments + .parse() + .with_context(|| format!("Invalid function call {name} {arguments}"))?; + handler.tool_call(ToolCall::new(name.to_string(), arguments, None))?; + } + } else { + let delta_text = &text[prev_text.len()..]; + prev_text = text.to_string(); + handler.text(delta_text)?; + } } Ok(false) }; @@ -177,12 +198,11 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu messages, temperature, top_p, - functions: _, - stream, + functions, + stream: _, } = data; let mut has_upload = false; - let mut is_tool_call = false; let input = if model.supports_vision() { let messages: Vec = messages .into_iter() @@ -205,7 +225,6 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu }) .collect(), MessageContent::ToolResults(_) => { - is_tool_call = true; vec![] } }; @@ -217,18 +236,60 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu "messages": messages, }) } else { + let messages: Vec = messages + .into_iter() + .flat_map(|message| { + let role = message.role; + match message.content { + MessageContent::Text(text) => vec![json!({ "role": role, "content": text })], + MessageContent::Array(list) => { + let parts: Vec<_> = list + .into_iter() + .map(|item| match item { + MessageContentPart::Text { text } => json!({"text": text}), + MessageContentPart::ImageUrl { + image_url: ImageUrl { url }, + } => { + if url.starts_with("oss:") { + has_upload = true; + } + json!({"image": url}) + } + }) + .collect(); + vec![json!({ "role": role, "content": parts })] + } + MessageContent::ToolResults((tool_results, _)) => { + let content = tool_results + .iter() + .map(|tool_result| { + format!( + "✿FUNCTION✿: {}\n✿ARGS✿: {}\n✿RESULT✿", + tool_result.call.name, tool_result.call.arguments + ) + }) + .collect::>() + .join("\n"); + let mut messages = + vec![json!({ "role": MessageRole::Assistant, "content": content })]; + for tool_result in tool_results { + messages.push(json!({ + "role": "tool", + "content": tool_result.output.to_string(), + "name": tool_result.call.name, + })); + } + messages + } + } + }) + .collect(); json!({ "messages": messages, }) }; - if is_tool_call { - bail!("The client does not support function calling",); - } let mut parameters = json!({}); - if stream { - parameters["incremental_output"] = true.into(); - } if let Some(v) = model.max_tokens_param() { parameters["max_tokens"] = v.into(); @@ -240,6 +301,18 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu parameters["top_p"] = v.into(); } + if let Some(functions) = functions { + parameters["tools"] = functions + .iter() + .map(|v| { + json!({ + "type": "function", + "function": v, + }) + }) + .collect(); + } + let body = json!({ "model": &model.name(), "input": input, @@ -280,6 +353,7 @@ struct EmbeddingsResBodyOutputEmbedding { fn extract_chat_completions_text(data: &Value, model: &Model) -> Result { let err = || anyhow!("Invalid response data: {data}"); + let mut tool_calls = vec![]; let text = if model.name() == "qwen-long" { data["output"]["choices"][0]["message"]["content"] .as_str() @@ -289,11 +363,21 @@ fn extract_chat_completions_text(data: &Value, model: &Model) -> Result { + let arguments: Value = arguments + .parse() + .with_context(|| format!("Invalid function call {name} {arguments}"))?; + tool_calls.push(ToolCall::new(name.to_string(), arguments, None)); + "" + } + None => text, + } }; let output = ChatCompletionsOutput { text: text.to_string(), - tool_calls: vec![], + tool_calls, id: data["request_id"].as_str().map(|v| v.to_string()), input_tokens: data["usage"]["input_tokens"].as_u64(), output_tokens: data["usage"]["output_tokens"].as_u64(), @@ -395,3 +479,14 @@ async fn upload(model: &str, api_key: &str, url: &str) -> Result { } Ok(format!("oss://{key}")) } + +fn parse_tool_call(text: &str) -> Option<(&str, &str)> { + let function_symbol = "✿FUNCTION✿: "; + let result_symbol = "\n✿RESULT✿: "; + let args_symbol = "\n✿ARGS✿: "; + let start = text.find(function_symbol)? + function_symbol.len(); + let text = &text[start..]; + let end = text.find(result_symbol)?; + let text = &text[..end]; + text.split_once(args_symbol) +}