From 250e0eb7fee0d180b08b7d38d10c530e9a6a120d Mon Sep 17 00:00:00 2001 From: sigoden Date: Sat, 22 Jun 2024 14:00:25 +0800 Subject: [PATCH] feat: ernie support function calling (#631) --- README.md | 2 +- models.yaml | 2 ++ src/client/ernie.rs | 70 +++++++++++++++++++++++++++++++++++++++------ 3 files changed, 65 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 9c5cb04..74fcfcf 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ AIChat is an all-in-one AI CLI tool featuring chat REPL, RAG, function calling, - Bedrock: Llama-3/Claude-3.5/Claude-3/Mistral (paid, vision) - Cloudflare (free, vision, embedding) - Replicate (paid) -- Ernie (paid) +- Ernie (paid, embedding, rerank, function-calling) - Qianwen: Qwen (paid, vision, embedding, function-calling) - Moonshot (paid, function-calling) - Deepseek (paid) diff --git a/models.yaml b/models.yaml index 867d0c0..10c6d50 100644 --- a/models.yaml +++ b/models.yaml @@ -501,10 +501,12 @@ max_input_tokens: 8192 input_price: 16.8 output_price: 16.8 + supports_function_calling: true - name: ernie-3.5-8k-0613 max_input_tokens: 8192 input_price: 1.68 output_price: 1.68 + supports_function_calling: true - name: ernie-speed-128k max_input_tokens: 128000 input_price: 0 diff --git a/src/client/ernie.rs b/src/client/ernie.rs index 3b6f562..21362c1 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -1,7 +1,7 @@ use super::access_token::*; use super::*; -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; @@ -93,7 +93,7 @@ impl ErnieClient { &self.model.name(), ); - debug!("Ernie Re Rerank: {url} {body}"); + debug!("Ernie Rerank Request: {url} {body}"); let builder = client.post(url).json(&body); @@ -179,7 +179,17 @@ async fn chat_completions_streaming( let handle = |message: SseMmessage| -> Result { let data: Value = serde_json::from_str(&message.data)?; debug!("stream-data: {data}"); - if let Some(text) = data["result"].as_str() { + if let Some(function) = data["function_call"].as_object() { + if let (Some(name), Some(arguments)) = ( + function.get("name").and_then(|v| v.as_str()), + function.get("arguments").and_then(|v| v.as_str()), + ) { + let arguments: Value = arguments.parse().with_context(|| { + format!("Tool call '{name}' is invalid: arguments must be in valid JSON format") + })?; + handler.tool_call(ToolCall::new(name.to_string(), arguments, None))?; + } + } else if let Some(text) = data["result"].as_str() { handler.text(text)?; } Ok(false) @@ -224,12 +234,37 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu mut messages, temperature, top_p, - functions: _, + functions, stream, } = data; patch_system_message(&mut messages); + let messages: Vec = messages + .into_iter() + .flat_map(|message| { + let Message { role, content } = message; + match content { + MessageContent::ToolResults((tool_results, _)) => { + let mut list = vec![]; + for tool_result in tool_results { + list.push(json!({ + "role": "assistant", + "content": format!("Action: {}\nAction Input: {}", tool_result.call.name, tool_result.call.arguments) + })); + list.push(json!({ + "role": "user", + "content": tool_result.output.to_string(), + })) + + } + list + } + _ => vec![json!({ "role": role, "content": content })], + } + }) + .collect(); + let mut body = json!({ "messages": messages, }); @@ -248,16 +283,35 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu body["stream"] = true.into(); } + if let Some(functions) = functions { + body["functions"] = json!(functions); + } + body } fn extract_chat_completions_text(data: &Value) -> Result { - let text = data["result"] - .as_str() - .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; + let text = data["result"].as_str().unwrap_or_default(); + + let mut tool_calls = vec![]; + if let Some(call) = data["function_call"].as_object() { + if let (Some(name), Some(arguments)) = ( + call.get("name").and_then(|v| v.as_str()), + call.get("arguments").and_then(|v| v.as_str()), + ) { + let arguments: Value = arguments.parse().with_context(|| { + format!("Tool call '{name}' is invalid: arguments must be in valid JSON format") + })?; + tool_calls.push(ToolCall::new(name.to_string(), arguments, None)); + } + } + + if text.is_empty() && tool_calls.is_empty() { + bail!("Invalid response data: {data}"); + } let output = ChatCompletionsOutput { text: text.to_string(), - tool_calls: vec![], + tool_calls, id: data["id"].as_str().map(|v| v.to_string()), input_tokens: data["usage"]["prompt_tokens"].as_u64(), output_tokens: data["usage"]["completion_tokens"].as_u64(),