diff --git a/Argcfile.sh b/Argcfile.sh index 5706aa5..315a26c 100755 --- a/Argcfile.sh +++ b/Argcfile.sh @@ -247,7 +247,7 @@ models-cohere() { } # @cmd Chat with ollama api -# @option -m --model=codegemma $OLLAMA_MODEL +# @option -m --model=llama3.1:latest $OLLAMA_MODEL # @flag -S --no-stream # @arg text~ chat-ollama() { diff --git a/src/client/ollama.rs b/src/client/ollama.rs index be065c1..5a3d201 100644 --- a/src/client/ollama.rs +++ b/src/client/ollama.rs @@ -1,6 +1,6 @@ use super::*; -use anyhow::{anyhow, bail, Context, Result}; +use anyhow::{bail, Context, Result}; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; @@ -65,10 +65,10 @@ impl OllamaClient { let body = json!({ "model": self.model.name(), - "prompt": data.texts[0], + "input": data.texts, }); - let url = format!("{api_base}/api/embeddings"); + let url = format!("{api_base}/api/embed"); debug!("Ollama Embeddings Request: {url} {body}"); @@ -96,10 +96,8 @@ async fn chat_completions(builder: RequestBuilder) -> Result Result { } let res_body: EmbeddingsResBody = serde_json::from_value(data).context("Invalid embeddings data")?; - let output = vec![res_body.embedding]; - Ok(output) + Ok(res_body.embeddings) } #[derive(Deserialize)] struct EmbeddingsResBody { - embedding: Vec, + embeddings: Vec>, } fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result { @@ -156,22 +153,21 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu messages, temperature, top_p, - functions: _, + functions, stream, } = data; - let mut is_tool_call = false; let mut network_image_urls = vec![]; let messages: Vec = messages .into_iter() - .map(|message| { - let role = message.role; - match message.content { - MessageContent::Text(text) => json!({ + .flat_map(|message| { + let Message { role, content } = message; + match content { + MessageContent::Text(text) => vec![json!({ "role": role, "content": text, - }), + })], MessageContent::Array(list) => { let mut content = vec![]; let mut images = vec![]; @@ -195,20 +191,34 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu } } let content = content.join("\n\n"); - json!({ "role": role, "content": content, "images": images }) - } - MessageContent::ToolResults(_) => { - is_tool_call = true; - json!({ "role": role }) + vec![json!({ "role": role, "content": content, "images": images })] } + MessageContent::ToolResults((tool_results, text)) => { + let tool_calls: Vec<_> = tool_results.iter().map(|tool_result| { + json!({ + "function": { + "name": tool_result.call.name, + "arguments": tool_result.call.arguments, + }, + }) + }).collect(); + let mut messages = vec![ + json!({ "role": MessageRole::Assistant, "content": text, "tool_calls": tool_calls }) + ]; + for tool_result in tool_results { + messages.push( + json!({ + "role": "tool", + "content": tool_result.output.to_string(), + }) + ); + } + messages + }, } }) .collect(); - if is_tool_call { - bail!("The client does not support function calling",); - } - if !network_image_urls.is_empty() { bail!( "The model does not support network images: {:?}", @@ -232,6 +242,50 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu if let Some(v) = top_p { body["options"]["top_p"] = v.into(); } + if let Some(functions) = functions { + body["tools"] = functions + .iter() + .map(|v| { + json!({ + "type": "function", + "function": v, + }) + }) + .collect(); + } Ok(body) } + +fn extract_chat_completions(data: &Value) -> Result { + let text = data["message"]["content"].as_str().unwrap_or_default(); + + let mut tool_calls = vec![]; + if let Some(calls) = data["message"]["tool_calls"].as_array() { + tool_calls = calls + .iter() + .filter_map(|call| { + if let (Some(name), arguments) = ( + call["function"]["name"].as_str(), + call["function"]["arguments"].clone(), + ) { + Some(ToolCall::new(name.to_string(), arguments, None)) + } else { + None + } + }) + .collect() + }; + + if text.is_empty() && tool_calls.is_empty() { + bail!("Invalid response data: {data}"); + } + let output = ChatCompletionsOutput { + text: text.to_string(), + tool_calls, + id: None, + input_tokens: data["prompt_eval_count"].as_u64(), + output_tokens: data["eval_count"].as_u64(), + }; + Ok(output) +}