|
|
@ -1,7 +1,7 @@
|
|
|
|
use super::access_token::*;
|
|
|
|
use super::access_token::*;
|
|
|
|
use super::*;
|
|
|
|
use super::*;
|
|
|
|
|
|
|
|
|
|
|
|
use anyhow::{anyhow, Context, Result};
|
|
|
|
use anyhow::{anyhow, bail, Context, Result};
|
|
|
|
use async_trait::async_trait;
|
|
|
|
use async_trait::async_trait;
|
|
|
|
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
|
|
|
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
|
|
|
use serde::Deserialize;
|
|
|
|
use serde::Deserialize;
|
|
|
@ -93,7 +93,7 @@ impl ErnieClient {
|
|
|
|
&self.model.name(),
|
|
|
|
&self.model.name(),
|
|
|
|
);
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
|
|
debug!("Ernie Re Rerank: {url} {body}");
|
|
|
|
debug!("Ernie Rerank Request: {url} {body}");
|
|
|
|
|
|
|
|
|
|
|
|
let builder = client.post(url).json(&body);
|
|
|
|
let builder = client.post(url).json(&body);
|
|
|
|
|
|
|
|
|
|
|
@ -179,7 +179,17 @@ async fn chat_completions_streaming(
|
|
|
|
let handle = |message: SseMmessage| -> Result<bool> {
|
|
|
|
let handle = |message: SseMmessage| -> Result<bool> {
|
|
|
|
let data: Value = serde_json::from_str(&message.data)?;
|
|
|
|
let data: Value = serde_json::from_str(&message.data)?;
|
|
|
|
debug!("stream-data: {data}");
|
|
|
|
debug!("stream-data: {data}");
|
|
|
|
if let Some(text) = data["result"].as_str() {
|
|
|
|
if let Some(function) = data["function_call"].as_object() {
|
|
|
|
|
|
|
|
if let (Some(name), Some(arguments)) = (
|
|
|
|
|
|
|
|
function.get("name").and_then(|v| v.as_str()),
|
|
|
|
|
|
|
|
function.get("arguments").and_then(|v| v.as_str()),
|
|
|
|
|
|
|
|
) {
|
|
|
|
|
|
|
|
let arguments: Value = arguments.parse().with_context(|| {
|
|
|
|
|
|
|
|
format!("Tool call '{name}' is invalid: arguments must be in valid JSON format")
|
|
|
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
handler.tool_call(ToolCall::new(name.to_string(), arguments, None))?;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else if let Some(text) = data["result"].as_str() {
|
|
|
|
handler.text(text)?;
|
|
|
|
handler.text(text)?;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Ok(false)
|
|
|
|
Ok(false)
|
|
|
@ -224,12 +234,37 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu
|
|
|
|
mut messages,
|
|
|
|
mut messages,
|
|
|
|
temperature,
|
|
|
|
temperature,
|
|
|
|
top_p,
|
|
|
|
top_p,
|
|
|
|
functions: _,
|
|
|
|
functions,
|
|
|
|
stream,
|
|
|
|
stream,
|
|
|
|
} = data;
|
|
|
|
} = data;
|
|
|
|
|
|
|
|
|
|
|
|
patch_system_message(&mut messages);
|
|
|
|
patch_system_message(&mut messages);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let messages: Vec<Value> = messages
|
|
|
|
|
|
|
|
.into_iter()
|
|
|
|
|
|
|
|
.flat_map(|message| {
|
|
|
|
|
|
|
|
let Message { role, content } = message;
|
|
|
|
|
|
|
|
match content {
|
|
|
|
|
|
|
|
MessageContent::ToolResults((tool_results, _)) => {
|
|
|
|
|
|
|
|
let mut list = vec![];
|
|
|
|
|
|
|
|
for tool_result in tool_results {
|
|
|
|
|
|
|
|
list.push(json!({
|
|
|
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
|
|
|
"content": format!("Action: {}\nAction Input: {}", tool_result.call.name, tool_result.call.arguments)
|
|
|
|
|
|
|
|
}));
|
|
|
|
|
|
|
|
list.push(json!({
|
|
|
|
|
|
|
|
"role": "user",
|
|
|
|
|
|
|
|
"content": tool_result.output.to_string(),
|
|
|
|
|
|
|
|
}))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
list
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
_ => vec![json!({ "role": role, "content": content })],
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
|
|
|
|
let mut body = json!({
|
|
|
|
let mut body = json!({
|
|
|
|
"messages": messages,
|
|
|
|
"messages": messages,
|
|
|
|
});
|
|
|
|
});
|
|
|
@ -248,16 +283,35 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu
|
|
|
|
body["stream"] = true.into();
|
|
|
|
body["stream"] = true.into();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if let Some(functions) = functions {
|
|
|
|
|
|
|
|
body["functions"] = json!(functions);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
body
|
|
|
|
body
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fn extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsOutput> {
|
|
|
|
fn extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsOutput> {
|
|
|
|
let text = data["result"]
|
|
|
|
let text = data["result"].as_str().unwrap_or_default();
|
|
|
|
.as_str()
|
|
|
|
|
|
|
|
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
|
|
|
|
let mut tool_calls = vec![];
|
|
|
|
|
|
|
|
if let Some(call) = data["function_call"].as_object() {
|
|
|
|
|
|
|
|
if let (Some(name), Some(arguments)) = (
|
|
|
|
|
|
|
|
call.get("name").and_then(|v| v.as_str()),
|
|
|
|
|
|
|
|
call.get("arguments").and_then(|v| v.as_str()),
|
|
|
|
|
|
|
|
) {
|
|
|
|
|
|
|
|
let arguments: Value = arguments.parse().with_context(|| {
|
|
|
|
|
|
|
|
format!("Tool call '{name}' is invalid: arguments must be in valid JSON format")
|
|
|
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
tool_calls.push(ToolCall::new(name.to_string(), arguments, None));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if text.is_empty() && tool_calls.is_empty() {
|
|
|
|
|
|
|
|
bail!("Invalid response data: {data}");
|
|
|
|
|
|
|
|
}
|
|
|
|
let output = ChatCompletionsOutput {
|
|
|
|
let output = ChatCompletionsOutput {
|
|
|
|
text: text.to_string(),
|
|
|
|
text: text.to_string(),
|
|
|
|
tool_calls: vec![],
|
|
|
|
tool_calls,
|
|
|
|
id: data["id"].as_str().map(|v| v.to_string()),
|
|
|
|
id: data["id"].as_str().map(|v| v.to_string()),
|
|
|
|
input_tokens: data["usage"]["prompt_tokens"].as_u64(),
|
|
|
|
input_tokens: data["usage"]["prompt_tokens"].as_u64(),
|
|
|
|
output_tokens: data["usage"]["completion_tokens"].as_u64(),
|
|
|
|
output_tokens: data["usage"]["completion_tokens"].as_u64(),
|
|
|
|