feat: ernie support function calling (#631)

pull/632/head
sigoden 2 weeks ago committed by GitHub
parent 1fd5c58cff
commit 250e0eb7fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -41,7 +41,7 @@ AIChat is an all-in-one AI CLI tool featuring chat REPL, RAG, function calling,
- Bedrock: Llama-3/Claude-3.5/Claude-3/Mistral (paid, vision)
- Cloudflare (free, vision, embedding)
- Replicate (paid)
- Ernie (paid)
- Ernie (paid, embedding, rerank, function-calling)
- Qianwen: Qwen (paid, vision, embedding, function-calling)
- Moonshot (paid, function-calling)
- Deepseek (paid)

@ -501,10 +501,12 @@
max_input_tokens: 8192
input_price: 16.8
output_price: 16.8
supports_function_calling: true
- name: ernie-3.5-8k-0613
max_input_tokens: 8192
input_price: 1.68
output_price: 1.68
supports_function_calling: true
- name: ernie-speed-128k
max_input_tokens: 128000
input_price: 0

@ -1,7 +1,7 @@
use super::access_token::*;
use super::*;
use anyhow::{anyhow, Context, Result};
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
@ -93,7 +93,7 @@ impl ErnieClient {
&self.model.name(),
);
debug!("Ernie Re Rerank: {url} {body}");
debug!("Ernie Rerank Request: {url} {body}");
let builder = client.post(url).json(&body);
@ -179,7 +179,17 @@ async fn chat_completions_streaming(
let handle = |message: SseMmessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(text) = data["result"].as_str() {
if let Some(function) = data["function_call"].as_object() {
if let (Some(name), Some(arguments)) = (
function.get("name").and_then(|v| v.as_str()),
function.get("arguments").and_then(|v| v.as_str()),
) {
let arguments: Value = arguments.parse().with_context(|| {
format!("Tool call '{name}' is invalid: arguments must be in valid JSON format")
})?;
handler.tool_call(ToolCall::new(name.to_string(), arguments, None))?;
}
} else if let Some(text) = data["result"].as_str() {
handler.text(text)?;
}
Ok(false)
@ -224,12 +234,37 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu
mut messages,
temperature,
top_p,
functions: _,
functions,
stream,
} = data;
patch_system_message(&mut messages);
let messages: Vec<Value> = messages
.into_iter()
.flat_map(|message| {
let Message { role, content } = message;
match content {
MessageContent::ToolResults((tool_results, _)) => {
let mut list = vec![];
for tool_result in tool_results {
list.push(json!({
"role": "assistant",
"content": format!("Action: {}\nAction Input: {}", tool_result.call.name, tool_result.call.arguments)
}));
list.push(json!({
"role": "user",
"content": tool_result.output.to_string(),
}))
}
list
}
_ => vec![json!({ "role": role, "content": content })],
}
})
.collect();
let mut body = json!({
"messages": messages,
});
@ -248,16 +283,35 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu
body["stream"] = true.into();
}
if let Some(functions) = functions {
body["functions"] = json!(functions);
}
body
}
fn extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["result"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let text = data["result"].as_str().unwrap_or_default();
let mut tool_calls = vec![];
if let Some(call) = data["function_call"].as_object() {
if let (Some(name), Some(arguments)) = (
call.get("name").and_then(|v| v.as_str()),
call.get("arguments").and_then(|v| v.as_str()),
) {
let arguments: Value = arguments.parse().with_context(|| {
format!("Tool call '{name}' is invalid: arguments must be in valid JSON format")
})?;
tool_calls.push(ToolCall::new(name.to_string(), arguments, None));
}
}
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls: vec![],
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["prompt_tokens"].as_u64(),
output_tokens: data["usage"]["completion_tokens"].as_u64(),

Loading…
Cancel
Save