feat: enhenced flexibility for use tools (#688)

pull/690/head
sigoden 3 months ago committed by GitHub
parent 6a56af270e
commit 8b0c648a73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -28,6 +28,10 @@ summary_prompt: 'This is a summary of the chat history as a recap: '
# ---- function-calling & agent ---- # ---- function-calling & agent ----
# Visit https://github.com/sigoden/llm-functions for setup instructions # Visit https://github.com/sigoden/llm-functions for setup instructions
function_calling: true # Enables or disables function calling (Globally). function_calling: true # Enables or disables function calling (Globally).
mapping_tools: # Alias for a tool or toolset
# web_search: 'search_duckduckgo'
# code_interpreter: 'execute_py_code'
use_tools: null # Which tools to use by default
# Regex for seletecting dangerous functions # Regex for seletecting dangerous functions
# User confirmation is required when executing these functions # User confirmation is required when executing these functions
# e.g. 'execute_command|execute_js_code' 'execute_.*' # e.g. 'execute_command|execute_js_code' 'execute_.*'

@ -1,9 +1,6 @@
use super::*; use super::*;
use crate::{ use crate::{client::Model, function::Functions};
client::Model,
function::{Functions, FunctionsFilter, SELECTED_ALL_FUNCTIONS},
};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use std::{fs::read_to_string, path::Path}; use std::{fs::read_to_string, path::Path};
@ -150,11 +147,12 @@ impl RoleLike for Agent {
self.config.top_p self.config.top_p
} }
fn functions_filter(&self) -> Option<FunctionsFilter> { fn use_tools(&self) -> Option<String> {
if self.functions.is_empty() { let common_tools = &self.definition.common_tools;
if common_tools.is_empty() {
None None
} else { } else {
Some(SELECTED_ALL_FUNCTIONS.into()) Some(common_tools.join(","))
} }
} }
@ -171,7 +169,7 @@ impl RoleLike for Agent {
self.config.top_p = value; self.config.top_p = value;
} }
fn set_functions_filter(&mut self, _value: Option<FunctionsFilter>) {} fn set_use_tools(&mut self, _value: Option<String>) {}
} }
#[derive(Debug, Clone, Default, Deserialize, Serialize)] #[derive(Debug, Clone, Default, Deserialize, Serialize)]
@ -184,7 +182,7 @@ pub struct AgentConfig {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>, pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub dangerously_functions_filter: Option<FunctionsFilter>, pub dangerously_functions_filter: Option<String>,
} }
impl AgentConfig { impl AgentConfig {
@ -208,6 +206,8 @@ pub struct AgentDefinition {
pub conversation_starters: Vec<String>, pub conversation_starters: Vec<String>,
#[serde(default)] #[serde(default)]
pub documents: Vec<String>, pub documents: Vec<String>,
#[serde(default)]
pub common_tools: Vec<String>,
} }
impl AgentDefinition { impl AgentDefinition {

@ -12,13 +12,14 @@ use crate::client::{
create_client_config, list_chat_models, list_client_types, list_reranker_models, ClientConfig, create_client_config, list_chat_models, list_client_types, list_reranker_models, ClientConfig,
Model, OPENAI_COMPATIBLE_PLATFORMS, Model, OPENAI_COMPATIBLE_PLATFORMS,
}; };
use crate::function::{FunctionDeclaration, Functions, FunctionsFilter, ToolResult}; use crate::function::{FunctionDeclaration, Functions, ToolResult};
use crate::rag::Rag; use crate::rag::Rag;
use crate::render::{MarkdownRender, RenderOptions}; use crate::render::{MarkdownRender, RenderOptions};
use crate::utils::*; use crate::utils::*;
use anyhow::{anyhow, bail, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
use fancy_regex::Regex; use fancy_regex::Regex;
use indexmap::IndexMap;
use inquire::{Confirm, Select}; use inquire::{Confirm, Select};
use parking_lot::RwLock; use parking_lot::RwLock;
use serde::Deserialize; use serde::Deserialize;
@ -103,7 +104,9 @@ pub struct Config {
pub summary_prompt: Option<String>, pub summary_prompt: Option<String>,
pub function_calling: bool, pub function_calling: bool,
pub dangerously_functions_filter: Option<FunctionsFilter>, pub mapping_tools: IndexMap<String, String>,
pub use_tools: Option<String>,
pub dangerously_functions_filter: Option<String>,
pub agents: Vec<AgentConfig>, pub agents: Vec<AgentConfig>,
pub rag_embedding_model: Option<String>, pub rag_embedding_model: Option<String>,
@ -164,6 +167,8 @@ impl Default for Config {
agent_prelude: None, agent_prelude: None,
function_calling: true, function_calling: true,
mapping_tools: Default::default(),
use_tools: None,
dangerously_functions_filter: None, dangerously_functions_filter: None,
agents: vec![], agents: vec![],
@ -421,6 +426,9 @@ impl Config {
if role.top_p().is_none() && self.top_p.is_some() { if role.top_p().is_none() && self.top_p.is_some() {
role.set_top_p(self.top_p); role.set_top_p(self.top_p);
} }
if role.use_tools().is_none() && self.use_tools.is_some() && self.agent.is_none() {
role.set_use_tools(self.use_tools.clone())
}
role role
} }
@ -475,6 +483,7 @@ impl Config {
("save_session", format_option_value(&self.save_session)), ("save_session", format_option_value(&self.save_session)),
("compress_threshold", self.compress_threshold.to_string()), ("compress_threshold", self.compress_threshold.to_string()),
("function_calling", self.function_calling.to_string()), ("function_calling", self.function_calling.to_string()),
("use_tools", format_option_value(&role.use_tools())),
( (
"rag_reranker_model", "rag_reranker_model",
format_option_value(&self.rag_reranker_model), format_option_value(&self.rag_reranker_model),
@ -545,6 +554,13 @@ impl Config {
} }
self.function_calling = value; self.function_calling = value;
} }
"use_tools" => {
if self.agent.is_some() {
bail!("This action cannot be performed within an agent.")
}
let value = parse_value(value)?;
self.set_use_tools(value);
}
"compress_threshold" => { "compress_threshold" => {
let value = parse_value(value)?; let value = parse_value(value)?;
self.set_compress_threshold(value); self.set_compress_threshold(value);
@ -584,6 +600,13 @@ impl Config {
} }
} }
pub fn set_use_tools(&mut self, value: Option<String>) {
match self.role_like_mut() {
Some(role_like) => role_like.set_use_tools(value),
None => self.use_tools = value,
}
}
pub fn set_save_session(&mut self, value: Option<bool>) { pub fn set_save_session(&mut self, value: Option<bool>) {
if let Some(session) = self.session.as_mut() { if let Some(session) = self.session.as_mut() {
session.set_save_session(value); session.set_save_session(value);
@ -1034,23 +1057,62 @@ impl Config {
} }
pub fn select_functions(&self, model: &Model, role: &Role) -> Option<Vec<FunctionDeclaration>> { pub fn select_functions(&self, model: &Model, role: &Role) -> Option<Vec<FunctionDeclaration>> {
let mut functions = None; let mut functions = vec![];
if self.function_calling { if self.function_calling {
let filter = role.functions_filter(); let use_tools = role.use_tools();
if let Some(filter) = filter { let declaration_names: HashSet<String> = self
functions = match &self.agent { .functions
Some(agent) => agent.functions().select(&filter), .declarations()
None => self.functions.select(&filter), .iter()
}; .map(|v| v.name.to_string())
.collect();
if let Some(use_tools) = use_tools {
let mut tool_names: HashSet<String> = Default::default();
for item in use_tools.split(',') {
let item = item.trim();
if item == "all" {
tool_names.extend(declaration_names);
break;
} else if let Some(values) = self.mapping_tools.get(item) {
tool_names.extend(
values
.split(',')
.map(|v| v.to_string())
.filter(|v| declaration_names.contains(v)),
)
} else if declaration_names.contains(item) {
tool_names.insert(item.to_string());
}
}
functions = self
.functions
.declarations()
.iter()
.filter_map(|v| {
if tool_names.contains(&v.name) {
Some(v.clone())
} else {
None
}
})
.collect();
if let Some(agent) = &self.agent {
let agent_functions = agent.functions().declarations().to_vec();
functions = [agent_functions, functions].concat();
}
if !model.supports_function_calling() { if !model.supports_function_calling() {
functions = None; functions.clear();
if *IS_STDOUT_TERMINAL { if *IS_STDOUT_TERMINAL {
eprintln!("{}", warning_text("WARNING: This LLM or client does not support function calling, despite the context requiring it.")); eprintln!("{}", warning_text("WARNING: This LLM or client does not support function calling, despite the context requiring it."));
} }
} }
} }
}; };
functions if functions.is_empty() {
None
} else {
Some(functions)
}
} }
pub fn is_dangerously_function(&self, name: &str) -> bool { pub fn is_dangerously_function(&self, name: &str) -> bool {
@ -1110,14 +1172,15 @@ impl Config {
"max_output_tokens", "max_output_tokens",
"temperature", "temperature",
"top_p", "top_p",
"rag_reranker_model", "dry_run",
"rag_top_k",
"function_calling",
"compress_threshold",
"save", "save",
"save_session", "save_session",
"compress_threshold",
"function_calling",
"use_tools",
"rag_reranker_model",
"rag_top_k",
"highlight", "highlight",
"dry_run",
] ]
.into_iter() .into_iter()
.map(|v| (format!("{v} "), None)) .map(|v| (format!("{v} "), None))
@ -1131,8 +1194,7 @@ impl Config {
Some(v) => vec![v.to_string()], Some(v) => vec![v.to_string()],
None => vec![], None => vec![],
}, },
"rag_reranker_model" => list_reranker_models(self).iter().map(|v| v.id()).collect(), "dry_run" => complete_bool(self.dry_run),
"function_calling" => complete_bool(self.function_calling),
"save" => complete_bool(self.save), "save" => complete_bool(self.save),
"save_session" => { "save_session" => {
let save_session = if let Some(session) = &self.session { let save_session = if let Some(session) = &self.session {
@ -1142,8 +1204,26 @@ impl Config {
}; };
complete_option_bool(save_session) complete_option_bool(save_session)
} }
"function_calling" => complete_bool(self.function_calling),
"use_tools" => {
let mut prefix = String::new();
if let Some((v, _)) = args[1].rsplit_once(',') {
prefix = format!("{v},");
}
let mut values = vec![];
if prefix.is_empty() {
values.push("all".to_string());
}
values.extend(self.mapping_tools.keys().map(|v| v.to_string()));
values.extend(self.functions.declarations().iter().map(|v| v.name.clone()));
values
.into_iter()
.filter(|v| !prefix.contains(&format!("{v},")))
.map(|v| format!("{prefix}{v}"))
.collect()
}
"rag_reranker_model" => list_reranker_models(self).iter().map(|v| v.id()).collect(),
"highlight" => complete_bool(self.highlight), "highlight" => complete_bool(self.highlight),
"dry_run" => complete_bool(self.dry_run),
_ => vec![], _ => vec![],
}; };
(values.into_iter().map(|v| (v, None)).collect(), args[1]) (values.into_iter().map(|v| (v, None)).collect(), args[1])

@ -2,7 +2,6 @@ use super::*;
use crate::{ use crate::{
client::{Message, MessageContent, MessageRole, Model}, client::{Message, MessageContent, MessageRole, Model},
function::{FunctionsFilter, SELECTED_ALL_FUNCTIONS},
utils::{detect_os, detect_shell}, utils::{detect_os, detect_shell},
}; };
@ -21,11 +20,11 @@ pub trait RoleLike {
fn model_mut(&mut self) -> &mut Model; fn model_mut(&mut self) -> &mut Model;
fn temperature(&self) -> Option<f64>; fn temperature(&self) -> Option<f64>;
fn top_p(&self) -> Option<f64>; fn top_p(&self) -> Option<f64>;
fn functions_filter(&self) -> Option<FunctionsFilter>; fn use_tools(&self) -> Option<String>;
fn set_model(&mut self, model: &Model); fn set_model(&mut self, model: &Model);
fn set_temperature(&mut self, value: Option<f64>); fn set_temperature(&mut self, value: Option<f64>);
fn set_top_p(&mut self, value: Option<f64>); fn set_top_p(&mut self, value: Option<f64>);
fn set_functions_filter(&mut self, value: Option<FunctionsFilter>); fn set_use_tools(&mut self, value: Option<String>);
} }
#[derive(Debug, Clone, Default, Deserialize, Serialize)] #[derive(Debug, Clone, Default, Deserialize, Serialize)]
@ -43,7 +42,7 @@ pub struct Role {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>, top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
functions_filter: Option<FunctionsFilter>, use_tools: Option<String>,
#[serde(skip)] #[serde(skip)]
model: Model, model: Model,
@ -85,17 +84,13 @@ async function timeout(ms) {
.into(), .into(),
None, None,
), ),
( ("%functions%", String::new(), Some("all".into())),
"%functions%",
String::new(),
Some(SELECTED_ALL_FUNCTIONS.into()),
),
] ]
.into_iter() .into_iter()
.map(|(name, prompt, functions_filter)| Self { .map(|(name, prompt, use_tools)| Self {
name: name.into(), name: name.into(),
prompt, prompt,
functions_filter, use_tools,
..Default::default() ..Default::default()
}) })
.collect() .collect()
@ -111,8 +106,8 @@ async function timeout(ms) {
let model = role_like.model(); let model = role_like.model();
let temperature = role_like.temperature(); let temperature = role_like.temperature();
let top_p = role_like.top_p(); let top_p = role_like.top_p();
let functions_filter = role_like.functions_filter(); let use_tools = role_like.use_tools();
self.batch_set(model, temperature, top_p, functions_filter); self.batch_set(model, temperature, top_p, use_tools);
} }
pub fn batch_set( pub fn batch_set(
@ -120,7 +115,7 @@ async function timeout(ms) {
model: &Model, model: &Model,
temperature: Option<f64>, temperature: Option<f64>,
top_p: Option<f64>, top_p: Option<f64>,
functions_filter: Option<FunctionsFilter>, use_tools: Option<String>,
) { ) {
self.set_model(model); self.set_model(model);
if temperature.is_some() { if temperature.is_some() {
@ -129,8 +124,8 @@ async function timeout(ms) {
if top_p.is_some() { if top_p.is_some() {
self.set_top_p(top_p); self.set_top_p(top_p);
} }
if functions_filter.is_some() { if use_tools.is_some() {
self.set_functions_filter(functions_filter); self.set_use_tools(use_tools);
} }
} }
@ -242,8 +237,8 @@ impl RoleLike for Role {
self.top_p self.top_p
} }
fn functions_filter(&self) -> Option<FunctionsFilter> { fn use_tools(&self) -> Option<String> {
self.functions_filter.clone() self.use_tools.clone()
} }
fn set_model(&mut self, model: &Model) { fn set_model(&mut self, model: &Model) {
@ -258,8 +253,8 @@ impl RoleLike for Role {
self.top_p = value; self.top_p = value;
} }
fn set_functions_filter(&mut self, value: Option<FunctionsFilter>) { fn set_use_tools(&mut self, value: Option<String>) {
self.functions_filter = value; self.use_tools = value;
} }
} }

@ -21,7 +21,7 @@ pub struct Session {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>, top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
functions_filter: Option<FunctionsFilter>, use_tools: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
save_session: Option<bool>, save_session: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
@ -134,8 +134,8 @@ impl Session {
if let Some(top_p) = self.top_p() { if let Some(top_p) = self.top_p() {
data["top_p"] = top_p.into(); data["top_p"] = top_p.into();
} }
if let Some(functions_filter) = self.functions_filter() { if let Some(use_tools) = self.use_tools() {
data["functions_filter"] = functions_filter.into(); data["use_tools"] = use_tools.into();
} }
if let Some(save_session) = self.save_session() { if let Some(save_session) = self.save_session() {
data["save_session"] = save_session.into(); data["save_session"] = save_session.into();
@ -171,8 +171,8 @@ impl Session {
items.push(("top_p", top_p.to_string())); items.push(("top_p", top_p.to_string()));
} }
if let Some(functions_filter) = self.functions_filter() { if let Some(use_tools) = self.use_tools() {
items.push(("functions_filter", functions_filter)); items.push(("use_tools", use_tools));
} }
if let Some(save_session) = self.save_session() { if let Some(save_session) = self.save_session() {
@ -242,7 +242,7 @@ impl Session {
self.model_id = role.model().id(); self.model_id = role.model().id();
self.temperature = role.temperature(); self.temperature = role.temperature();
self.top_p = role.top_p(); self.top_p = role.top_p();
self.functions_filter = role.functions_filter(); self.use_tools = role.use_tools();
self.model = role.model().clone(); self.model = role.model().clone();
self.role_name = role.name().to_string(); self.role_name = role.name().to_string();
self.role_prompt = role.prompt().to_string(); self.role_prompt = role.prompt().to_string();
@ -345,7 +345,7 @@ impl Session {
pub fn guard_empty(&self) -> Result<()> { pub fn guard_empty(&self) -> Result<()> {
if !self.is_empty() { if !self.is_empty() {
bail!("Cannot perform this action in a session with messages") bail!("This action cannot be performed in a session with messages.")
} }
Ok(()) Ok(())
} }
@ -442,8 +442,8 @@ impl RoleLike for Session {
self.top_p self.top_p
} }
fn functions_filter(&self) -> Option<FunctionsFilter> { fn use_tools(&self) -> Option<String> {
self.functions_filter.clone() self.use_tools.clone()
} }
fn set_model(&mut self, model: &Model) { fn set_model(&mut self, model: &Model) {
@ -468,9 +468,9 @@ impl RoleLike for Session {
} }
} }
fn set_functions_filter(&mut self, value: Option<FunctionsFilter>) { fn set_use_tools(&mut self, value: Option<String>) {
if self.functions_filter != value { if self.use_tools != value {
self.functions_filter = value; self.use_tools = value;
self.dirty = true; self.dirty = true;
} }
} }

@ -4,8 +4,7 @@ use crate::{
}; };
use anyhow::{anyhow, bail, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
use fancy_regex::Regex; use indexmap::IndexMap;
use indexmap::{IndexMap, IndexSet};
use inquire::{validator::Validation, Text}; use inquire::{validator::Validation, Text};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
@ -15,9 +14,7 @@ use std::{
path::Path, path::Path,
}; };
pub const SELECTED_ALL_FUNCTIONS: &str = ".*";
pub type ToolResults = (Vec<ToolResult>, String); pub type ToolResults = (Vec<ToolResult>, String);
pub type FunctionsFilter = String;
pub fn eval_tool_calls(config: &GlobalConfig, mut calls: Vec<ToolCall>) -> Result<Vec<ToolResult>> { pub fn eval_tool_calls(config: &GlobalConfig, mut calls: Vec<ToolCall>) -> Result<Vec<ToolResult>> {
let mut output = vec![]; let mut output = vec![];
@ -53,7 +50,6 @@ impl ToolResult {
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct Functions { pub struct Functions {
names: IndexSet<String>,
declarations: Vec<FunctionDeclaration>, declarations: Vec<FunctionDeclaration>,
} }
@ -72,35 +68,19 @@ impl Functions {
vec![] vec![]
}; };
let names = declarations.iter().map(|v| v.name.clone()).collect(); Ok(Self { declarations })
Ok(Self {
names,
declarations,
})
} }
pub fn select(&self, filter: &FunctionsFilter) -> Option<Vec<FunctionDeclaration>> { pub fn contains(&self, name: &str) -> bool {
let regex = Regex::new(&format!("^({filter})$")).ok()?; self.declarations.iter().any(|v| v.name == name)
let output: Vec<FunctionDeclaration> = self
.declarations
.iter()
.filter(|v| regex.is_match(&v.name).unwrap_or_default())
.cloned()
.collect();
if output.is_empty() {
None
} else {
Some(output)
}
} }
pub fn contains(&self, name: &str) -> bool { pub fn declarations(&self) -> &[FunctionDeclaration] {
self.names.contains(name) &self.declarations
} }
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
self.names.is_empty() self.declarations.is_empty()
} }
} }
@ -174,18 +154,21 @@ impl ToolCall {
let is_dangerously = config.read().is_dangerously_function(&function_name); let is_dangerously = config.read().is_dangerously_function(&function_name);
let (call_name, cmd_name, mut cmd_args) = match &config.read().agent { let (call_name, cmd_name, mut cmd_args) = match &config.read().agent {
Some(agent) => { Some(agent) => {
if !agent.functions().contains(&function_name) { if agent.functions().contains(&function_name) {
(
format!("{}:{}", agent.name(), function_name),
agent.name().to_string(),
vec![function_name],
)
} else if config.read().functions.contains(&function_name) {
(function_name.clone(), function_name, vec![])
} else {
bail!( bail!(
"Unexpected call: {} {function_name} {}", "Unexpected call: {} {function_name} {}",
agent.name(), agent.name(),
self.arguments self.arguments
); );
} }
(
format!("{}:{}", agent.name(), function_name),
agent.name().to_string(),
vec![function_name],
)
} }
None => { None => {
if !config.read().functions.contains(&function_name) { if !config.read().functions.contains(&function_name) {

Loading…
Cancel
Save