feat: support agent variables (#692)

pull/693/head
sigoden 2 months ago committed by GitHub
parent 138c90b58b
commit a9268b600f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -3,7 +3,11 @@ use super::*;
use crate::{client::Model, function::Functions};
use anyhow::{Context, Result};
use std::{fs::read_to_string, path::Path};
use inquire::{validator::Validation, Text};
use std::{
fs::{self, read_to_string},
path::Path,
};
use serde::{Deserialize, Serialize};
@ -29,8 +33,13 @@ impl Agent {
let functions_dir = Config::agent_functions_dir(name)?;
let definition_file_path = functions_dir.join("index.yaml");
let functions_file_path = functions_dir.join("functions.json");
let variables_path = Config::agent_variables_file(name)?;
let rag_path = Config::agent_rag_file(name)?;
let definition = AgentDefinition::load(&definition_file_path)?;
let mut definition = AgentDefinition::load(&definition_file_path)?;
init_variables(&variables_path, &mut definition.variables)
.context("Failed to init variables")?;
let functions = if functions_file_path.exists() {
Functions::init(&functions_file_path)?
} else {
@ -50,6 +59,7 @@ impl Agent {
None => config.current_model().clone(),
}
};
let rag = if rag_path.exists() {
Some(Arc::new(Rag::load(config, "rag", &rag_path)?))
} else if !definition.documents.is_empty() {
@ -91,6 +101,10 @@ impl Agent {
.display()
.to_string()
.into();
value["variables_file"] = Config::agent_variables_file(&self.name)?
.display()
.to_string()
.into();
let data = serde_yaml::to_string(&value)?;
Ok(data)
}
@ -122,11 +136,28 @@ impl Agent {
pub fn conversation_staters(&self) -> &[String] {
&self.definition.conversation_starters
}
pub fn variables(&self) -> &[AgentVariable] {
&self.definition.variables
}
pub fn set_variable(&mut self, key: &str, value: &str) -> Result<()> {
match self.definition.variables.iter_mut().find(|v| v.name == key) {
Some(variable) => {
variable.value = value.to_string();
let variables_path = Config::agent_variables_file(&self.name)?;
save_variables(&variables_path, self.variables())?;
Ok(())
}
None => bail!("Unknown variable '{key}'"),
}
}
}
impl RoleLike for Agent {
fn to_role(&self) -> Role {
let mut role = Role::new("", &self.definition.instructions);
let prompt = self.definition.interpolated_instructions();
let mut role = Role::new("", &prompt);
role.sync(self);
role
}
@ -202,6 +233,8 @@ pub struct AgentDefinition {
pub version: String,
pub instructions: String,
#[serde(default)]
pub variables: Vec<AgentVariable>,
#[serde(default)]
pub conversation_starters: Vec<String>,
#[serde(default)]
pub documents: Vec<String>,
@ -244,6 +277,24 @@ impl AgentDefinition {
{description}{starters}"#
)
}
fn interpolated_instructions(&self) -> String {
let mut output = self.instructions.clone();
for variable in &self.variables {
output = output.replace(&format!("{{{{{}}}}}", variable.name), &variable.value)
}
output
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct AgentVariable {
pub name: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<String>,
#[serde(skip_deserializing, default)]
pub value: String,
}
pub fn list_agents() -> Vec<String> {
@ -266,3 +317,66 @@ fn list_agents_impl() -> Result<Vec<String>> {
.collect();
Ok(agents)
}
fn init_variables(variables_path: &Path, variables: &mut [AgentVariable]) -> Result<()> {
if variables.is_empty() {
return Ok(());
}
let variable_values = if variables_path.exists() {
let content = read_to_string(variables_path).with_context(|| {
format!(
"Failed to read variables from '{}'",
variables_path.display()
)
})?;
let variable_values: IndexMap<String, String> = serde_yaml::from_str(&content)?;
variable_values
} else {
Default::default()
};
let mut initialized = false;
for variable in variables.iter_mut() {
match variable_values.get(&variable.name) {
Some(value) => variable.value = value.to_string(),
None => {
if !initialized {
println!("The agent has the variables and is initializing them...");
initialized = true;
}
if *IS_STDOUT_TERMINAL {
let mut text =
Text::new(&variable.description).with_validator(|input: &str| {
if input.trim().is_empty() {
Ok(Validation::Invalid("This field is required".into()))
} else {
Ok(Validation::Valid)
}
});
if let Some(default) = &variable.default {
text = text.with_default(default);
}
let value = text.prompt()?;
variable.value = value;
} else {
bail!("Failed to init agent variables in the script mode.");
}
}
}
}
if initialized {
save_variables(variables_path, variables)?;
}
Ok(())
}
fn save_variables(variables_path: &Path, variables: &[AgentVariable]) -> Result<()> {
ensure_parent_exists(variables_path)?;
let variable_values: IndexMap<String, String> = variables
.iter()
.map(|v| (v.name.clone(), v.value.clone()))
.collect();
let content = serde_yaml::to_string(&variable_values)?;
fs::write(variables_path, content)
.with_context(|| format!("Failed to save variables to '{}'", variables_path.display()))?;
Ok(())
}

@ -50,6 +50,7 @@ const FUNCTIONS_FILE_NAME: &str = "functions.json";
const FUNCTIONS_BIN_DIR_NAME: &str = "bin";
const AGENTS_DIR_NAME: &str = "agents";
const AGENT_RAG_FILE_NAME: &str = "rag.bin";
const AGENT_VARIABLES_FILE_NAME: &str = "variables.yaml";
pub const TEMP_ROLE_NAME: &str = "%%";
pub const TEMP_RAG_NAME: &str = "temp";
@ -352,6 +353,10 @@ impl Config {
Ok(Self::agent_config_dir(name)?.join(AGENT_RAG_FILE_NAME))
}
pub fn agent_variables_file(name: &str) -> Result<PathBuf> {
Ok(Self::agent_config_dir(name)?.join(AGENT_VARIABLES_FILE_NAME))
}
pub fn agents_functions_dir() -> Result<PathBuf> {
match env::var(get_env_name("agents_functions_dir")) {
Ok(value) => Ok(PathBuf::from(value)),
@ -1014,6 +1019,20 @@ impl Config {
}
}
pub fn set_agent_variable(&mut self, data: &str) -> Result<()> {
let parts: Vec<&str> = data.split_whitespace().collect();
if parts.len() != 2 {
bail!("Usage: .variable <key> <value>");
}
let key = parts[0];
let value = parts[1];
match self.agent.as_mut() {
Some(agent) => agent.set_variable(key, value)?,
None => bail!("No agent"),
};
Ok(())
}
pub fn exit_agent(&mut self) -> Result<()> {
self.exit_session()?;
if self.agent.take().is_some() {
@ -1180,6 +1199,14 @@ impl Config {
.collect(),
None => vec![],
},
".variable" => match &self.agent {
Some(agent) => agent
.variables()
.iter()
.map(|v| (v.name.clone(), Some(v.description.clone())))
.collect(),
None => vec![],
},
".set" => vec![
"max_output_tokens",
"temperature",
@ -1200,7 +1227,7 @@ impl Config {
_ => vec![],
};
(values, args[0])
} else if args.len() == 2 {
} else if args.len() == 2 && cmd == ".set" {
let values = match args[0] {
"max_output_tokens" => match self.model.max_output_tokens() {
Some(v) => vec![v.to_string()],
@ -1693,11 +1720,11 @@ pub(crate) fn ensure_parent_exists(path: &Path) -> Result<()> {
}
let parent = path
.parent()
.ok_or_else(|| anyhow!("Failed to write to {}, No parent path", path.display()))?;
.ok_or_else(|| anyhow!("Failed to write to '{}', No parent path", path.display()))?;
if !parent.exists() {
create_dir_all(parent).with_context(|| {
format!(
"Failed to write {}, Cannot create parent directory",
"Failed to write to '{}', Cannot create parent directory",
path.display()
)
})?;

@ -9,7 +9,7 @@ use inquire::{validator::Validation, Confirm, Text};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::fs::{self, create_dir_all, read_to_string};
use std::fs::{self, read_to_string};
use std::path::Path;
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
@ -314,13 +314,7 @@ impl Session {
}
pub fn save(&mut self, session_path: &Path, is_repl: bool) -> Result<()> {
if let Some(sessions_dir) = session_path.parent() {
if !sessions_dir.exists() {
create_dir_all(sessions_dir).with_context(|| {
format!("Failed to create session_dir '{}'", sessions_dir.display())
})?;
}
}
ensure_parent_exists(session_path)?;
self.path = Some(session_path.display().to_string());

@ -33,7 +33,7 @@ lazy_static! {
const MENU_NAME: &str = "completion_menu";
lazy_static! {
static ref REPL_COMMANDS: [ReplCommand; 27] = [
static ref REPL_COMMANDS: [ReplCommand; 28] = [
ReplCommand::new(".help", "Show this help message", AssertState::pass()),
ReplCommand::new(".info", "View system info", AssertState::pass()),
ReplCommand::new(".model", "Change the current LLM", AssertState::pass()),
@ -118,6 +118,11 @@ lazy_static! {
"Use the conversation starter",
AssertState::True(StateFlags::AGENT)
),
ReplCommand::new(
".variable",
"Set agent variable",
AssertState::True(StateFlags::AGENT)
),
ReplCommand::new(
".exit agent",
"Leave the agent",
@ -293,6 +298,14 @@ Tips: use <tab> to autocomplete conversation starter text.
println!("{output}");
}
},
".variable" => match args {
Some(args) => {
self.config.write().set_agent_variable(args)?;
}
_ => {
println!("Usage: .variable <key> <value>")
}
},
".save" => {
match args.map(|v| match v.split_once(' ') {
Some((subcmd, args)) => (subcmd, Some(args.trim())),

Loading…
Cancel
Save