feat: agent can reuse tools (#690)

pull/691/head
sigoden 3 months ago committed by GitHub
parent 8b0c648a73
commit 10a4c23c83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -29,8 +29,7 @@ summary_prompt: 'This is a summary of the chat history as a recap: '
# Visit https://github.com/sigoden/llm-functions for setup instructions
function_calling: true # Enables or disables function calling (Globally).
mapping_tools: # Alias for a tool or toolset
# web_search: 'search_duckduckgo'
# code_interpreter: 'execute_py_code'
# fs: 'fs_cat,fs_ls,fs_mkdir,fs_rm,fs_write'
use_tools: null # Which tools to use by default
# Regex for seletecting dangerous functions
# User confirmation is required when executing these functions

@ -148,12 +148,7 @@ impl RoleLike for Agent {
}
fn use_tools(&self) -> Option<String> {
let common_tools = &self.definition.common_tools;
if common_tools.is_empty() {
None
} else {
Some(common_tools.join(","))
}
self.config.use_tools.clone()
}
fn set_model(&mut self, model: &Model) {
@ -169,7 +164,9 @@ impl RoleLike for Agent {
self.config.top_p = value;
}
fn set_use_tools(&mut self, _value: Option<String>) {}
fn set_use_tools(&mut self, value: Option<String>) {
self.config.use_tools = value;
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
@ -182,6 +179,8 @@ pub struct AgentConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
use_tools: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dangerously_functions_filter: Option<String>,
}
@ -206,8 +205,6 @@ pub struct AgentDefinition {
pub conversation_starters: Vec<String>,
#[serde(default)]
pub documents: Vec<String>,
#[serde(default)]
pub common_tools: Vec<String>,
}
impl AgentDefinition {

@ -555,9 +555,6 @@ impl Config {
self.function_calling = value;
}
"use_tools" => {
if self.agent.is_some() {
bail!("This action cannot be performed within an agent.")
}
let value = parse_value(value)?;
self.set_use_tools(value);
}
@ -1059,15 +1056,14 @@ impl Config {
pub fn select_functions(&self, model: &Model, role: &Role) -> Option<Vec<FunctionDeclaration>> {
let mut functions = vec![];
if self.function_calling {
let use_tools = role.use_tools();
if let Some(use_tools) = role.use_tools() {
let mut tool_names: HashSet<String> = Default::default();
let declaration_names: HashSet<String> = self
.functions
.declarations()
.iter()
.map(|v| v.name.to_string())
.collect();
if let Some(use_tools) = use_tools {
let mut tool_names: HashSet<String> = Default::default();
for item in use_tools.split(',') {
let item = item.trim();
if item == "all" {
@ -1096,17 +1092,33 @@ impl Config {
}
})
.collect();
}
if let Some(agent) = &self.agent {
let agent_functions = agent.functions().declarations().to_vec();
functions = [agent_functions, functions].concat();
let mut agent_functions = agent.functions().declarations().to_vec();
let tool_names: HashSet<String> = agent_functions
.iter()
.filter_map(|v| {
if v.agent {
None
} else {
Some(v.name.to_string())
}
})
.collect();
agent_functions.extend(
functions
.into_iter()
.filter(|v| !tool_names.contains(&v.name)),
);
functions = agent_functions;
}
if !model.supports_function_calling() {
if !functions.is_empty() && !model.supports_function_calling() {
functions.clear();
if *IS_STDOUT_TERMINAL {
eprintln!("{}", warning_text("WARNING: This LLM or client does not support function calling, despite the context requiring it."));
}
}
}
};
if functions.is_empty() {
None

@ -71,6 +71,10 @@ impl Functions {
Ok(Self { declarations })
}
pub fn find(&self, name: &str) -> Option<&FunctionDeclaration> {
self.declarations.iter().find(|v| v.name == name)
}
pub fn contains(&self, name: &str) -> bool {
self.declarations.iter().any(|v| v.name == name)
}
@ -89,6 +93,8 @@ pub struct FunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: JsonSchema,
#[serde(skip_serializing, default)]
pub agent: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -154,7 +160,7 @@ impl ToolCall {
let is_dangerously = config.read().is_dangerously_function(&function_name);
let (call_name, cmd_name, mut cmd_args) = match &config.read().agent {
Some(agent) => {
if agent.functions().contains(&function_name) {
if let Some(true) = agent.functions().find(&function_name).map(|v| v.agent) {
(
format!("{}:{}", agent.name(), function_name),
agent.name().to_string(),

Loading…
Cancel
Save