feat: agent can reuse tools (#690)

pull/691/head
sigoden 3 months ago committed by GitHub
parent 8b0c648a73
commit 10a4c23c83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -29,8 +29,7 @@ summary_prompt: 'This is a summary of the chat history as a recap: '
# Visit https://github.com/sigoden/llm-functions for setup instructions # Visit https://github.com/sigoden/llm-functions for setup instructions
function_calling: true # Enables or disables function calling (Globally). function_calling: true # Enables or disables function calling (Globally).
mapping_tools: # Alias for a tool or toolset mapping_tools: # Alias for a tool or toolset
# web_search: 'search_duckduckgo' # fs: 'fs_cat,fs_ls,fs_mkdir,fs_rm,fs_write'
# code_interpreter: 'execute_py_code'
use_tools: null # Which tools to use by default use_tools: null # Which tools to use by default
# Regex for seletecting dangerous functions # Regex for seletecting dangerous functions
# User confirmation is required when executing these functions # User confirmation is required when executing these functions

@ -148,12 +148,7 @@ impl RoleLike for Agent {
} }
fn use_tools(&self) -> Option<String> { fn use_tools(&self) -> Option<String> {
let common_tools = &self.definition.common_tools; self.config.use_tools.clone()
if common_tools.is_empty() {
None
} else {
Some(common_tools.join(","))
}
} }
fn set_model(&mut self, model: &Model) { fn set_model(&mut self, model: &Model) {
@ -169,7 +164,9 @@ impl RoleLike for Agent {
self.config.top_p = value; self.config.top_p = value;
} }
fn set_use_tools(&mut self, _value: Option<String>) {} fn set_use_tools(&mut self, value: Option<String>) {
self.config.use_tools = value;
}
} }
#[derive(Debug, Clone, Default, Deserialize, Serialize)] #[derive(Debug, Clone, Default, Deserialize, Serialize)]
@ -182,6 +179,8 @@ pub struct AgentConfig {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>, pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
use_tools: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dangerously_functions_filter: Option<String>, pub dangerously_functions_filter: Option<String>,
} }
@ -206,8 +205,6 @@ pub struct AgentDefinition {
pub conversation_starters: Vec<String>, pub conversation_starters: Vec<String>,
#[serde(default)] #[serde(default)]
pub documents: Vec<String>, pub documents: Vec<String>,
#[serde(default)]
pub common_tools: Vec<String>,
} }
impl AgentDefinition { impl AgentDefinition {

@ -555,9 +555,6 @@ impl Config {
self.function_calling = value; self.function_calling = value;
} }
"use_tools" => { "use_tools" => {
if self.agent.is_some() {
bail!("This action cannot be performed within an agent.")
}
let value = parse_value(value)?; let value = parse_value(value)?;
self.set_use_tools(value); self.set_use_tools(value);
} }
@ -1059,15 +1056,14 @@ impl Config {
pub fn select_functions(&self, model: &Model, role: &Role) -> Option<Vec<FunctionDeclaration>> { pub fn select_functions(&self, model: &Model, role: &Role) -> Option<Vec<FunctionDeclaration>> {
let mut functions = vec![]; let mut functions = vec![];
if self.function_calling { if self.function_calling {
let use_tools = role.use_tools(); if let Some(use_tools) = role.use_tools() {
let declaration_names: HashSet<String> = self
.functions
.declarations()
.iter()
.map(|v| v.name.to_string())
.collect();
if let Some(use_tools) = use_tools {
let mut tool_names: HashSet<String> = Default::default(); let mut tool_names: HashSet<String> = Default::default();
let declaration_names: HashSet<String> = self
.functions
.declarations()
.iter()
.map(|v| v.name.to_string())
.collect();
for item in use_tools.split(',') { for item in use_tools.split(',') {
let item = item.trim(); let item = item.trim();
if item == "all" { if item == "all" {
@ -1096,15 +1092,31 @@ impl Config {
} }
}) })
.collect(); .collect();
if let Some(agent) = &self.agent { }
let agent_functions = agent.functions().declarations().to_vec();
functions = [agent_functions, functions].concat(); if let Some(agent) = &self.agent {
} let mut agent_functions = agent.functions().declarations().to_vec();
if !model.supports_function_calling() { let tool_names: HashSet<String> = agent_functions
functions.clear(); .iter()
if *IS_STDOUT_TERMINAL { .filter_map(|v| {
eprintln!("{}", warning_text("WARNING: This LLM or client does not support function calling, despite the context requiring it.")); if v.agent {
} None
} else {
Some(v.name.to_string())
}
})
.collect();
agent_functions.extend(
functions
.into_iter()
.filter(|v| !tool_names.contains(&v.name)),
);
functions = agent_functions;
}
if !functions.is_empty() && !model.supports_function_calling() {
functions.clear();
if *IS_STDOUT_TERMINAL {
eprintln!("{}", warning_text("WARNING: This LLM or client does not support function calling, despite the context requiring it."));
} }
} }
}; };

@ -71,6 +71,10 @@ impl Functions {
Ok(Self { declarations }) Ok(Self { declarations })
} }
pub fn find(&self, name: &str) -> Option<&FunctionDeclaration> {
self.declarations.iter().find(|v| v.name == name)
}
pub fn contains(&self, name: &str) -> bool { pub fn contains(&self, name: &str) -> bool {
self.declarations.iter().any(|v| v.name == name) self.declarations.iter().any(|v| v.name == name)
} }
@ -89,6 +93,8 @@ pub struct FunctionDeclaration {
pub name: String, pub name: String,
pub description: String, pub description: String,
pub parameters: JsonSchema, pub parameters: JsonSchema,
#[serde(skip_serializing, default)]
pub agent: bool,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -154,7 +160,7 @@ impl ToolCall {
let is_dangerously = config.read().is_dangerously_function(&function_name); let is_dangerously = config.read().is_dangerously_function(&function_name);
let (call_name, cmd_name, mut cmd_args) = match &config.read().agent { let (call_name, cmd_name, mut cmd_args) = match &config.read().agent {
Some(agent) => { Some(agent) => {
if agent.functions().contains(&function_name) { if let Some(true) = agent.functions().find(&function_name).map(|v| v.agent) {
( (
format!("{}:{}", agent.name(), function_name), format!("{}:{}", agent.name(), function_name),
agent.name().to_string(), agent.name().to_string(),

Loading…
Cancel
Save