From a9268b600fb378400795fbaf98bf923372b4c19a Mon Sep 17 00:00:00 2001 From: sigoden Date: Tue, 9 Jul 2024 21:22:51 +0800 Subject: [PATCH] feat: support agent variables (#692) --- src/config/agent.rs | 120 ++++++++++++++++++++++++++++++++++++++++-- src/config/mod.rs | 33 ++++++++++-- src/config/session.rs | 10 +--- src/repl/mod.rs | 15 +++++- 4 files changed, 163 insertions(+), 15 deletions(-) diff --git a/src/config/agent.rs b/src/config/agent.rs index 82f887c..7a065e7 100644 --- a/src/config/agent.rs +++ b/src/config/agent.rs @@ -3,7 +3,11 @@ use super::*; use crate::{client::Model, function::Functions}; use anyhow::{Context, Result}; -use std::{fs::read_to_string, path::Path}; +use inquire::{validator::Validation, Text}; +use std::{ + fs::{self, read_to_string}, + path::Path, +}; use serde::{Deserialize, Serialize}; @@ -29,8 +33,13 @@ impl Agent { let functions_dir = Config::agent_functions_dir(name)?; let definition_file_path = functions_dir.join("index.yaml"); let functions_file_path = functions_dir.join("functions.json"); + let variables_path = Config::agent_variables_file(name)?; let rag_path = Config::agent_rag_file(name)?; - let definition = AgentDefinition::load(&definition_file_path)?; + + let mut definition = AgentDefinition::load(&definition_file_path)?; + init_variables(&variables_path, &mut definition.variables) + .context("Failed to init variables")?; + let functions = if functions_file_path.exists() { Functions::init(&functions_file_path)? } else { @@ -50,6 +59,7 @@ impl Agent { None => config.current_model().clone(), } }; + let rag = if rag_path.exists() { Some(Arc::new(Rag::load(config, "rag", &rag_path)?)) } else if !definition.documents.is_empty() { @@ -91,6 +101,10 @@ impl Agent { .display() .to_string() .into(); + value["variables_file"] = Config::agent_variables_file(&self.name)? + .display() + .to_string() + .into(); let data = serde_yaml::to_string(&value)?; Ok(data) } @@ -122,11 +136,28 @@ impl Agent { pub fn conversation_staters(&self) -> &[String] { &self.definition.conversation_starters } + + pub fn variables(&self) -> &[AgentVariable] { + &self.definition.variables + } + + pub fn set_variable(&mut self, key: &str, value: &str) -> Result<()> { + match self.definition.variables.iter_mut().find(|v| v.name == key) { + Some(variable) => { + variable.value = value.to_string(); + let variables_path = Config::agent_variables_file(&self.name)?; + save_variables(&variables_path, self.variables())?; + Ok(()) + } + None => bail!("Unknown variable '{key}'"), + } + } } impl RoleLike for Agent { fn to_role(&self) -> Role { - let mut role = Role::new("", &self.definition.instructions); + let prompt = self.definition.interpolated_instructions(); + let mut role = Role::new("", &prompt); role.sync(self); role } @@ -202,6 +233,8 @@ pub struct AgentDefinition { pub version: String, pub instructions: String, #[serde(default)] + pub variables: Vec, + #[serde(default)] pub conversation_starters: Vec, #[serde(default)] pub documents: Vec, @@ -244,6 +277,24 @@ impl AgentDefinition { {description}{starters}"# ) } + + fn interpolated_instructions(&self) -> String { + let mut output = self.instructions.clone(); + for variable in &self.variables { + output = output.replace(&format!("{{{{{}}}}}", variable.name), &variable.value) + } + output + } +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct AgentVariable { + pub name: String, + pub description: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_deserializing, default)] + pub value: String, } pub fn list_agents() -> Vec { @@ -266,3 +317,66 @@ fn list_agents_impl() -> Result> { .collect(); Ok(agents) } + +fn init_variables(variables_path: &Path, variables: &mut [AgentVariable]) -> Result<()> { + if variables.is_empty() { + return Ok(()); + } + let variable_values = if variables_path.exists() { + let content = read_to_string(variables_path).with_context(|| { + format!( + "Failed to read variables from '{}'", + variables_path.display() + ) + })?; + let variable_values: IndexMap = serde_yaml::from_str(&content)?; + variable_values + } else { + Default::default() + }; + let mut initialized = false; + for variable in variables.iter_mut() { + match variable_values.get(&variable.name) { + Some(value) => variable.value = value.to_string(), + None => { + if !initialized { + println!("The agent has the variables and is initializing them..."); + initialized = true; + } + if *IS_STDOUT_TERMINAL { + let mut text = + Text::new(&variable.description).with_validator(|input: &str| { + if input.trim().is_empty() { + Ok(Validation::Invalid("This field is required".into())) + } else { + Ok(Validation::Valid) + } + }); + if let Some(default) = &variable.default { + text = text.with_default(default); + } + let value = text.prompt()?; + variable.value = value; + } else { + bail!("Failed to init agent variables in the script mode."); + } + } + } + } + if initialized { + save_variables(variables_path, variables)?; + } + Ok(()) +} + +fn save_variables(variables_path: &Path, variables: &[AgentVariable]) -> Result<()> { + ensure_parent_exists(variables_path)?; + let variable_values: IndexMap = variables + .iter() + .map(|v| (v.name.clone(), v.value.clone())) + .collect(); + let content = serde_yaml::to_string(&variable_values)?; + fs::write(variables_path, content) + .with_context(|| format!("Failed to save variables to '{}'", variables_path.display()))?; + Ok(()) +} diff --git a/src/config/mod.rs b/src/config/mod.rs index a243a83..154dcd9 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -50,6 +50,7 @@ const FUNCTIONS_FILE_NAME: &str = "functions.json"; const FUNCTIONS_BIN_DIR_NAME: &str = "bin"; const AGENTS_DIR_NAME: &str = "agents"; const AGENT_RAG_FILE_NAME: &str = "rag.bin"; +const AGENT_VARIABLES_FILE_NAME: &str = "variables.yaml"; pub const TEMP_ROLE_NAME: &str = "%%"; pub const TEMP_RAG_NAME: &str = "temp"; @@ -352,6 +353,10 @@ impl Config { Ok(Self::agent_config_dir(name)?.join(AGENT_RAG_FILE_NAME)) } + pub fn agent_variables_file(name: &str) -> Result { + Ok(Self::agent_config_dir(name)?.join(AGENT_VARIABLES_FILE_NAME)) + } + pub fn agents_functions_dir() -> Result { match env::var(get_env_name("agents_functions_dir")) { Ok(value) => Ok(PathBuf::from(value)), @@ -1014,6 +1019,20 @@ impl Config { } } + pub fn set_agent_variable(&mut self, data: &str) -> Result<()> { + let parts: Vec<&str> = data.split_whitespace().collect(); + if parts.len() != 2 { + bail!("Usage: .variable "); + } + let key = parts[0]; + let value = parts[1]; + match self.agent.as_mut() { + Some(agent) => agent.set_variable(key, value)?, + None => bail!("No agent"), + }; + Ok(()) + } + pub fn exit_agent(&mut self) -> Result<()> { self.exit_session()?; if self.agent.take().is_some() { @@ -1180,6 +1199,14 @@ impl Config { .collect(), None => vec![], }, + ".variable" => match &self.agent { + Some(agent) => agent + .variables() + .iter() + .map(|v| (v.name.clone(), Some(v.description.clone()))) + .collect(), + None => vec![], + }, ".set" => vec![ "max_output_tokens", "temperature", @@ -1200,7 +1227,7 @@ impl Config { _ => vec![], }; (values, args[0]) - } else if args.len() == 2 { + } else if args.len() == 2 && cmd == ".set" { let values = match args[0] { "max_output_tokens" => match self.model.max_output_tokens() { Some(v) => vec![v.to_string()], @@ -1693,11 +1720,11 @@ pub(crate) fn ensure_parent_exists(path: &Path) -> Result<()> { } let parent = path .parent() - .ok_or_else(|| anyhow!("Failed to write to {}, No parent path", path.display()))?; + .ok_or_else(|| anyhow!("Failed to write to '{}', No parent path", path.display()))?; if !parent.exists() { create_dir_all(parent).with_context(|| { format!( - "Failed to write {}, Cannot create parent directory", + "Failed to write to '{}', Cannot create parent directory", path.display() ) })?; diff --git a/src/config/session.rs b/src/config/session.rs index 80629c1..792b22d 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -9,7 +9,7 @@ use inquire::{validator::Validation, Confirm, Text}; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; -use std::fs::{self, create_dir_all, read_to_string}; +use std::fs::{self, read_to_string}; use std::path::Path; #[derive(Debug, Clone, Default, Deserialize, Serialize)] @@ -314,13 +314,7 @@ impl Session { } pub fn save(&mut self, session_path: &Path, is_repl: bool) -> Result<()> { - if let Some(sessions_dir) = session_path.parent() { - if !sessions_dir.exists() { - create_dir_all(sessions_dir).with_context(|| { - format!("Failed to create session_dir '{}'", sessions_dir.display()) - })?; - } - } + ensure_parent_exists(session_path)?; self.path = Some(session_path.display().to_string()); diff --git a/src/repl/mod.rs b/src/repl/mod.rs index 2111c48..0100cb0 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -33,7 +33,7 @@ lazy_static! { const MENU_NAME: &str = "completion_menu"; lazy_static! { - static ref REPL_COMMANDS: [ReplCommand; 27] = [ + static ref REPL_COMMANDS: [ReplCommand; 28] = [ ReplCommand::new(".help", "Show this help message", AssertState::pass()), ReplCommand::new(".info", "View system info", AssertState::pass()), ReplCommand::new(".model", "Change the current LLM", AssertState::pass()), @@ -118,6 +118,11 @@ lazy_static! { "Use the conversation starter", AssertState::True(StateFlags::AGENT) ), + ReplCommand::new( + ".variable", + "Set agent variable", + AssertState::True(StateFlags::AGENT) + ), ReplCommand::new( ".exit agent", "Leave the agent", @@ -293,6 +298,14 @@ Tips: use to autocomplete conversation starter text. println!("{output}"); } }, + ".variable" => match args { + Some(args) => { + self.config.write().set_agent_variable(args)?; + } + _ => { + println!("Usage: .variable ") + } + }, ".save" => { match args.map(|v| match v.split_once(' ') { Some((subcmd, args)) => (subcmd, Some(args.trim())),