feat: qianwen support function calling (#616)

pull/617/head
sigoden 2 weeks ago committed by GitHub
parent 98ac7e2b57
commit 1fb06ecdc4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -148,17 +148,14 @@
max_input_tokens: 64000
input_price: 2
output_price: 6
supports_function_calling: true
- name: mistral-small-latest
max_input_tokens: 32000
input_price: 1
output_price: 3
supports_function_calling: true
- name: mistral-large-latest
max_input_tokens: 32000
input_price: 4
output_price: 12
supports_function_calling: true
- name: codestral-latest
max_input_tokens: 32000
input_price: 1
@ -515,16 +512,19 @@
max_output_tokens: 1500
input_price: 0.28
output_price: 0.84
supports_function_calling: true
- name: qwen-plus
max_input_tokens: 30000
max_output_tokens: 2000
input_price: 0.56
output_price: 1.68
supports_function_calling: true
- name: qwen-max
max_input_tokens: 6000
max_output_tokens: 2000
input_price: 5.6
output_price: 16.8
supports_function_calling: true
- name: qwen-max-longcontext
input_price: 5.6
output_price: 16.8

@ -149,22 +149,43 @@ async fn chat_completions_streaming(
model: &Model,
) -> Result<()> {
let model_name = model.name();
let mut prev_text = String::new();
let handle = |message: SseMmessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
maybe_catch_error(&data)?;
debug!("stream-data: {data}");
if model_name == "qwen-long" {
if let Some(text) = data["output"]["choices"][0]["message"]["content"].as_str() {
handler.text(text)?;
let delta_text = &text[prev_text.len()..];
prev_text = text.to_string();
handler.text(delta_text)?;
}
} else if model.supports_vision() {
if let Some(text) =
data["output"]["choices"][0]["message"]["content"][0]["text"].as_str()
{
handler.text(text)?;
let delta_text = &text[prev_text.len()..];
prev_text = text.to_string();
handler.text(delta_text)?;
}
} else if let Some(text) = data["output"]["text"].as_str() {
handler.text(text)?;
if let Some(pos) = text.rfind("✿FUNCTION") {
if pos > prev_text.len() {
let delta_text = &text[prev_text.len()..pos];
handler.text(delta_text)?;
}
prev_text = text.to_string();
if let Some((name, arguments)) = parse_tool_call(&text[pos..]) {
let arguments: Value = arguments
.parse()
.with_context(|| format!("Invalid function call {name} {arguments}"))?;
handler.tool_call(ToolCall::new(name.to_string(), arguments, None))?;
}
} else {
let delta_text = &text[prev_text.len()..];
prev_text = text.to_string();
handler.text(delta_text)?;
}
}
Ok(false)
};
@ -177,12 +198,11 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
messages,
temperature,
top_p,
functions: _,
stream,
functions,
stream: _,
} = data;
let mut has_upload = false;
let mut is_tool_call = false;
let input = if model.supports_vision() {
let messages: Vec<Value> = messages
.into_iter()
@ -205,7 +225,6 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
})
.collect(),
MessageContent::ToolResults(_) => {
is_tool_call = true;
vec![]
}
};
@ -217,18 +236,60 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
"messages": messages,
})
} else {
let messages: Vec<Value> = messages
.into_iter()
.flat_map(|message| {
let role = message.role;
match message.content {
MessageContent::Text(text) => vec![json!({ "role": role, "content": text })],
MessageContent::Array(list) => {
let parts: Vec<_> = list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => json!({"text": text}),
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if url.starts_with("oss:") {
has_upload = true;
}
json!({"image": url})
}
})
.collect();
vec![json!({ "role": role, "content": parts })]
}
MessageContent::ToolResults((tool_results, _)) => {
let content = tool_results
.iter()
.map(|tool_result| {
format!(
"✿FUNCTION✿: {}\n✿ARGS✿: {}\n✿RESULT✿",
tool_result.call.name, tool_result.call.arguments
)
})
.collect::<Vec<String>>()
.join("\n");
let mut messages =
vec![json!({ "role": MessageRole::Assistant, "content": content })];
for tool_result in tool_results {
messages.push(json!({
"role": "tool",
"content": tool_result.output.to_string(),
"name": tool_result.call.name,
}));
}
messages
}
}
})
.collect();
json!({
"messages": messages,
})
};
if is_tool_call {
bail!("The client does not support function calling",);
}
let mut parameters = json!({});
if stream {
parameters["incremental_output"] = true.into();
}
if let Some(v) = model.max_tokens_param() {
parameters["max_tokens"] = v.into();
@ -240,6 +301,18 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
parameters["top_p"] = v.into();
}
if let Some(functions) = functions {
parameters["tools"] = functions
.iter()
.map(|v| {
json!({
"type": "function",
"function": v,
})
})
.collect();
}
let body = json!({
"model": &model.name(),
"input": input,
@ -280,6 +353,7 @@ struct EmbeddingsResBodyOutputEmbedding {
fn extract_chat_completions_text(data: &Value, model: &Model) -> Result<ChatCompletionsOutput> {
let err = || anyhow!("Invalid response data: {data}");
let mut tool_calls = vec![];
let text = if model.name() == "qwen-long" {
data["output"]["choices"][0]["message"]["content"]
.as_str()
@ -289,11 +363,21 @@ fn extract_chat_completions_text(data: &Value, model: &Model) -> Result<ChatComp
.as_str()
.ok_or_else(err)?
} else {
data["output"]["text"].as_str().ok_or_else(err)?
let text = data["output"]["text"].as_str().ok_or_else(err)?;
match parse_tool_call(text) {
Some((name, arguments)) => {
let arguments: Value = arguments
.parse()
.with_context(|| format!("Invalid function call {name} {arguments}"))?;
tool_calls.push(ToolCall::new(name.to_string(), arguments, None));
""
}
None => text,
}
};
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls: vec![],
tool_calls,
id: data["request_id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["output_tokens"].as_u64(),
@ -395,3 +479,14 @@ async fn upload(model: &str, api_key: &str, url: &str) -> Result<String> {
}
Ok(format!("oss://{key}"))
}
fn parse_tool_call(text: &str) -> Option<(&str, &str)> {
let function_symbol = "✿FUNCTION✿: ";
let result_symbol = "\n✿RESULT✿: ";
let args_symbol = "\n✿ARGS✿: ";
let start = text.find(function_symbol)? + function_symbol.len();
let text = &text[start..];
let end = text.find(result_symbol)?;
let text = &text[..end];
text.split_once(args_symbol)
}

Loading…
Cancel
Save