|
|
|
@ -1,6 +1,6 @@
|
|
|
|
|
use super::*;
|
|
|
|
|
|
|
|
|
|
use anyhow::{anyhow, bail, Context, Result};
|
|
|
|
|
use anyhow::{bail, Context, Result};
|
|
|
|
|
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
|
|
|
|
use serde::Deserialize;
|
|
|
|
|
use serde_json::{json, Value};
|
|
|
|
@ -65,10 +65,10 @@ impl OllamaClient {
|
|
|
|
|
|
|
|
|
|
let body = json!({
|
|
|
|
|
"model": self.model.name(),
|
|
|
|
|
"prompt": data.texts[0],
|
|
|
|
|
"input": data.texts,
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
let url = format!("{api_base}/api/embeddings");
|
|
|
|
|
let url = format!("{api_base}/api/embed");
|
|
|
|
|
|
|
|
|
|
debug!("Ollama Embeddings Request: {url} {body}");
|
|
|
|
|
|
|
|
|
@ -96,10 +96,8 @@ async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutp
|
|
|
|
|
catch_error(&data, status.as_u16())?;
|
|
|
|
|
}
|
|
|
|
|
debug!("non-stream-data: {data}");
|
|
|
|
|
let text = data["message"]["content"]
|
|
|
|
|
.as_str()
|
|
|
|
|
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
|
|
|
|
|
Ok(ChatCompletionsOutput::new(text))
|
|
|
|
|
|
|
|
|
|
extract_chat_completions(&data)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn chat_completions_streaming(
|
|
|
|
@ -142,13 +140,12 @@ async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
|
|
|
|
|
}
|
|
|
|
|
let res_body: EmbeddingsResBody =
|
|
|
|
|
serde_json::from_value(data).context("Invalid embeddings data")?;
|
|
|
|
|
let output = vec![res_body.embedding];
|
|
|
|
|
Ok(output)
|
|
|
|
|
Ok(res_body.embeddings)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Deserialize)]
|
|
|
|
|
struct EmbeddingsResBody {
|
|
|
|
|
embedding: Vec<f32>,
|
|
|
|
|
embeddings: Vec<Vec<f32>>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
|
|
|
|
@ -156,22 +153,21 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
|
|
|
|
|
messages,
|
|
|
|
|
temperature,
|
|
|
|
|
top_p,
|
|
|
|
|
functions: _,
|
|
|
|
|
functions,
|
|
|
|
|
stream,
|
|
|
|
|
} = data;
|
|
|
|
|
|
|
|
|
|
let mut is_tool_call = false;
|
|
|
|
|
let mut network_image_urls = vec![];
|
|
|
|
|
|
|
|
|
|
let messages: Vec<Value> = messages
|
|
|
|
|
.into_iter()
|
|
|
|
|
.map(|message| {
|
|
|
|
|
let role = message.role;
|
|
|
|
|
match message.content {
|
|
|
|
|
MessageContent::Text(text) => json!({
|
|
|
|
|
.flat_map(|message| {
|
|
|
|
|
let Message { role, content } = message;
|
|
|
|
|
match content {
|
|
|
|
|
MessageContent::Text(text) => vec![json!({
|
|
|
|
|
"role": role,
|
|
|
|
|
"content": text,
|
|
|
|
|
}),
|
|
|
|
|
})],
|
|
|
|
|
MessageContent::Array(list) => {
|
|
|
|
|
let mut content = vec![];
|
|
|
|
|
let mut images = vec![];
|
|
|
|
@ -195,20 +191,34 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
let content = content.join("\n\n");
|
|
|
|
|
json!({ "role": role, "content": content, "images": images })
|
|
|
|
|
}
|
|
|
|
|
MessageContent::ToolResults(_) => {
|
|
|
|
|
is_tool_call = true;
|
|
|
|
|
json!({ "role": role })
|
|
|
|
|
vec![json!({ "role": role, "content": content, "images": images })]
|
|
|
|
|
}
|
|
|
|
|
MessageContent::ToolResults((tool_results, text)) => {
|
|
|
|
|
let tool_calls: Vec<_> = tool_results.iter().map(|tool_result| {
|
|
|
|
|
json!({
|
|
|
|
|
"function": {
|
|
|
|
|
"name": tool_result.call.name,
|
|
|
|
|
"arguments": tool_result.call.arguments,
|
|
|
|
|
},
|
|
|
|
|
})
|
|
|
|
|
}).collect();
|
|
|
|
|
let mut messages = vec![
|
|
|
|
|
json!({ "role": MessageRole::Assistant, "content": text, "tool_calls": tool_calls })
|
|
|
|
|
];
|
|
|
|
|
for tool_result in tool_results {
|
|
|
|
|
messages.push(
|
|
|
|
|
json!({
|
|
|
|
|
"role": "tool",
|
|
|
|
|
"content": tool_result.output.to_string(),
|
|
|
|
|
})
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
messages
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
|
|
if is_tool_call {
|
|
|
|
|
bail!("The client does not support function calling",);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if !network_image_urls.is_empty() {
|
|
|
|
|
bail!(
|
|
|
|
|
"The model does not support network images: {:?}",
|
|
|
|
@ -232,6 +242,50 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
|
|
|
|
|
if let Some(v) = top_p {
|
|
|
|
|
body["options"]["top_p"] = v.into();
|
|
|
|
|
}
|
|
|
|
|
if let Some(functions) = functions {
|
|
|
|
|
body["tools"] = functions
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|v| {
|
|
|
|
|
json!({
|
|
|
|
|
"type": "function",
|
|
|
|
|
"function": v,
|
|
|
|
|
})
|
|
|
|
|
})
|
|
|
|
|
.collect();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(body)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
|
|
|
|
let text = data["message"]["content"].as_str().unwrap_or_default();
|
|
|
|
|
|
|
|
|
|
let mut tool_calls = vec![];
|
|
|
|
|
if let Some(calls) = data["message"]["tool_calls"].as_array() {
|
|
|
|
|
tool_calls = calls
|
|
|
|
|
.iter()
|
|
|
|
|
.filter_map(|call| {
|
|
|
|
|
if let (Some(name), arguments) = (
|
|
|
|
|
call["function"]["name"].as_str(),
|
|
|
|
|
call["function"]["arguments"].clone(),
|
|
|
|
|
) {
|
|
|
|
|
Some(ToolCall::new(name.to_string(), arguments, None))
|
|
|
|
|
} else {
|
|
|
|
|
None
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
.collect()
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if text.is_empty() && tool_calls.is_empty() {
|
|
|
|
|
bail!("Invalid response data: {data}");
|
|
|
|
|
}
|
|
|
|
|
let output = ChatCompletionsOutput {
|
|
|
|
|
text: text.to_string(),
|
|
|
|
|
tool_calls,
|
|
|
|
|
id: None,
|
|
|
|
|
input_tokens: data["prompt_eval_count"].as_u64(),
|
|
|
|
|
output_tokens: data["eval_count"].as_u64(),
|
|
|
|
|
};
|
|
|
|
|
Ok(output)
|
|
|
|
|
}
|
|
|
|
|