|
|
|
@ -275,8 +275,14 @@ impl Server {
|
|
|
|
|
top_p,
|
|
|
|
|
max_tokens,
|
|
|
|
|
stream,
|
|
|
|
|
tools,
|
|
|
|
|
} = req_body;
|
|
|
|
|
|
|
|
|
|
let messages =
|
|
|
|
|
parse_messages(messages).map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
|
|
|
|
|
|
|
|
let functions = parse_tools(tools).map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
|
|
|
|
|
|
|
|
let config = self.config.clone();
|
|
|
|
|
|
|
|
|
|
let default_model = config.model.clone();
|
|
|
|
@ -309,7 +315,7 @@ impl Server {
|
|
|
|
|
messages,
|
|
|
|
|
temperature,
|
|
|
|
|
top_p,
|
|
|
|
|
functions: None,
|
|
|
|
|
functions,
|
|
|
|
|
stream,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -351,11 +357,17 @@ impl Server {
|
|
|
|
|
let ret = client
|
|
|
|
|
.chat_completions_streaming_inner(http_client, handler, data)
|
|
|
|
|
.await;
|
|
|
|
|
if let Err(err) = ret {
|
|
|
|
|
if is_first.load(Ordering::SeqCst) {
|
|
|
|
|
let _ = tx.send(ResEvent::First(Some(format!("{err:?}"))));
|
|
|
|
|
is_first.store(false, Ordering::SeqCst)
|
|
|
|
|
}
|
|
|
|
|
let first = match ret {
|
|
|
|
|
Ok(()) => None,
|
|
|
|
|
Err(err) => Some(format!("{err:?}")),
|
|
|
|
|
};
|
|
|
|
|
if is_first.load(Ordering::SeqCst) {
|
|
|
|
|
let _ = tx.send(ResEvent::First(first));
|
|
|
|
|
is_first.store(false, Ordering::SeqCst)
|
|
|
|
|
}
|
|
|
|
|
let tool_calls = handler.get_tool_calls();
|
|
|
|
|
if !tool_calls.is_empty() {
|
|
|
|
|
let _ = tx.send(ResEvent::ToolCalls(tool_calls.to_vec()));
|
|
|
|
|
}
|
|
|
|
|
handler.done();
|
|
|
|
|
}
|
|
|
|
@ -378,23 +390,32 @@ impl Server {
|
|
|
|
|
bail!("{err}");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let shared: Arc<(String, String, i64)> = Arc::new((completion_id, model_name, created));
|
|
|
|
|
let shared: Arc<(String, String, i64, AtomicBool)> =
|
|
|
|
|
Arc::new((completion_id, model_name, created, AtomicBool::new(false)));
|
|
|
|
|
let stream = UnboundedReceiverStream::new(rx);
|
|
|
|
|
let stream = stream.filter_map(move |res_event| {
|
|
|
|
|
let shared = shared.clone();
|
|
|
|
|
async move {
|
|
|
|
|
let (completion_id, model, created) = shared.as_ref();
|
|
|
|
|
let (completion_id, model, created, has_tool_calls) = shared.as_ref();
|
|
|
|
|
match res_event {
|
|
|
|
|
ResEvent::Text(text) => Some(Ok(create_frame(
|
|
|
|
|
ResEvent::Text(text) => {
|
|
|
|
|
Some(Ok(create_text_frame(completion_id, model, *created, &text)))
|
|
|
|
|
}
|
|
|
|
|
ResEvent::ToolCalls(tool_calls) => {
|
|
|
|
|
has_tool_calls.store(true, Ordering::SeqCst);
|
|
|
|
|
Some(Ok(create_tool_calls_frame(
|
|
|
|
|
completion_id,
|
|
|
|
|
model,
|
|
|
|
|
*created,
|
|
|
|
|
&tool_calls,
|
|
|
|
|
)))
|
|
|
|
|
}
|
|
|
|
|
ResEvent::Done => Some(Ok(create_done_frame(
|
|
|
|
|
completion_id,
|
|
|
|
|
model,
|
|
|
|
|
*created,
|
|
|
|
|
&text,
|
|
|
|
|
false,
|
|
|
|
|
has_tool_calls.load(Ordering::SeqCst),
|
|
|
|
|
))),
|
|
|
|
|
ResEvent::Done => {
|
|
|
|
|
Some(Ok(create_frame(completion_id, model, *created, "", true)))
|
|
|
|
|
}
|
|
|
|
|
_ => None,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -488,12 +509,13 @@ struct SearchRagReqBody {
|
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
|
|
|
struct ChatCompletionsReqBody {
|
|
|
|
|
model: String,
|
|
|
|
|
messages: Vec<Message>,
|
|
|
|
|
messages: Vec<Value>,
|
|
|
|
|
temperature: Option<f64>,
|
|
|
|
|
top_p: Option<f64>,
|
|
|
|
|
max_tokens: Option<isize>,
|
|
|
|
|
#[serde(default)]
|
|
|
|
|
stream: bool,
|
|
|
|
|
tools: Option<Vec<Value>>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
|
|
@ -513,6 +535,7 @@ enum EmbeddingsReqBodyInput {
|
|
|
|
|
enum ResEvent {
|
|
|
|
|
First(Option<String>),
|
|
|
|
|
Text(String),
|
|
|
|
|
ToolCalls(Vec<ToolCall>),
|
|
|
|
|
Done,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -542,36 +565,94 @@ fn set_cors_header(res: &mut AppResponse) {
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn create_frame(id: &str, model: &str, created: i64, content: &str, done: bool) -> Frame<Bytes> {
|
|
|
|
|
let (delta, finish_reason) = if done {
|
|
|
|
|
(json!({}), "stop".into())
|
|
|
|
|
fn create_text_frame(id: &str, model: &str, created: i64, content: &str) -> Frame<Bytes> {
|
|
|
|
|
let delta = if content.is_empty() {
|
|
|
|
|
json!({ "role": "assistant", "content": content })
|
|
|
|
|
} else {
|
|
|
|
|
let delta = if content.is_empty() {
|
|
|
|
|
json!({ "role": "assistant", "content": content })
|
|
|
|
|
} else {
|
|
|
|
|
json!({ "content": content })
|
|
|
|
|
};
|
|
|
|
|
(delta, Value::Null)
|
|
|
|
|
json!({ "content": content })
|
|
|
|
|
};
|
|
|
|
|
let value = json!({
|
|
|
|
|
let choice = json!({
|
|
|
|
|
"index": 0,
|
|
|
|
|
"delta": delta,
|
|
|
|
|
"finish_reason": null,
|
|
|
|
|
});
|
|
|
|
|
let value = build_chat_completion_chunk_json(id, model, created, &choice);
|
|
|
|
|
Frame::data(Bytes::from(format!("data: {value}\n\n")))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn create_tool_calls_frame(
|
|
|
|
|
id: &str,
|
|
|
|
|
model: &str,
|
|
|
|
|
created: i64,
|
|
|
|
|
tool_calls: &[ToolCall],
|
|
|
|
|
) -> Frame<Bytes> {
|
|
|
|
|
let chunks = tool_calls
|
|
|
|
|
.iter()
|
|
|
|
|
.enumerate()
|
|
|
|
|
.flat_map(|(i, call)| {
|
|
|
|
|
let choice1 = json!({
|
|
|
|
|
"index": 0,
|
|
|
|
|
"delta": {
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
"content": null,
|
|
|
|
|
"tool_calls": [
|
|
|
|
|
{
|
|
|
|
|
"index": i,
|
|
|
|
|
"id": call.id,
|
|
|
|
|
"type": "function",
|
|
|
|
|
"function": {
|
|
|
|
|
"name": call.name,
|
|
|
|
|
"arguments": ""
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"finish_reason": null
|
|
|
|
|
});
|
|
|
|
|
let choice2 = json!({
|
|
|
|
|
"index": 0,
|
|
|
|
|
"delta": {
|
|
|
|
|
"tool_calls": [
|
|
|
|
|
{
|
|
|
|
|
"index": i,
|
|
|
|
|
"function": {
|
|
|
|
|
"arguments": call.arguments.to_string(),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"finish_reason": null
|
|
|
|
|
});
|
|
|
|
|
vec![
|
|
|
|
|
build_chat_completion_chunk_json(id, model, created, &choice1),
|
|
|
|
|
build_chat_completion_chunk_json(id, model, created, &choice2),
|
|
|
|
|
]
|
|
|
|
|
})
|
|
|
|
|
.map(|v| format!("data: {v}\n\n"))
|
|
|
|
|
.collect::<Vec<String>>()
|
|
|
|
|
.join("");
|
|
|
|
|
Frame::data(Bytes::from(chunks))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn create_done_frame(id: &str, model: &str, created: i64, has_tool_calls: bool) -> Frame<Bytes> {
|
|
|
|
|
let finish_reason = if has_tool_calls { "tool_calls" } else { "stop" };
|
|
|
|
|
let choice = json!({
|
|
|
|
|
"index": 0,
|
|
|
|
|
"delta": {},
|
|
|
|
|
"finish_reason": finish_reason,
|
|
|
|
|
});
|
|
|
|
|
let value = build_chat_completion_chunk_json(id, model, created, &choice);
|
|
|
|
|
Frame::data(Bytes::from(format!("data: {value}\n\ndata: [DONE]\n\n")))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn build_chat_completion_chunk_json(id: &str, model: &str, created: i64, choice: &Value) -> Value {
|
|
|
|
|
json!({
|
|
|
|
|
"id": id,
|
|
|
|
|
"object": "chat.completion.chunk",
|
|
|
|
|
"created": created,
|
|
|
|
|
"model": model,
|
|
|
|
|
"choices": [
|
|
|
|
|
{
|
|
|
|
|
"index": 0,
|
|
|
|
|
"delta": delta,
|
|
|
|
|
"finish_reason": finish_reason,
|
|
|
|
|
},
|
|
|
|
|
],
|
|
|
|
|
});
|
|
|
|
|
let output = if done {
|
|
|
|
|
format!("data: {value}\n\ndata: [DONE]\n\n")
|
|
|
|
|
} else {
|
|
|
|
|
format!("data: {value}\n\n")
|
|
|
|
|
};
|
|
|
|
|
Frame::data(Bytes::from(output))
|
|
|
|
|
"choices": [choice],
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn ret_non_stream(id: &str, model: &str, created: i64, output: &ChatCompletionsOutput) -> Bytes {
|
|
|
|
@ -579,22 +660,53 @@ fn ret_non_stream(id: &str, model: &str, created: i64, output: &ChatCompletionsO
|
|
|
|
|
let input_tokens = output.input_tokens.unwrap_or_default();
|
|
|
|
|
let output_tokens = output.output_tokens.unwrap_or_default();
|
|
|
|
|
let total_tokens = input_tokens + output_tokens;
|
|
|
|
|
let choice = if output.tool_calls.is_empty() {
|
|
|
|
|
json!({
|
|
|
|
|
"index": 0,
|
|
|
|
|
"message": {
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
"content": output.text,
|
|
|
|
|
},
|
|
|
|
|
"logprobs": null,
|
|
|
|
|
"finish_reason": "stop",
|
|
|
|
|
})
|
|
|
|
|
} else {
|
|
|
|
|
let content = if output.text.is_empty() {
|
|
|
|
|
Value::Null
|
|
|
|
|
} else {
|
|
|
|
|
output.text.clone().into()
|
|
|
|
|
};
|
|
|
|
|
let tool_calls: Vec<_> = output
|
|
|
|
|
.tool_calls
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|call| {
|
|
|
|
|
json!({
|
|
|
|
|
"id": call.id,
|
|
|
|
|
"type": "function",
|
|
|
|
|
"function": {
|
|
|
|
|
"name": call.name,
|
|
|
|
|
"arguments": call.arguments.to_string(),
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
})
|
|
|
|
|
.collect();
|
|
|
|
|
json!({
|
|
|
|
|
"index": 0,
|
|
|
|
|
"message": {
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
"content": content,
|
|
|
|
|
"tool_calls": tool_calls,
|
|
|
|
|
},
|
|
|
|
|
"logprobs": null,
|
|
|
|
|
"finish_reason": "tool_calls",
|
|
|
|
|
})
|
|
|
|
|
};
|
|
|
|
|
let res_body = json!({
|
|
|
|
|
"id": id,
|
|
|
|
|
"object": "chat.completion",
|
|
|
|
|
"created": created,
|
|
|
|
|
"model": model,
|
|
|
|
|
"choices": [
|
|
|
|
|
{
|
|
|
|
|
"index": 0,
|
|
|
|
|
"message": {
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
"content": output.text,
|
|
|
|
|
},
|
|
|
|
|
"logprobs": null,
|
|
|
|
|
"finish_reason": "stop",
|
|
|
|
|
},
|
|
|
|
|
],
|
|
|
|
|
"choices": [choice],
|
|
|
|
|
"usage": {
|
|
|
|
|
"prompt_tokens": input_tokens,
|
|
|
|
|
"completion_tokens": output_tokens,
|
|
|
|
@ -616,3 +728,124 @@ fn ret_err<T: std::fmt::Display>(err: T) -> AppResponse {
|
|
|
|
|
.body(Full::new(Bytes::from(data.to_string())).boxed())
|
|
|
|
|
.unwrap()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn parse_messages(message: Vec<Value>) -> Result<Vec<Message>> {
|
|
|
|
|
let mut output = vec![];
|
|
|
|
|
let mut tool_results = None;
|
|
|
|
|
for (i, message) in message.into_iter().enumerate() {
|
|
|
|
|
let err = || anyhow!("Failed to parse '.messages[{i}]'");
|
|
|
|
|
let role = message["role"].as_str().ok_or_else(err)?;
|
|
|
|
|
let content = match message.get("content") {
|
|
|
|
|
Some(value) => {
|
|
|
|
|
if let Some(value) = value.as_str() {
|
|
|
|
|
MessageContent::Text(value.to_string())
|
|
|
|
|
} else if value.is_array() {
|
|
|
|
|
let value = serde_json::from_value(value.clone()).map_err(|_| err())?;
|
|
|
|
|
MessageContent::Array(value)
|
|
|
|
|
} else if value.is_null() {
|
|
|
|
|
MessageContent::Text(String::new())
|
|
|
|
|
} else {
|
|
|
|
|
return Err(err());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
None => MessageContent::Text(String::new()),
|
|
|
|
|
};
|
|
|
|
|
match role {
|
|
|
|
|
"system" | "user" => {
|
|
|
|
|
let role = match role {
|
|
|
|
|
"system" => MessageRole::System,
|
|
|
|
|
"user" => MessageRole::User,
|
|
|
|
|
_ => unreachable!(),
|
|
|
|
|
};
|
|
|
|
|
output.push(Message::new(role, content))
|
|
|
|
|
}
|
|
|
|
|
"assistant" => {
|
|
|
|
|
let role = MessageRole::Assistant;
|
|
|
|
|
match message["tool_calls"].as_array() {
|
|
|
|
|
Some(tool_calls) => {
|
|
|
|
|
if tool_results.is_some() {
|
|
|
|
|
return Err(err());
|
|
|
|
|
}
|
|
|
|
|
let mut list = vec![];
|
|
|
|
|
for tool_call in tool_calls {
|
|
|
|
|
if let (id, Some(name), Some(arguments)) = (
|
|
|
|
|
tool_call["id"].as_str().map(|v| v.to_string()),
|
|
|
|
|
tool_call["function"]["name"].as_str(),
|
|
|
|
|
tool_call["function"]["arguments"].as_str(),
|
|
|
|
|
) {
|
|
|
|
|
let arguments =
|
|
|
|
|
serde_json::from_str(arguments).map_err(|_| err())?;
|
|
|
|
|
list.push((id, name.to_string(), arguments));
|
|
|
|
|
} else {
|
|
|
|
|
return Err(err());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
tool_results = Some((content.to_text(), list, vec![]));
|
|
|
|
|
}
|
|
|
|
|
None => output.push(Message::new(role, content)),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
"tool" => match tool_results.take() {
|
|
|
|
|
Some((text, tool_calls, mut tool_values)) => {
|
|
|
|
|
let tool_call_id = message["tool_call_id"].as_str().map(|v| v.to_string());
|
|
|
|
|
let content = content.to_text();
|
|
|
|
|
let value: Value = serde_json::from_str(&content)
|
|
|
|
|
.ok()
|
|
|
|
|
.unwrap_or_else(|| content.into());
|
|
|
|
|
|
|
|
|
|
tool_values.push((value, tool_call_id));
|
|
|
|
|
|
|
|
|
|
if tool_calls.len() == tool_values.len() {
|
|
|
|
|
let mut list = vec![];
|
|
|
|
|
for ((id, name, arguments), (value, tool_call_id)) in
|
|
|
|
|
tool_calls.into_iter().zip(tool_values.into_iter())
|
|
|
|
|
{
|
|
|
|
|
if id != tool_call_id {
|
|
|
|
|
return Err(err());
|
|
|
|
|
}
|
|
|
|
|
list.push(ToolResult::new(ToolCall::new(name, arguments, id), value))
|
|
|
|
|
}
|
|
|
|
|
output.push(Message::new(
|
|
|
|
|
MessageRole::Assistant,
|
|
|
|
|
MessageContent::ToolResults((list, text)),
|
|
|
|
|
));
|
|
|
|
|
tool_results = None;
|
|
|
|
|
} else {
|
|
|
|
|
tool_results = Some((text, tool_calls, tool_values));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
None => return Err(err()),
|
|
|
|
|
},
|
|
|
|
|
_ => {
|
|
|
|
|
return Err(err());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if tool_results.is_some() {
|
|
|
|
|
bail!("Invalid messages");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(output)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn parse_tools(tools: Option<Vec<Value>>) -> Result<Option<Vec<FunctionDeclaration>>> {
|
|
|
|
|
let tools = match tools {
|
|
|
|
|
Some(v) => v,
|
|
|
|
|
None => return Ok(None),
|
|
|
|
|
};
|
|
|
|
|
let mut functions = vec![];
|
|
|
|
|
for (i, tool) in tools.into_iter().enumerate() {
|
|
|
|
|
if let (Some("function"), Some(function)) = (
|
|
|
|
|
tool["type"].as_str(),
|
|
|
|
|
tool["function"]
|
|
|
|
|
.as_object()
|
|
|
|
|
.and_then(|v| serde_json::from_value(json!(v)).ok()),
|
|
|
|
|
) {
|
|
|
|
|
functions.push(function);
|
|
|
|
|
} else {
|
|
|
|
|
bail!("Failed to parse '.tools[{i}]'")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Ok(Some(functions))
|
|
|
|
|
}
|
|
|
|
|