diff --git a/config.example.yaml b/config.example.yaml index 45efb65..c55eb8a 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -28,6 +28,10 @@ summary_prompt: 'This is a summary of the chat history as a recap: ' # ---- function-calling & agent ---- # Visit https://github.com/sigoden/llm-functions for setup instructions function_calling: true # Enables or disables function calling (Globally). +mapping_tools: # Alias for a tool or toolset + # web_search: 'search_duckduckgo' + # code_interpreter: 'execute_py_code' +use_tools: null # Which tools to use by default # Regex for seletecting dangerous functions # User confirmation is required when executing these functions # e.g. 'execute_command|execute_js_code' 'execute_.*' diff --git a/src/config/agent.rs b/src/config/agent.rs index c7ff9ad..a6257ca 100644 --- a/src/config/agent.rs +++ b/src/config/agent.rs @@ -1,9 +1,6 @@ use super::*; -use crate::{ - client::Model, - function::{Functions, FunctionsFilter, SELECTED_ALL_FUNCTIONS}, -}; +use crate::{client::Model, function::Functions}; use anyhow::{Context, Result}; use std::{fs::read_to_string, path::Path}; @@ -150,11 +147,12 @@ impl RoleLike for Agent { self.config.top_p } - fn functions_filter(&self) -> Option { - if self.functions.is_empty() { + fn use_tools(&self) -> Option { + let common_tools = &self.definition.common_tools; + if common_tools.is_empty() { None } else { - Some(SELECTED_ALL_FUNCTIONS.into()) + Some(common_tools.join(",")) } } @@ -171,7 +169,7 @@ impl RoleLike for Agent { self.config.top_p = value; } - fn set_functions_filter(&mut self, _value: Option) {} + fn set_use_tools(&mut self, _value: Option) {} } #[derive(Debug, Clone, Default, Deserialize, Serialize)] @@ -184,7 +182,7 @@ pub struct AgentConfig { #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub dangerously_functions_filter: Option, + pub dangerously_functions_filter: Option, } impl AgentConfig { @@ -208,6 +206,8 @@ pub struct AgentDefinition { pub conversation_starters: Vec, #[serde(default)] pub documents: Vec, + #[serde(default)] + pub common_tools: Vec, } impl AgentDefinition { diff --git a/src/config/mod.rs b/src/config/mod.rs index 80abbb8..d2804ed 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -12,13 +12,14 @@ use crate::client::{ create_client_config, list_chat_models, list_client_types, list_reranker_models, ClientConfig, Model, OPENAI_COMPATIBLE_PLATFORMS, }; -use crate::function::{FunctionDeclaration, Functions, FunctionsFilter, ToolResult}; +use crate::function::{FunctionDeclaration, Functions, ToolResult}; use crate::rag::Rag; use crate::render::{MarkdownRender, RenderOptions}; use crate::utils::*; use anyhow::{anyhow, bail, Context, Result}; use fancy_regex::Regex; +use indexmap::IndexMap; use inquire::{Confirm, Select}; use parking_lot::RwLock; use serde::Deserialize; @@ -103,7 +104,9 @@ pub struct Config { pub summary_prompt: Option, pub function_calling: bool, - pub dangerously_functions_filter: Option, + pub mapping_tools: IndexMap, + pub use_tools: Option, + pub dangerously_functions_filter: Option, pub agents: Vec, pub rag_embedding_model: Option, @@ -164,6 +167,8 @@ impl Default for Config { agent_prelude: None, function_calling: true, + mapping_tools: Default::default(), + use_tools: None, dangerously_functions_filter: None, agents: vec![], @@ -421,6 +426,9 @@ impl Config { if role.top_p().is_none() && self.top_p.is_some() { role.set_top_p(self.top_p); } + if role.use_tools().is_none() && self.use_tools.is_some() && self.agent.is_none() { + role.set_use_tools(self.use_tools.clone()) + } role } @@ -475,6 +483,7 @@ impl Config { ("save_session", format_option_value(&self.save_session)), ("compress_threshold", self.compress_threshold.to_string()), ("function_calling", self.function_calling.to_string()), + ("use_tools", format_option_value(&role.use_tools())), ( "rag_reranker_model", format_option_value(&self.rag_reranker_model), @@ -545,6 +554,13 @@ impl Config { } self.function_calling = value; } + "use_tools" => { + if self.agent.is_some() { + bail!("This action cannot be performed within an agent.") + } + let value = parse_value(value)?; + self.set_use_tools(value); + } "compress_threshold" => { let value = parse_value(value)?; self.set_compress_threshold(value); @@ -584,6 +600,13 @@ impl Config { } } + pub fn set_use_tools(&mut self, value: Option) { + match self.role_like_mut() { + Some(role_like) => role_like.set_use_tools(value), + None => self.use_tools = value, + } + } + pub fn set_save_session(&mut self, value: Option) { if let Some(session) = self.session.as_mut() { session.set_save_session(value); @@ -1034,23 +1057,62 @@ impl Config { } pub fn select_functions(&self, model: &Model, role: &Role) -> Option> { - let mut functions = None; + let mut functions = vec![]; if self.function_calling { - let filter = role.functions_filter(); - if let Some(filter) = filter { - functions = match &self.agent { - Some(agent) => agent.functions().select(&filter), - None => self.functions.select(&filter), - }; + let use_tools = role.use_tools(); + let declaration_names: HashSet = self + .functions + .declarations() + .iter() + .map(|v| v.name.to_string()) + .collect(); + if let Some(use_tools) = use_tools { + let mut tool_names: HashSet = Default::default(); + for item in use_tools.split(',') { + let item = item.trim(); + if item == "all" { + tool_names.extend(declaration_names); + break; + } else if let Some(values) = self.mapping_tools.get(item) { + tool_names.extend( + values + .split(',') + .map(|v| v.to_string()) + .filter(|v| declaration_names.contains(v)), + ) + } else if declaration_names.contains(item) { + tool_names.insert(item.to_string()); + } + } + functions = self + .functions + .declarations() + .iter() + .filter_map(|v| { + if tool_names.contains(&v.name) { + Some(v.clone()) + } else { + None + } + }) + .collect(); + if let Some(agent) = &self.agent { + let agent_functions = agent.functions().declarations().to_vec(); + functions = [agent_functions, functions].concat(); + } if !model.supports_function_calling() { - functions = None; + functions.clear(); if *IS_STDOUT_TERMINAL { eprintln!("{}", warning_text("WARNING: This LLM or client does not support function calling, despite the context requiring it.")); } } } }; - functions + if functions.is_empty() { + None + } else { + Some(functions) + } } pub fn is_dangerously_function(&self, name: &str) -> bool { @@ -1110,14 +1172,15 @@ impl Config { "max_output_tokens", "temperature", "top_p", - "rag_reranker_model", - "rag_top_k", - "function_calling", - "compress_threshold", + "dry_run", "save", "save_session", + "compress_threshold", + "function_calling", + "use_tools", + "rag_reranker_model", + "rag_top_k", "highlight", - "dry_run", ] .into_iter() .map(|v| (format!("{v} "), None)) @@ -1131,8 +1194,7 @@ impl Config { Some(v) => vec![v.to_string()], None => vec![], }, - "rag_reranker_model" => list_reranker_models(self).iter().map(|v| v.id()).collect(), - "function_calling" => complete_bool(self.function_calling), + "dry_run" => complete_bool(self.dry_run), "save" => complete_bool(self.save), "save_session" => { let save_session = if let Some(session) = &self.session { @@ -1142,8 +1204,26 @@ impl Config { }; complete_option_bool(save_session) } + "function_calling" => complete_bool(self.function_calling), + "use_tools" => { + let mut prefix = String::new(); + if let Some((v, _)) = args[1].rsplit_once(',') { + prefix = format!("{v},"); + } + let mut values = vec![]; + if prefix.is_empty() { + values.push("all".to_string()); + } + values.extend(self.mapping_tools.keys().map(|v| v.to_string())); + values.extend(self.functions.declarations().iter().map(|v| v.name.clone())); + values + .into_iter() + .filter(|v| !prefix.contains(&format!("{v},"))) + .map(|v| format!("{prefix}{v}")) + .collect() + } + "rag_reranker_model" => list_reranker_models(self).iter().map(|v| v.id()).collect(), "highlight" => complete_bool(self.highlight), - "dry_run" => complete_bool(self.dry_run), _ => vec![], }; (values.into_iter().map(|v| (v, None)).collect(), args[1]) diff --git a/src/config/role.rs b/src/config/role.rs index 29d3f3b..3fc7696 100644 --- a/src/config/role.rs +++ b/src/config/role.rs @@ -2,7 +2,6 @@ use super::*; use crate::{ client::{Message, MessageContent, MessageRole, Model}, - function::{FunctionsFilter, SELECTED_ALL_FUNCTIONS}, utils::{detect_os, detect_shell}, }; @@ -21,11 +20,11 @@ pub trait RoleLike { fn model_mut(&mut self) -> &mut Model; fn temperature(&self) -> Option; fn top_p(&self) -> Option; - fn functions_filter(&self) -> Option; + fn use_tools(&self) -> Option; fn set_model(&mut self, model: &Model); fn set_temperature(&mut self, value: Option); fn set_top_p(&mut self, value: Option); - fn set_functions_filter(&mut self, value: Option); + fn set_use_tools(&mut self, value: Option); } #[derive(Debug, Clone, Default, Deserialize, Serialize)] @@ -43,7 +42,7 @@ pub struct Role { #[serde(skip_serializing_if = "Option::is_none")] top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] - functions_filter: Option, + use_tools: Option, #[serde(skip)] model: Model, @@ -85,17 +84,13 @@ async function timeout(ms) { .into(), None, ), - ( - "%functions%", - String::new(), - Some(SELECTED_ALL_FUNCTIONS.into()), - ), + ("%functions%", String::new(), Some("all".into())), ] .into_iter() - .map(|(name, prompt, functions_filter)| Self { + .map(|(name, prompt, use_tools)| Self { name: name.into(), prompt, - functions_filter, + use_tools, ..Default::default() }) .collect() @@ -111,8 +106,8 @@ async function timeout(ms) { let model = role_like.model(); let temperature = role_like.temperature(); let top_p = role_like.top_p(); - let functions_filter = role_like.functions_filter(); - self.batch_set(model, temperature, top_p, functions_filter); + let use_tools = role_like.use_tools(); + self.batch_set(model, temperature, top_p, use_tools); } pub fn batch_set( @@ -120,7 +115,7 @@ async function timeout(ms) { model: &Model, temperature: Option, top_p: Option, - functions_filter: Option, + use_tools: Option, ) { self.set_model(model); if temperature.is_some() { @@ -129,8 +124,8 @@ async function timeout(ms) { if top_p.is_some() { self.set_top_p(top_p); } - if functions_filter.is_some() { - self.set_functions_filter(functions_filter); + if use_tools.is_some() { + self.set_use_tools(use_tools); } } @@ -242,8 +237,8 @@ impl RoleLike for Role { self.top_p } - fn functions_filter(&self) -> Option { - self.functions_filter.clone() + fn use_tools(&self) -> Option { + self.use_tools.clone() } fn set_model(&mut self, model: &Model) { @@ -258,8 +253,8 @@ impl RoleLike for Role { self.top_p = value; } - fn set_functions_filter(&mut self, value: Option) { - self.functions_filter = value; + fn set_use_tools(&mut self, value: Option) { + self.use_tools = value; } } diff --git a/src/config/session.rs b/src/config/session.rs index 9aa4ca4..80629c1 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -21,7 +21,7 @@ pub struct Session { #[serde(skip_serializing_if = "Option::is_none")] top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] - functions_filter: Option, + use_tools: Option, #[serde(skip_serializing_if = "Option::is_none")] save_session: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -134,8 +134,8 @@ impl Session { if let Some(top_p) = self.top_p() { data["top_p"] = top_p.into(); } - if let Some(functions_filter) = self.functions_filter() { - data["functions_filter"] = functions_filter.into(); + if let Some(use_tools) = self.use_tools() { + data["use_tools"] = use_tools.into(); } if let Some(save_session) = self.save_session() { data["save_session"] = save_session.into(); @@ -171,8 +171,8 @@ impl Session { items.push(("top_p", top_p.to_string())); } - if let Some(functions_filter) = self.functions_filter() { - items.push(("functions_filter", functions_filter)); + if let Some(use_tools) = self.use_tools() { + items.push(("use_tools", use_tools)); } if let Some(save_session) = self.save_session() { @@ -242,7 +242,7 @@ impl Session { self.model_id = role.model().id(); self.temperature = role.temperature(); self.top_p = role.top_p(); - self.functions_filter = role.functions_filter(); + self.use_tools = role.use_tools(); self.model = role.model().clone(); self.role_name = role.name().to_string(); self.role_prompt = role.prompt().to_string(); @@ -345,7 +345,7 @@ impl Session { pub fn guard_empty(&self) -> Result<()> { if !self.is_empty() { - bail!("Cannot perform this action in a session with messages") + bail!("This action cannot be performed in a session with messages.") } Ok(()) } @@ -442,8 +442,8 @@ impl RoleLike for Session { self.top_p } - fn functions_filter(&self) -> Option { - self.functions_filter.clone() + fn use_tools(&self) -> Option { + self.use_tools.clone() } fn set_model(&mut self, model: &Model) { @@ -468,9 +468,9 @@ impl RoleLike for Session { } } - fn set_functions_filter(&mut self, value: Option) { - if self.functions_filter != value { - self.functions_filter = value; + fn set_use_tools(&mut self, value: Option) { + if self.use_tools != value { + self.use_tools = value; self.dirty = true; } } diff --git a/src/function.rs b/src/function.rs index fcb4c93..739ef0c 100644 --- a/src/function.rs +++ b/src/function.rs @@ -4,8 +4,7 @@ use crate::{ }; use anyhow::{anyhow, bail, Context, Result}; -use fancy_regex::Regex; -use indexmap::{IndexMap, IndexSet}; +use indexmap::IndexMap; use inquire::{validator::Validation, Text}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -15,9 +14,7 @@ use std::{ path::Path, }; -pub const SELECTED_ALL_FUNCTIONS: &str = ".*"; pub type ToolResults = (Vec, String); -pub type FunctionsFilter = String; pub fn eval_tool_calls(config: &GlobalConfig, mut calls: Vec) -> Result> { let mut output = vec![]; @@ -53,7 +50,6 @@ impl ToolResult { #[derive(Debug, Clone, Default)] pub struct Functions { - names: IndexSet, declarations: Vec, } @@ -72,35 +68,19 @@ impl Functions { vec![] }; - let names = declarations.iter().map(|v| v.name.clone()).collect(); - - Ok(Self { - names, - declarations, - }) + Ok(Self { declarations }) } - pub fn select(&self, filter: &FunctionsFilter) -> Option> { - let regex = Regex::new(&format!("^({filter})$")).ok()?; - let output: Vec = self - .declarations - .iter() - .filter(|v| regex.is_match(&v.name).unwrap_or_default()) - .cloned() - .collect(); - if output.is_empty() { - None - } else { - Some(output) - } + pub fn contains(&self, name: &str) -> bool { + self.declarations.iter().any(|v| v.name == name) } - pub fn contains(&self, name: &str) -> bool { - self.names.contains(name) + pub fn declarations(&self) -> &[FunctionDeclaration] { + &self.declarations } pub fn is_empty(&self) -> bool { - self.names.is_empty() + self.declarations.is_empty() } } @@ -174,18 +154,21 @@ impl ToolCall { let is_dangerously = config.read().is_dangerously_function(&function_name); let (call_name, cmd_name, mut cmd_args) = match &config.read().agent { Some(agent) => { - if !agent.functions().contains(&function_name) { + if agent.functions().contains(&function_name) { + ( + format!("{}:{}", agent.name(), function_name), + agent.name().to_string(), + vec![function_name], + ) + } else if config.read().functions.contains(&function_name) { + (function_name.clone(), function_name, vec![]) + } else { bail!( "Unexpected call: {} {function_name} {}", agent.name(), self.arguments ); } - ( - format!("{}:{}", agent.name(), function_name), - agent.name().to_string(), - vec![function_name], - ) } None => { if !config.read().functions.contains(&function_name) {