feat: support function calling (#514)

* feat: support function calling

* fix on Windows OS

* implement multi-steps function calling

* fix on Windows OS

* add error for client not support function calling

* refactor message data structure and make claude client supporting function calling

* support reuse previous call results

* improve error handling for function calling

* use prefix `may_`  as indicator for `execute` type fucntions
pull/518/head
sigoden 1 month ago committed by GitHub
parent 1348a62e5f
commit b4a40e3fed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

3
.gitignore vendored

@ -1,2 +1,3 @@
/target
/tmp
/tmp
*.log

@ -24,6 +24,34 @@ test-platform-env() {
cargo run -- "$@"
}
# @cmd Test function calling
# @option --model[?`_choice_model`]
# @option --preset[=default|weather|multi-weathers]
# @flag -S --no-stream
# @arg text~
test-function-calling() {
args=(--role %functions%)
if [[ -n "$argc_model" ]]; then
args+=("--model" "$argc_model")
fi
if [[ -n "$argc_no_stream" ]]; then
args+=("-S")
fi
if [[ -z "$argc_text" ]]; then
case "$argc_preset" in
multi-weathers)
text="what is the weather in London and Pairs?"
;;
weather|*)
text="what is the weather in London?"
;;
esac
else
text="${argc_text[*]}"
fi
cargo run -- "${args[@]}" "$text"
}
# @cmd Test clients
# @arg clients+[`_choice_client`]
test-clients() {
@ -36,7 +64,7 @@ test-clients() {
}
# @cmd Test proxy server
# @option -m --model[`_choice_model`]
# @option -m --model[?`_choice_model`]
# @flag -S --no-stream
# @arg text~
test-server() {
@ -153,10 +181,7 @@ chat-gemini() {
_wrapper curl -i "https://generativelanguage.googleapis.com/v1beta/models/${argc_model}:${method}?key=${GEMINI_API_KEY}" \
-i -X POST \
-H 'Content-Type: application/json' \
-d '{
"safetySettings":[{"category":"HARM_CATEGORY_HARASSMENT","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_HATE_SPEECH","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":"BLOCK_ONLY_HIGH"}],
"contents": '"$(_build_msg_gemini $*)"'
}'
-d "$(_build_body gemini "$@")"
}
# @cmd List gemini models
@ -177,14 +202,9 @@ chat-claude() {
-X POST \
-H 'content-type: application/json' \
-H 'anthropic-version: 2023-06-01' \
-H 'anthropic-beta: tools-2024-05-16' \
-H "x-api-key: $CLAUDE_API_KEY" \
-d '{
"model": "'$argc_model'",
"messages": '"$(_build_msg $*)"',
"max_tokens": 4096,
"stream": '$stream'
}
'
-d "$(_build_body claude "$@")"
}
# @cmd Chat with cohere api
@ -221,11 +241,7 @@ chat-ollama() {
_wrapper curl -i http://localhost:11434/api/chat \
-X POST \
-H 'Content-Type: application/json' \
-d '{
"model": "'$argc_model'",
"stream": '$stream',
"messages": '"$(_build_msg $*)"'
}'
-d "$(_build_body ollama "$@")"
}
# @cmd Chat with vertexai api
@ -246,10 +262,7 @@ chat-vertexai() {
-X POST \
-H "Authorization: Bearer $api_key" \
-H 'Content-Type: application/json' \
-d '{
"contents": '"$(_build_msg_gemini $*)"',
"generationConfig": {}
}'
-d "$(_build_body vertexai "$@")"
}
# @cmd Chat with vertexai-claude api
@ -266,12 +279,7 @@ chat-vertexai-claude() {
-X POST \
-H "Authorization: Bearer $api_key" \
-H 'Content-Type: application/json' \
-d '{
"anthropic_version": "vertex-2023-10-16",
"messages": '"$(_build_msg $*)"',
"max_tokens": 4096,
"stream": '$stream'
}'
-d "$(_build_body vertexai-claude "$@")"
}
# @cmd Chat with bedrock api
@ -285,11 +293,7 @@ chat-bedrock() {
body='{"prompt":"'"$*"'"}'
;;
anthropic.*)
body='{
"anthropic_version": "vertex-2023-10-16",
"messages": '"$(_build_msg $*)"',
"max_tokens": 4096
}'
body="$(_build_body bedrock-claude "$@")"
;;
*)
_die "Invalid model: $argc_model"
@ -314,10 +318,7 @@ chat-cloudflare() {
_wrapper curl -i "$url" \
-X POST \
-H "Authorization: Bearer $CLOUDFLARE_API_KEY" \
-d '{
"messages": '"$(_build_msg $*)"',
"stream": '$stream'
}'
-d "$(_build_body cloudflare "$@")"
}
# @cmd Chat with replicate api
@ -331,12 +332,8 @@ chat-replicate() {
-X POST \
-H "Authorization: Bearer $REPLICATE_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"stream": '$stream',
"input": {
"prompt": "'"$*"'"
}
}')"
-d "$(_build_body replicate "$@")" \
)"
echo "$res"
if [[ -n "$argc_no_stream" ]]; then
prediction_url="$(echo "$res" | jq -r '.urls.get')"
@ -373,10 +370,7 @@ chat-ernie() {
url="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/$argc_model?access_token=$ACCESS_TOKEN"
_wrapper curl -i "$url" \
-X POST \
-d '{
"messages": '"$(_build_msg $*)"',
"stream": '$stream'
}'
-d "$(_build_body ernie "$@")"
}
@ -397,13 +391,7 @@ chat-qianwen() {
-X POST \
-H "Authorization: Bearer $QIANWEN_API_KEY" \
-H 'Content-Type: application/json' $stream_args \
-d '{
"model": "'$argc_model'",
"parameters": '"$parameters_args"',
"input":{
"messages": '"$(_build_msg $*)"'
}
}'
-d "$(_build_body qianwen "$@")"
}
_argc_before() {
@ -420,12 +408,7 @@ _openai_chat() {
-X POST \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $api_key" \
--data '{
"model": "'$argc_model'",
"messages": '"$(_build_msg $*)"',
"stream": '$stream'
}
'
-d "$(_build_body openai "$@")"
}
_openai_models() {
@ -460,35 +443,112 @@ _choice_openai_compatible_platform() {
done
}
_build_msg() {
if [[ $# -eq 0 ]]; then
cat tmp/messages.json
else
echo '
[
{
"role": "user",
"content": "'"$*"'"
}
]
'
fi
}
_build_body() {
kind="$1"
if [[ "$#" -eq 1 ]]; then
file="${BODY_FILE:-"tmp/body/$1.json"}"
if [[ -f "$file" ]]; then
cat "$file" | \
sed \
-e 's/"model": ".*"/"model": "'"$argc_model"'"/' \
-e 's/"stream": \(true\|false\)/"stream": '$stream'/' \
_build_msg_gemini() {
if [[ $# -eq 0 ]]; then
cat tmp/messages.gemini.json
fi
else
echo '
[{
"role": "user",
"parts": [
shift
case "$kind" in
openai|ollama)
echo '{
"model": "'$argc_model'",
"messages": [
{
"text": "'"$*"'"
"role": "user",
"content": "'"$*"'"
}
]
}]
'
],
"stream": '$stream'
}'
;;
claude)
echo '{
"model": "'$argc_model'",
"messages": [
{
"role": "user",
"content": "'"$*"'"
}
],
"max_tokens": 4096,
"stream": '$stream'
}'
;;
vertexai-claude|bedrock-claude)
echo '{
"anthropic_version": "vertex-2023-10-16",
"messages": [
{
"role": "user",
"content": "'"$*"'"
}
],
"max_tokens": 4096,
"stream": '$stream'
}'
;;
gemini|vertexai)
echo '{
"contents": [{
"role": "user",
"parts": [
{
"text": "'"$*"'"
}
]
}],
"safetySettings":[{"category":"HARM_CATEGORY_HARASSMENT","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_HATE_SPEECH","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":"BLOCK_ONLY_HIGH"}]
}'
;;
ernie|cloudflare)
echo '{
"messages": [
{
"role": "user",
"content": "'"$*"'"
}
],
"stream": '$stream'
}'
;;
replicate)
echo '{
"stream": '$stream',
"input": {
"prompt": "'"$*"'"
}
}'
;;
qianwen)
echo '{
"model": "'$argc_model'",
"parameters": '"$parameters_args"',
"input":{
"messages": [
{
"role": "user",
"content": "'"$*"'"
}
]
}
}'
;;
*)
_die "Unsupported build body for $kind"
;;
esac
fi
}

13
Cargo.lock generated

@ -38,7 +38,6 @@ dependencies = [
"aws-smithy-eventstream",
"base64 0.22.1",
"bincode",
"bitflags 2.5.0",
"bstr",
"bytes",
"chrono",
@ -59,6 +58,7 @@ dependencies = [
"log",
"mime_guess",
"nu-ansi-term 0.50.0",
"num_cpus",
"parking_lot",
"reedline",
"reqwest",
@ -72,6 +72,7 @@ dependencies = [
"simplelog",
"syntect",
"textwrap",
"threadpool",
"time",
"tokio",
"tokio-graceful",
@ -1071,6 +1072,7 @@ checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26"
dependencies = [
"equivalent",
"hashbrown",
"serde",
]
[[package]]
@ -2229,6 +2231,15 @@ dependencies = [
"once_cell",
]
[[package]]
name = "threadpool"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa"
dependencies = [
"num_cpus",
]
[[package]]
name = "time"
version = "0.3.36"

@ -44,19 +44,20 @@ log = "0.4.20"
shell-words = "1.1.0"
mime_guess = "2.0.4"
sha2 = "0.10.8"
bitflags = "2.4.1"
unicode-width = "0.1.11"
async-recursion = "1.1.0"
async-recursion = "1.1.1"
http = "1.1.0"
http-body-util = "0.1"
hyper = { version = "1.0", features = ["full"] }
hyper-util = { version = "0.1", features = ["server-auto", "client-legacy"] }
time = { version = "0.3.36", features = ["macros"] }
indexmap = "2.2.6"
indexmap = { version = "2.2.6", features = ["serde"] }
hmac = "0.12.1"
aws-smithy-eventstream = "0.60.4"
urlencoding = "2.1.3"
unicode-segmentation = "1.11.0"
num_cpus = "1.16.0"
threadpool = "1.8.1"
[dependencies.reqwest]
version = "0.12.0"

@ -15,6 +15,9 @@ prelude: null # Set a default role or session to start with (
# if unset fallback to $EDITOR and $VISUAL
buffer_editor: null
# Controls the function calling feature. For setup instructions, visit https://github.com/sigoden/llm-functions.
function_calling: false
# Compress session when token count reaches or exceeds this threshold (must be at least 1000)
compress_threshold: 1000
# Text prompt used for creating a concise summary of session message

@ -15,33 +15,39 @@
max_output_tokens: 4096
input_price: 0.5
output_price: 1.5
supports_function_calling: true
- name: gpt-3.5-turbo-1106
max_input_tokens: 16385
max_output_tokens: 4096
input_price: 1
output_price: 2
supports_function_calling: true
- name: gpt-4o
max_input_tokens: 128000
max_output_tokens: 4096
input_price: 5
output_price: 15
supports_vision: true
supports_function_calling: true
- name: gpt-4-turbo
max_input_tokens: 128000
max_output_tokens: 4096
input_price: 10
output_price: 30
supports_vision: true
supports_function_calling: true
- name: gpt-4-turbo-preview
max_input_tokens: 128000
max_output_tokens: 4096
input_price: 10
output_price: 30
supports_function_calling: true
- name: gpt-4-1106-preview
max_input_tokens: 128000
max_output_tokens: 4096
input_price: 10
output_price: 30
supports_function_calling: true
- name: gpt-4-vision-preview
max_input_tokens: 128000
max_output_tokens: 4096
@ -73,6 +79,7 @@
max_output_tokens: 2048
input_price: 0.5
output_price: 1.5
supports_function_calling: true
- name: gemini-1.0-pro-vision-latest
max_input_tokens: 12288
max_output_tokens: 4096
@ -85,12 +92,14 @@
input_price: 0.35
output_price: 0.53
supports_vision: true
supports_function_calling: true
- name: gemini-1.5-pro-latest
max_input_tokens: 1048576
max_output_tokens: 8192
input_price: 3.5
output_price: 10.5
supports_vision: true
supports_function_calling: true
- platform: claude
# docs:
@ -106,6 +115,7 @@
input_price: 15
output_price: 75
supports_vision: true
supports_function_calling: true
- name: claude-3-sonnet-20240229
max_input_tokens: 200000
max_output_tokens: 4096
@ -113,6 +123,7 @@
input_price: 3
output_price: 15
supports_vision: true
supports_function_calling: true
- name: claude-3-haiku-20240307
max_input_tokens: 200000
max_output_tokens: 4096
@ -120,6 +131,7 @@
input_price: 0.25
output_price: 1.25
supports_vision: true
supports_function_calling: true
- platform: mistral
# docs:
@ -149,6 +161,7 @@
max_input_tokens: 32000
input_price: 8
output_price: 24
supports_function_calling: true
- platform: cohere
# docs:
@ -163,11 +176,13 @@
max_output_tokens: 4000
input_price: 0.5
output_price: 1.5
supports_function_calling: true
- name: command-r-plus
max_input_tokens: 128000
max_output_tokens: 4000
input_price: 3
output_price: 15
supports_function_calling: true
- platform: perplexity
# docs:
@ -242,12 +257,13 @@
# notes:
# - get max_output_tokens info from models doc
models:
- name: gemini-1.0-pro
- name: gemini-1.0-pro-002
max_input_tokens: 24568
max_output_tokens: 8192
input_price: 0.125
output_price: 0.375
- name: gemini-1.0-pro-vision
supports_function_calling: true
- name: gemini-1.0-pro-vision-001
max_input_tokens: 14336
max_output_tokens: 2048
input_price: 0.125
@ -387,6 +403,7 @@
# docs:
# - https://replicate.com/docs
# - https://replicate.com/pricing
# - https://replicate.com/docs/reference/http
# notes:
# - max_output_tokens is required but unknown
models:
@ -695,20 +712,24 @@
max_input_tokens: 16385
input_price: 0.5
output_price: 1.5
supports_function_calling: true
- name: openai/gpt-4o
max_input_tokens: 128000
input_price: 5
output_price: 15
supports_vision: true
supports_function_calling: true
- name: openai/gpt-4-turbo
max_input_tokens: 128000
input_price: 10
output_price: 30
supports_vision: true
supports_function_calling: true
- name: openai/gpt-4-turbo-preview
max_input_tokens: 128000
input_price: 10
output_price: 30
supports_function_calling: true
- name: openai/gpt-4-vision-preview
max_input_tokens: 128000
max_output_tokens: 4096

@ -1,7 +1,5 @@
use super::openai::openai_build_body;
use super::{
AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData,
};
use super::{AzureOpenAIClient, ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData};
use anyhow::Result;
use reqwest::{Client as ReqwestClient, RequestBuilder};
@ -12,7 +10,7 @@ pub struct AzureOpenAIConfig {
pub name: Option<String>,
pub api_base: Option<String>,
pub api_key: Option<String>,
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -41,7 +39,8 @@ impl AzureOpenAIClient {
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2024-02-01",
&api_base, self.model.name
&api_base,
self.model.name()
);
debug!("AzureOpenAI Request: {url} {body}");

@ -1,8 +1,8 @@
use super::claude::{claude_build_body, claude_extract_completion};
use super::{
catch_error, generate_prompt, BedrockClient, Client, CompletionDetails, ExtraConfig, Model,
ModelConfig, PromptAction, PromptFormat, PromptKind, SendData, SseHandler,
LLAMA3_PROMPT_FORMAT, MISTRAL_PROMPT_FORMAT,
catch_error, generate_prompt, BedrockClient, Client, CompletionOutput, ExtraConfig, Model,
ModelData, PromptAction, PromptFormat, PromptKind, SendData, SseHandler, LLAMA3_PROMPT_FORMAT,
MISTRAL_PROMPT_FORMAT,
};
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
@ -30,7 +30,7 @@ pub struct BedrockConfig {
pub secret_access_key: Option<String>,
pub region: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -42,8 +42,8 @@ impl Client for BedrockClient {
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionDetails)> {
let model_category = ModelCategory::from_str(&self.model.name)?;
) -> Result<CompletionOutput> {
let model_category = ModelCategory::from_str(self.model.name())?;
let builder = self.request_builder(client, data, &model_category)?;
send_message(builder, &model_category).await
}
@ -54,7 +54,7 @@ impl Client for BedrockClient {
handler: &mut SseHandler,
data: SendData,
) -> Result<()> {
let model_category = ModelCategory::from_str(&self.model.name)?;
let model_category = ModelCategory::from_str(self.model.name())?;
let builder = self.request_builder(client, data, &model_category)?;
send_message_streaming(builder, handler, &model_category).await
}
@ -91,7 +91,7 @@ impl BedrockClient {
let secret_access_key = self.get_secret_access_key()?;
let region = self.get_region()?;
let model_name = &self.model.name;
let model_name = &self.model.name();
let uri = if data.stream {
format!("/model/{model_name}/invoke-with-response-stream")
} else {
@ -129,7 +129,7 @@ impl BedrockClient {
async fn send_message(
builder: RequestBuilder,
model_category: &ModelCategory,
) -> Result<(String, CompletionDetails)> {
) -> Result<CompletionOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -138,6 +138,7 @@ async fn send_message(
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
match model_category {
ModelCategory::Anthropic => claude_extract_completion(&data),
ModelCategory::MetaLlama3 => llama_extract_completion(&data),
@ -172,7 +173,7 @@ async fn send_message_streaming(
let data: Value = decode_chunk(message.payload()).ok_or_else(|| {
anyhow!("Invalid chunk data: {}", hex_encode(message.payload()))
})?;
// debug!("bedrock chunk: {data}");
debug!("stream-data: {data}");
match model_category {
ModelCategory::Anthropic => {
if let Some(typ) = data["type"].as_str() {
@ -230,6 +231,7 @@ fn meta_llama_build_body(data: SendData, model: &Model, pt: PromptFormat) -> Res
messages,
temperature,
top_p,
functions: _,
stream: _,
} = data;
let prompt = generate_prompt(&messages, pt)?;
@ -253,6 +255,7 @@ fn mistral_build_body(data: SendData, model: &Model) -> Result<Value> {
messages,
temperature,
top_p,
functions: _,
stream: _,
} = data;
let prompt = generate_prompt(&messages, MISTRAL_PROMPT_FORMAT)?;
@ -271,23 +274,25 @@ fn mistral_build_body(data: SendData, model: &Model) -> Result<Value> {
Ok(body)
}
fn llama_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
fn llama_extract_completion(data: &Value) -> Result<CompletionOutput> {
let text = data["generation"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let details = CompletionDetails {
let output = CompletionOutput {
text: text.to_string(),
tool_calls: vec![],
id: None,
input_tokens: data["prompt_token_count"].as_u64(),
output_tokens: data["generation_token_count"].as_u64(),
};
Ok((text.to_string(), details))
Ok(output)
}
fn mistral_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
fn mistral_extract_completion(data: &Value) -> Result<CompletionOutput> {
let text = data["outputs"][0]["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok((text.to_string(), CompletionDetails::default()))
Ok(CompletionOutput::new(text))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]

@ -1,10 +1,10 @@
use super::{
catch_error, extract_system_message, sse_stream, ClaudeClient, CompletionDetails, ExtraConfig,
ImageUrl, MessageContent, MessageContentPart, Model, ModelConfig, PromptAction, PromptKind,
SendData, SsMmessage, SseHandler,
catch_error, extract_system_message, message::*, sse_stream, ClaudeClient, CompletionOutput,
ExtraConfig, ImageUrl, MessageContent, MessageContentPart, Model, ModelData, PromptAction,
PromptKind, SendData, SsMmessage, SseHandler, ToolCall,
};
use anyhow::{anyhow, bail, Result};
use anyhow::{bail, Context, Result};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
@ -16,7 +16,7 @@ pub struct ClaudeConfig {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -36,7 +36,9 @@ impl ClaudeClient {
debug!("Claude Request: {url} {body}");
let mut builder = client.post(url).json(&body);
builder = builder.header("anthropic-version", "2023-06-01");
builder = builder
.header("anthropic-version", "2023-06-01")
.header("anthropic-beta", "tools-2024-05-16");
if let Some(api_key) = api_key {
builder = builder.header("x-api-key", api_key)
}
@ -51,13 +53,14 @@ impl_client_trait!(
claude_send_message_streaming
);
pub async fn claude_send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
pub async fn claude_send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
claude_extract_completion(&data)
}
@ -65,13 +68,59 @@ pub async fn claude_send_message_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let mut function_name = String::new();
let mut function_arguments = String::new();
let mut function_id = String::new();
let handle = |message: SsMmessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(typ) = data["type"].as_str() {
if typ == "content_block_delta" {
if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?;
match typ {
"content_block_start" => {
if let (Some("tool_use"), Some(name), Some(id)) = (
data["content_block"]["type"].as_str(),
data["content_block"]["name"].as_str(),
data["content_block"]["id"].as_str(),
) {
if !function_name.is_empty() {
let arguments: Value =
function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
function_name = name.into();
function_arguments.clear();
function_id = id.into();
}
}
"content_block_delta" => {
if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?;
} else if let (true, Some(partial_json)) = (
!function_name.is_empty(),
data["delta"]["partial_json"].as_str(),
) {
function_arguments.push_str(partial_json);
}
}
"content_block_stop" => {
if !function_name.is_empty() {
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
}
_ => {}
}
}
Ok(false)
@ -85,46 +134,91 @@ pub fn claude_build_body(data: SendData, model: &Model) -> Result<Value> {
mut messages,
temperature,
top_p,
functions,
stream,
} = data;
let system_message = extract_system_message(&mut messages);
let mut network_image_urls = vec![];
let messages: Vec<Value> = messages
.into_iter()
.map(|message| {
let role = message.role;
let content = match message.content {
MessageContent::Text(text) => vec![json!({"type": "text", "text": text})],
MessageContent::Array(list) => list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => json!({"type": "text", "text": text}),
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if let Some((mime_type, data)) = url
.strip_prefix("data:")
.and_then(|v| v.split_once(";base64,"))
{
json!({
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": data,
}
})
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
.flat_map(|message| {
let Message { role, content } = message;
match content {
MessageContent::Text(text) => vec![json!({
"role": role,
"content": text,
})],
MessageContent::Array(list) => {
let content: Vec<_> = list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => {
json!({"type": "text", "text": text})
}
}
})
.collect(),
};
json!({ "role": role, "content": content })
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if let Some((mime_type, data)) = url
.strip_prefix("data:")
.and_then(|v| v.split_once(";base64,"))
{
json!({
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": data,
}
})
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
}
}
})
.collect();
vec![json!({
"role": role,
"content": content,
})]
}
MessageContent::ToolResults((tool_call_results, text)) => {
let mut tool_call = vec![];
let mut tool_result = vec![];
if !text.is_empty() {
tool_call.push(json!({
"type": "text",
"text": text,
}))
}
for tool_call_result in tool_call_results {
tool_call.push(json!({
"type": "tool_use",
"id": tool_call_result.call.id,
"name": tool_call_result.call.name,
"input": tool_call_result.call.arguments,
}));
tool_result.push(json!({
"type": "tool_result",
"tool_use_id": tool_call_result.call.id,
"content": tool_call_result.output.to_string(),
}));
}
vec![
json!({
"role": "assistant",
"content": tool_call,
}),
json!({
"role": "user",
"content": tool_result,
}),
]
}
}
})
.collect();
@ -136,7 +230,7 @@ pub fn claude_build_body(data: SendData, model: &Model) -> Result<Value> {
}
let mut body = json!({
"model": &model.name,
"model": model.name(),
"messages": messages,
});
if let Some(v) = system_message {
@ -154,18 +248,61 @@ pub fn claude_build_body(data: SendData, model: &Model) -> Result<Value> {
if stream {
body["stream"] = true.into();
}
if let Some(functions) = functions {
body["tools"] = functions
.iter()
.map(|v| {
json!({
"name": v.name,
"description": v.description,
"input_schema": v.parameters,
})
})
.collect();
}
Ok(body)
}
pub fn claude_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["content"][0]["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
pub fn claude_extract_completion(data: &Value) -> Result<CompletionOutput> {
let text = data["content"][0]["text"].as_str().unwrap_or_default();
let mut tool_calls = vec![];
if let Some(calls) = data["content"].as_array().map(|content| {
content
.iter()
.filter(|content| matches!(content["type"].as_str(), Some("tool_use")))
.collect::<Vec<&Value>>()
}) {
tool_calls = calls
.into_iter()
.filter_map(|call| {
if let (Some(name), Some(input), Some(id)) = (
call["name"].as_str(),
call.get("input"),
call["id"].as_str(),
) {
Some(ToolCall::new(
name.to_string(),
input.clone(),
Some(id.to_string()),
))
} else {
None
}
})
.collect();
};
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let details = CompletionDetails {
let output = CompletionOutput {
text: text.to_string(),
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["output_tokens"].as_u64(),
};
Ok((text.to_string(), details))
Ok(output)
}

@ -1,5 +1,5 @@
use super::{
catch_error, sse_stream, CloudflareClient, CompletionDetails, ExtraConfig, Model, ModelConfig,
catch_error, sse_stream, CloudflareClient, CompletionOutput, ExtraConfig, Model, ModelData,
PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
};
@ -16,7 +16,7 @@ pub struct CloudflareConfig {
pub account_id: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -37,7 +37,7 @@ impl CloudflareClient {
let url = format!(
"{API_BASE}/accounts/{account_id}/ai/run/{}",
self.model.name
self.model.name()
);
debug!("Cloudflare Request: {url} {body}");
@ -50,7 +50,7 @@ impl CloudflareClient {
impl_client_trait!(CloudflareClient, send_message, send_message_streaming);
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
async fn send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -58,6 +58,7 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDeta
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
extract_completion(&data)
}
@ -67,6 +68,7 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
return Ok(true);
}
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(text) = data["response"].as_str() {
handler.text(text)?;
}
@ -80,11 +82,12 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
messages,
temperature,
top_p,
functions: _,
stream,
} = data;
let mut body = json!({
"model": &model.name,
"model": &model.name(),
"messages": messages,
});
@ -104,10 +107,10 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
Ok(body)
}
fn extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
fn extract_completion(data: &Value) -> Result<CompletionOutput> {
let text = data["result"]["response"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok((text.to_string(), CompletionDetails::default()))
Ok(CompletionOutput::new(text))
}

@ -1,9 +1,9 @@
use super::{
catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SseHandler,
catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionOutput,
ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData, SseHandler, ToolCall,
};
use anyhow::{anyhow, bail, Result};
use anyhow::{bail, Result};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
@ -15,7 +15,7 @@ pub struct CohereConfig {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -42,7 +42,7 @@ impl CohereClient {
impl_client_trait!(CohereClient, send_message, send_message_streaming);
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
async fn send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -50,6 +50,7 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDeta
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
extract_completion(&data)
}
@ -62,10 +63,25 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
} else {
let handle = |data: &str| -> Result<()> {
let data: Value = serde_json::from_str(data)?;
debug!("stream-data: {data}");
if let Some("text-generation") = data["event_type"].as_str() {
if let Some(text) = data["text"].as_str() {
handler.text(text)?;
}
} else if let Some("tool-calls-generation") = data["event_type"].as_str() {
if let Some(tool_calls) = data["tool_calls"].as_array() {
for call in tool_calls {
if let (Some(name), Some(args)) =
(call["name"].as_str(), call["parameters"].as_object())
{
handler.tool_call(ToolCall::new(
name.to_string(),
json!(args),
None,
))?;
}
}
}
}
Ok(())
};
@ -79,24 +95,28 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
mut messages,
temperature,
top_p,
functions,
stream,
} = data;
let system_message = extract_system_message(&mut messages);
let mut image_urls = vec![];
let mut tool_results = None;
let mut messages: Vec<Value> = messages
.into_iter()
.map(|message| {
let role = match message.role {
.filter_map(|message| {
let Message { role, content } = message;
let role = match role {
MessageRole::User => "USER",
_ => "CHATBOT",
};
match message.content {
MessageContent::Text(text) => json!({
match content {
MessageContent::Text(text) => Some(json!({
"role": role,
"message": text,
}),
})),
MessageContent::Array(list) => {
let list: Vec<String> = list
.into_iter()
@ -110,7 +130,11 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
}
})
.collect();
json!({ "role": role, "message": list.join("\n\n") })
Some(json!({ "role": role, "message": list.join("\n\n") }))
}
MessageContent::ToolResults((tool_call_results, _)) => {
tool_results = Some(tool_call_results);
None
}
}
})
@ -123,10 +147,29 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
let message = message["message"].as_str().unwrap_or_default();
let mut body = json!({
"model": &model.name,
"model": &model.name(),
"message": message,
});
if let Some(tool_results) = tool_results {
let tool_results: Vec<_> = tool_results
.into_iter()
.map(|tool_call_result| {
json!({
"call": {
"name": tool_call_result.call.name,
"parameters": tool_call_result.call.arguments,
},
"outputs": [
tool_call_result.output,
]
})
})
.collect();
body["tool_results"] = json!(tool_results);
}
if let Some(v) = system_message {
body["preamble"] = v.into();
}
@ -148,18 +191,60 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
body["stream"] = true.into();
}
if let Some(functions) = functions {
body["tools"] = functions
.iter()
.map(|v| {
let required = v.parameters.required.clone().unwrap_or_default();
let mut parameter_definitions = json!({});
if let Some(properties) = &v.parameters.properties {
for (key, value) in properties {
let mut value: Value = json!(value);
if value.is_object() && required.iter().any(|x| x == key) {
value["required"] = true.into();
}
parameter_definitions[key] = value;
}
}
json!({
"name": v.name,
"description": v.description,
"parameter_definitions": parameter_definitions,
})
})
.collect();
}
Ok(body)
}
fn extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
fn extract_completion(data: &Value) -> Result<CompletionOutput> {
let text = data["text"].as_str().unwrap_or_default();
let mut tool_calls = vec![];
if let Some(calls) = data["tool_calls"].as_array() {
tool_calls = calls
.iter()
.filter_map(|call| {
if let (Some(name), Some(parameters)) =
(call["name"].as_str(), call["parameters"].as_object())
{