diff --git a/config.example.yaml b/config.example.yaml index c55eb8a..01ee994 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -29,8 +29,7 @@ summary_prompt: 'This is a summary of the chat history as a recap: ' # Visit https://github.com/sigoden/llm-functions for setup instructions function_calling: true # Enables or disables function calling (Globally). mapping_tools: # Alias for a tool or toolset - # web_search: 'search_duckduckgo' - # code_interpreter: 'execute_py_code' + # fs: 'fs_cat,fs_ls,fs_mkdir,fs_rm,fs_write' use_tools: null # Which tools to use by default # Regex for seletecting dangerous functions # User confirmation is required when executing these functions diff --git a/src/config/agent.rs b/src/config/agent.rs index a6257ca..82f887c 100644 --- a/src/config/agent.rs +++ b/src/config/agent.rs @@ -148,12 +148,7 @@ impl RoleLike for Agent { } fn use_tools(&self) -> Option { - let common_tools = &self.definition.common_tools; - if common_tools.is_empty() { - None - } else { - Some(common_tools.join(",")) - } + self.config.use_tools.clone() } fn set_model(&mut self, model: &Model) { @@ -169,7 +164,9 @@ impl RoleLike for Agent { self.config.top_p = value; } - fn set_use_tools(&mut self, _value: Option) {} + fn set_use_tools(&mut self, value: Option) { + self.config.use_tools = value; + } } #[derive(Debug, Clone, Default, Deserialize, Serialize)] @@ -182,6 +179,8 @@ pub struct AgentConfig { #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] + use_tools: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub dangerously_functions_filter: Option, } @@ -206,8 +205,6 @@ pub struct AgentDefinition { pub conversation_starters: Vec, #[serde(default)] pub documents: Vec, - #[serde(default)] - pub common_tools: Vec, } impl AgentDefinition { diff --git a/src/config/mod.rs b/src/config/mod.rs index d2804ed..eeffc87 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -555,9 +555,6 @@ impl Config { self.function_calling = value; } "use_tools" => { - if self.agent.is_some() { - bail!("This action cannot be performed within an agent.") - } let value = parse_value(value)?; self.set_use_tools(value); } @@ -1059,15 +1056,14 @@ impl Config { pub fn select_functions(&self, model: &Model, role: &Role) -> Option> { let mut functions = vec![]; if self.function_calling { - let use_tools = role.use_tools(); - let declaration_names: HashSet = self - .functions - .declarations() - .iter() - .map(|v| v.name.to_string()) - .collect(); - if let Some(use_tools) = use_tools { + if let Some(use_tools) = role.use_tools() { let mut tool_names: HashSet = Default::default(); + let declaration_names: HashSet = self + .functions + .declarations() + .iter() + .map(|v| v.name.to_string()) + .collect(); for item in use_tools.split(',') { let item = item.trim(); if item == "all" { @@ -1096,15 +1092,31 @@ impl Config { } }) .collect(); - if let Some(agent) = &self.agent { - let agent_functions = agent.functions().declarations().to_vec(); - functions = [agent_functions, functions].concat(); - } - if !model.supports_function_calling() { - functions.clear(); - if *IS_STDOUT_TERMINAL { - eprintln!("{}", warning_text("WARNING: This LLM or client does not support function calling, despite the context requiring it.")); - } + } + + if let Some(agent) = &self.agent { + let mut agent_functions = agent.functions().declarations().to_vec(); + let tool_names: HashSet = agent_functions + .iter() + .filter_map(|v| { + if v.agent { + None + } else { + Some(v.name.to_string()) + } + }) + .collect(); + agent_functions.extend( + functions + .into_iter() + .filter(|v| !tool_names.contains(&v.name)), + ); + functions = agent_functions; + } + if !functions.is_empty() && !model.supports_function_calling() { + functions.clear(); + if *IS_STDOUT_TERMINAL { + eprintln!("{}", warning_text("WARNING: This LLM or client does not support function calling, despite the context requiring it.")); } } }; diff --git a/src/function.rs b/src/function.rs index 739ef0c..0016405 100644 --- a/src/function.rs +++ b/src/function.rs @@ -71,6 +71,10 @@ impl Functions { Ok(Self { declarations }) } + pub fn find(&self, name: &str) -> Option<&FunctionDeclaration> { + self.declarations.iter().find(|v| v.name == name) + } + pub fn contains(&self, name: &str) -> bool { self.declarations.iter().any(|v| v.name == name) } @@ -89,6 +93,8 @@ pub struct FunctionDeclaration { pub name: String, pub description: String, pub parameters: JsonSchema, + #[serde(skip_serializing, default)] + pub agent: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -154,7 +160,7 @@ impl ToolCall { let is_dangerously = config.read().is_dangerously_function(&function_name); let (call_name, cmd_name, mut cmd_args) = match &config.read().agent { Some(agent) => { - if agent.functions().contains(&function_name) { + if let Some(true) = agent.functions().find(&function_name).map(|v| v.agent) { ( format!("{}:{}", agent.name(), function_name), agent.name().to_string(),