feat: proxy chat-completions api with tools support (#850)

pull/851/head
sigoden 1 month ago committed by GitHub
parent a56d5f2ddf
commit 69965466e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -108,9 +108,12 @@ pub async fn openai_chat_completions_streaming(
let handle = |message: SseMmessage| -> Result<bool> {
if message.data == "[DONE]" {
if !function_name.is_empty() {
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
json!(function_arguments),
arguments,
normalize_function_id(&function_id),
))?;
}
@ -128,9 +131,12 @@ pub async fn openai_chat_completions_streaming(
let index = index.unwrap_or_default();
if index != function_index {
if !function_name.is_empty() {
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
json!(function_arguments),
arguments,
normalize_function_id(&function_id),
))?;
}
@ -207,7 +213,7 @@ pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Mod
"type": "function",
"function": {
"name": tool_result.call.name,
"arguments": tool_result.call.arguments,
"arguments": tool_result.call.arguments.to_string(),
},
})
}).collect();
@ -237,7 +243,7 @@ pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Mod
"type": "function",
"function": {
"name": tool_result.call.name,
"arguments": tool_result.call.arguments,
"arguments": tool_result.call.arguments.to_string(),
},
}
]
@ -302,24 +308,24 @@ pub fn openai_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOu
let mut tool_calls = vec![];
if let Some(calls) = data["choices"][0]["message"]["tool_calls"].as_array() {
tool_calls = calls
.iter()
.filter_map(|call| {
for call in calls {
if let (Some(name), Some(arguments), Some(id)) = (
call["function"]["name"].as_str(),
call["function"]["arguments"].as_str(),
call["id"].as_str(),
) {
Some(ToolCall::new(
let arguments: Value = arguments.parse().with_context(|| {
format!(
"Tool call '{name}' is invalid: arguments must be in valid JSON format"
)
})?;
tool_calls.push(ToolCall::new(
name.to_string(),
json!(arguments),
arguments,
Some(id.to_string()),
))
} else {
None
));
}
}
})
.collect()
};
if text.is_empty() && tool_calls.is_empty() {

@ -65,6 +65,10 @@ impl SseHandler {
self.abort.clone()
}
pub fn get_tool_calls(&self) -> &[ToolCall] {
&self.tool_calls
}
pub fn take(self) -> (String, Vec<ToolCall>) {
let Self {
buffer, tool_calls, ..

@ -275,8 +275,14 @@ impl Server {
top_p,
max_tokens,
stream,
tools,
} = req_body;
let messages =
parse_messages(messages).map_err(|err| anyhow!("Invalid request body, {err}"))?;
let functions = parse_tools(tools).map_err(|err| anyhow!("Invalid request body, {err}"))?;
let config = self.config.clone();
let default_model = config.model.clone();
@ -309,7 +315,7 @@ impl Server {
messages,
temperature,
top_p,
functions: None,
functions,
stream,
};
@ -351,11 +357,17 @@ impl Server {
let ret = client
.chat_completions_streaming_inner(http_client, handler, data)
.await;
if let Err(err) = ret {
let first = match ret {
Ok(()) => None,
Err(err) => Some(format!("{err:?}")),
};
if is_first.load(Ordering::SeqCst) {
let _ = tx.send(ResEvent::First(Some(format!("{err:?}"))));
let _ = tx.send(ResEvent::First(first));
is_first.store(false, Ordering::SeqCst)
}
let tool_calls = handler.get_tool_calls();
if !tool_calls.is_empty() {
let _ = tx.send(ResEvent::ToolCalls(tool_calls.to_vec()));
}
handler.done();
}
@ -378,23 +390,32 @@ impl Server {
bail!("{err}");
}
let shared: Arc<(String, String, i64)> = Arc::new((completion_id, model_name, created));
let shared: Arc<(String, String, i64, AtomicBool)> =
Arc::new((completion_id, model_name, created, AtomicBool::new(false)));
let stream = UnboundedReceiverStream::new(rx);
let stream = stream.filter_map(move |res_event| {
let shared = shared.clone();
async move {
let (completion_id, model, created) = shared.as_ref();
let (completion_id, model, created, has_tool_calls) = shared.as_ref();
match res_event {
ResEvent::Text(text) => Some(Ok(create_frame(
ResEvent::Text(text) => {
Some(Ok(create_text_frame(completion_id, model, *created, &text)))
}
ResEvent::ToolCalls(tool_calls) => {
has_tool_calls.store(true, Ordering::SeqCst);
Some(Ok(create_tool_calls_frame(
completion_id,
model,
*created,
&text,
false,
))),
ResEvent::Done => {
Some(Ok(create_frame(completion_id, model, *created, "", true)))
&tool_calls,
)))
}
ResEvent::Done => Some(Ok(create_done_frame(
completion_id,
model,
*created,
has_tool_calls.load(Ordering::SeqCst),
))),
_ => None,
}
}
@ -488,12 +509,13 @@ struct SearchRagReqBody {
#[derive(Debug, Deserialize)]
struct ChatCompletionsReqBody {
model: String,
messages: Vec<Message>,
messages: Vec<Value>,
temperature: Option<f64>,
top_p: Option<f64>,
max_tokens: Option<isize>,
#[serde(default)]
stream: bool,
tools: Option<Vec<Value>>,
}
#[derive(Debug, Deserialize)]
@ -513,6 +535,7 @@ enum EmbeddingsReqBodyInput {
enum ResEvent {
First(Option<String>),
Text(String),
ToolCalls(Vec<ToolCall>),
Done,
}
@ -542,36 +565,94 @@ fn set_cors_header(res: &mut AppResponse) {
);
}
fn create_frame(id: &str, model: &str, created: i64, content: &str, done: bool) -> Frame<Bytes> {
let (delta, finish_reason) = if done {
(json!({}), "stop".into())
} else {
fn create_text_frame(id: &str, model: &str, created: i64, content: &str) -> Frame<Bytes> {
let delta = if content.is_empty() {
json!({ "role": "assistant", "content": content })
} else {
json!({ "content": content })
};
(delta, Value::Null)
};
let value = json!({
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
let choice = json!({
"index": 0,
"delta": delta,
"finish_reason": finish_reason,
"finish_reason": null,
});
let value = build_chat_completion_chunk_json(id, model, created, &choice);
Frame::data(Bytes::from(format!("data: {value}\n\n")))
}
fn create_tool_calls_frame(
id: &str,
model: &str,
created: i64,
tool_calls: &[ToolCall],
) -> Frame<Bytes> {
let chunks = tool_calls
.iter()
.enumerate()
.flat_map(|(i, call)| {
let choice1 = json!({
"index": 0,
"delta": {
"role": "assistant",
"content": null,
"tool_calls": [
{
"index": i,
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": ""
}
}
]
},
],
"finish_reason": null
});
let output = if done {
format!("data: {value}\n\ndata: [DONE]\n\n")
} else {
format!("data: {value}\n\n")
};
Frame::data(Bytes::from(output))
let choice2 = json!({
"index": 0,
"delta": {
"tool_calls": [
{
"index": i,
"function": {
"arguments": call.arguments.to_string(),
}
}
]
},
"finish_reason": null
});
vec![
build_chat_completion_chunk_json(id, model, created, &choice1),
build_chat_completion_chunk_json(id, model, created, &choice2),
]
})
.map(|v| format!("data: {v}\n\n"))
.collect::<Vec<String>>()
.join("");
Frame::data(Bytes::from(chunks))
}
fn create_done_frame(id: &str, model: &str, created: i64, has_tool_calls: bool) -> Frame<Bytes> {
let finish_reason = if has_tool_calls { "tool_calls" } else { "stop" };
let choice = json!({
"index": 0,
"delta": {},
"finish_reason": finish_reason,
});
let value = build_chat_completion_chunk_json(id, model, created, &choice);
Frame::data(Bytes::from(format!("data: {value}\n\ndata: [DONE]\n\n")))
}
fn build_chat_completion_chunk_json(id: &str, model: &str, created: i64, choice: &Value) -> Value {
json!({
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [choice],
})
}
fn ret_non_stream(id: &str, model: &str, created: i64, output: &ChatCompletionsOutput) -> Bytes {
@ -579,13 +660,8 @@ fn ret_non_stream(id: &str, model: &str, created: i64, output: &ChatCompletionsO
let input_tokens = output.input_tokens.unwrap_or_default();
let output_tokens = output.output_tokens.unwrap_or_default();
let total_tokens = input_tokens + output_tokens;
let res_body = json!({
"id": id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": [
{
let choice = if output.tool_calls.is_empty() {
json!({
"index": 0,
"message": {
"role": "assistant",
@ -593,8 +669,44 @@ fn ret_non_stream(id: &str, model: &str, created: i64, output: &ChatCompletionsO
},
"logprobs": null,
"finish_reason": "stop",
})
} else {
let content = if output.text.is_empty() {
Value::Null
} else {
output.text.clone().into()
};
let tool_calls: Vec<_> = output
.tool_calls
.iter()
.map(|call| {
json!({
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": call.arguments.to_string(),
}
})
})
.collect();
json!({
"index": 0,
"message": {
"role": "assistant",
"content": content,
"tool_calls": tool_calls,
},
],
"logprobs": null,
"finish_reason": "tool_calls",
})
};
let res_body = json!({
"id": id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": [choice],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
@ -616,3 +728,124 @@ fn ret_err<T: std::fmt::Display>(err: T) -> AppResponse {
.body(Full::new(Bytes::from(data.to_string())).boxed())
.unwrap()
}
fn parse_messages(message: Vec<Value>) -> Result<Vec<Message>> {
let mut output = vec![];
let mut tool_results = None;
for (i, message) in message.into_iter().enumerate() {
let err = || anyhow!("Failed to parse '.messages[{i}]'");
let role = message["role"].as_str().ok_or_else(err)?;
let content = match message.get("content") {
Some(value) => {
if let Some(value) = value.as_str() {
MessageContent::Text(value.to_string())
} else if value.is_array() {
let value = serde_json::from_value(value.clone()).map_err(|_| err())?;
MessageContent::Array(value)
} else if value.is_null() {
MessageContent::Text(String::new())
} else {
return Err(err());
}
}
None => MessageContent::Text(String::new()),
};
match role {
"system" | "user" => {
let role = match role {
"system" => MessageRole::System,
"user" => MessageRole::User,
_ => unreachable!(),
};
output.push(Message::new(role, content))
}
"assistant" => {
let role = MessageRole::Assistant;
match message["tool_calls"].as_array() {
Some(tool_calls) => {
if tool_results.is_some() {
return Err(err());
}
let mut list = vec![];
for tool_call in tool_calls {
if let (id, Some(name), Some(arguments)) = (
tool_call["id"].as_str().map(|v| v.to_string()),
tool_call["function"]["name"].as_str(),
tool_call["function"]["arguments"].as_str(),
) {
let arguments =
serde_json::from_str(arguments).map_err(|_| err())?;
list.push((id, name.to_string(), arguments));
} else {
return Err(err());
}
}
tool_results = Some((content.to_text(), list, vec![]));
}
None => output.push(Message::new(role, content)),
}
}
"tool" => match tool_results.take() {
Some((text, tool_calls, mut tool_values)) => {
let tool_call_id = message["tool_call_id"].as_str().map(|v| v.to_string());
let content = content.to_text();
let value: Value = serde_json::from_str(&content)
.ok()
.unwrap_or_else(|| content.into());
tool_values.push((value, tool_call_id));
if tool_calls.len() == tool_values.len() {
let mut list = vec![];
for ((id, name, arguments), (value, tool_call_id)) in
tool_calls.into_iter().zip(tool_values.into_iter())
{
if id != tool_call_id {
return Err(err());
}
list.push(ToolResult::new(ToolCall::new(name, arguments, id), value))
}
output.push(Message::new(
MessageRole::Assistant,
MessageContent::ToolResults((list, text)),
));
tool_results = None;
} else {
tool_results = Some((text, tool_calls, tool_values));
}
}
None => return Err(err()),
},
_ => {
return Err(err());
}
}
}
if tool_results.is_some() {
bail!("Invalid messages");
}
Ok(output)
}
fn parse_tools(tools: Option<Vec<Value>>) -> Result<Option<Vec<FunctionDeclaration>>> {
let tools = match tools {
Some(v) => v,
None => return Ok(None),
};
let mut functions = vec![];
for (i, tool) in tools.into_iter().enumerate() {
if let (Some("function"), Some(function)) = (
tool["type"].as_str(),
tool["function"]
.as_object()
.and_then(|v| serde_json::from_value(json!(v)).ok()),
) {
functions.push(function);
} else {
bail!("Failed to parse '.tools[{i}]'")
}
}
Ok(Some(functions))
}

Loading…
Cancel
Save