feat: ollama support tools and new embeddings api (#748)

pull/749/head
sigoden 1 month ago committed by GitHub
parent a6f0196017
commit 3f7ce25709
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -247,7 +247,7 @@ models-cohere() {
}
# @cmd Chat with ollama api
# @option -m --model=codegemma $OLLAMA_MODEL
# @option -m --model=llama3.1:latest $OLLAMA_MODEL
# @flag -S --no-stream
# @arg text~
chat-ollama() {

@ -1,6 +1,6 @@
use super::*;
use anyhow::{anyhow, bail, Context, Result};
use anyhow::{bail, Context, Result};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
@ -65,10 +65,10 @@ impl OllamaClient {
let body = json!({
"model": self.model.name(),
"prompt": data.texts[0],
"input": data.texts,
});
let url = format!("{api_base}/api/embeddings");
let url = format!("{api_base}/api/embed");
debug!("Ollama Embeddings Request: {url} {body}");
@ -96,10 +96,8 @@ async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutp
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
let text = data["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok(ChatCompletionsOutput::new(text))
extract_chat_completions(&data)
}
async fn chat_completions_streaming(
@ -142,13 +140,12 @@ async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
}
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
let output = vec![res_body.embedding];
Ok(output)
Ok(res_body.embeddings)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
embedding: Vec<f32>,
embeddings: Vec<Vec<f32>>,
}
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
@ -156,22 +153,21 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
messages,
temperature,
top_p,
functions: _,
functions,
stream,
} = data;
let mut is_tool_call = false;
let mut network_image_urls = vec![];
let messages: Vec<Value> = messages
.into_iter()
.map(|message| {
let role = message.role;
match message.content {
MessageContent::Text(text) => json!({
.flat_map(|message| {
let Message { role, content } = message;
match content {
MessageContent::Text(text) => vec![json!({
"role": role,
"content": text,
}),
})],
MessageContent::Array(list) => {
let mut content = vec![];
let mut images = vec![];
@ -195,20 +191,34 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
}
}
let content = content.join("\n\n");
json!({ "role": role, "content": content, "images": images })
}
MessageContent::ToolResults(_) => {
is_tool_call = true;
json!({ "role": role })
vec![json!({ "role": role, "content": content, "images": images })]
}
MessageContent::ToolResults((tool_results, text)) => {
let tool_calls: Vec<_> = tool_results.iter().map(|tool_result| {
json!({
"function": {
"name": tool_result.call.name,
"arguments": tool_result.call.arguments,
},
})
}).collect();
let mut messages = vec![
json!({ "role": MessageRole::Assistant, "content": text, "tool_calls": tool_calls })
];
for tool_result in tool_results {
messages.push(
json!({
"role": "tool",
"content": tool_result.output.to_string(),
})
);
}
messages
},
}
})
.collect();
if is_tool_call {
bail!("The client does not support function calling",);
}
if !network_image_urls.is_empty() {
bail!(
"The model does not support network images: {:?}",
@ -232,6 +242,50 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
if let Some(v) = top_p {
body["options"]["top_p"] = v.into();
}
if let Some(functions) = functions {
body["tools"] = functions
.iter()
.map(|v| {
json!({
"type": "function",
"function": v,
})
})
.collect();
}
Ok(body)
}
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["message"]["content"].as_str().unwrap_or_default();
let mut tool_calls = vec![];
if let Some(calls) = data["message"]["tool_calls"].as_array() {
tool_calls = calls
.iter()
.filter_map(|call| {
if let (Some(name), arguments) = (
call["function"]["name"].as_str(),
call["function"]["arguments"].clone(),
) {
Some(ToolCall::new(name.to_string(), arguments, None))
} else {
None
}
})
.collect()
};
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls,
id: None,
input_tokens: data["prompt_eval_count"].as_u64(),
output_tokens: data["eval_count"].as_u64(),
};
Ok(output)
}

Loading…
Cancel
Save