From 1e8fc5d269985048d8d3023a615b94a8908571cf Mon Sep 17 00:00:00 2001 From: sigoden Date: Sat, 11 May 2024 09:23:59 +0800 Subject: [PATCH] refactor: list roles includeing builtin roles (#499) --- src/config/input.rs | 17 +++++--- src/config/mod.rs | 17 ++++---- src/config/role.rs | 100 ++++++++++++++++++++------------------------ src/main.rs | 11 ++--- src/repl/mod.rs | 3 +- 5 files changed, 72 insertions(+), 76 deletions(-) diff --git a/src/config/input.rs b/src/config/input.rs index b7b896c..0527b49 100644 --- a/src/config/input.rs +++ b/src/config/input.rs @@ -106,7 +106,7 @@ impl Input { } pub fn session<'a>(&self, session: &'a Option) -> Option<&'a Session> { - if self.context.in_session { + if self.context.session { session.as_ref() } else { None @@ -114,7 +114,7 @@ impl Input { } pub fn session_mut<'a>(&self, session: &'a mut Option) -> Option<&'a mut Session> { - if self.context.in_session { + if self.context.session { session.as_mut() } else { None @@ -199,12 +199,19 @@ impl Input { #[derive(Debug, Clone, Default)] pub struct InputContext { role: Option, - in_session: bool, + session: bool, } impl InputContext { - pub fn new(role: Option, in_session: bool) -> Self { - Self { role, in_session } + pub fn new(role: Option, session: bool) -> Self { + Self { role, session } + } + + pub fn role(role: Role) -> Self { + Self { + role: Some(role), + session: false, + } } } diff --git a/src/config/mod.rs b/src/config/mod.rs index e166969..4a1867f 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -3,7 +3,7 @@ mod role; mod session; pub use self::input::{Input, InputContext}; -pub use self::role::{Role, CODE_ROLE, EXPLAIN_ROLE, SHELL_ROLE}; +pub use self::role::{Role, CODE_ROLE, EXPLAIN_SHELL_ROLE, SHELL_ROLE}; use self::session::{Session, TEMP_SESSION_NAME}; use crate::client::{ @@ -190,7 +190,6 @@ impl Config { role.complete_prompt_args(name); role }) - .or_else(|| Role::find_system_role(name)) .ok_or_else(|| anyhow!("Unknown role `{name}`")) } @@ -682,10 +681,6 @@ impl Config { Ok(()) } - pub fn has_session(&self) -> bool { - self.session.is_some() - } - pub fn clear_session_messages(&mut self) -> Result<()> { if let Some(session) = self.session.as_mut() { session.clear_messages(); @@ -818,7 +813,7 @@ impl Config { } pub fn input_context(&self) -> InputContext { - InputContext::new(self.role.clone(), self.has_session()) + InputContext::new(self.role.clone(), self.session.is_some()) } pub fn maybe_print_send_tokens(&self, input: &Input) { @@ -978,7 +973,15 @@ impl Config { .with_context(|| format!("Failed to load roles at {}", path.display()))?; let roles: Vec = serde_yaml::from_str(&content).with_context(|| "Invalid roles config")?; + + let exist_roles: HashSet<_> = roles.iter().map(|v| v.name.clone()).collect(); self.roles = roles; + let builtin_roles = Role::builtin(); + for role in builtin_roles { + if !exist_roles.contains(&role.name) { + self.roles.push(role); + } + } Ok(()) } diff --git a/src/config/role.rs b/src/config/role.rs index 39b777c..4fc34c9 100644 --- a/src/config/role.rs +++ b/src/config/role.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; pub const TEMP_ROLE: &str = "%%"; pub const SHELL_ROLE: &str = "%shell%"; -pub const EXPLAIN_ROLE: &str = "%explain%"; +pub const EXPLAIN_SHELL_ROLE: &str = "%explain-shell%"; pub const CODE_ROLE: &str = "%code%"; pub const INPUT_PLACEHOLDER: &str = "__INPUT__"; @@ -32,61 +32,20 @@ impl Role { } } - pub fn find_system_role(name: &str) -> Option { - match name { - SHELL_ROLE => Some(Self::shell()), - EXPLAIN_ROLE => Some(Self::explain()), - CODE_ROLE => Some(Self::code()), - _ => None, - } - } - - pub fn shell() -> Self { - let os = detect_os(); - let (detected_shell, _, _) = detect_shell(); - let (shell, use_semicolon) = match (detected_shell.as_str(), os.as_str()) { - // GPT doesn’t know much about nushell - ("nushell", "windows") => ("cmd", true), - ("nushell", _) => ("bash", true), - ("powershell", _) => ("powershell", true), - ("pwsh", _) => ("powershell", false), - _ => (detected_shell.as_str(), false), - }; - let combine = if use_semicolon { - "\nIf multiple steps required try to combine them together using ';'.\nIf it already combined with '&&' try to replace it with ';'.".to_string() - } else { - "\nIf multiple steps required try to combine them together using &&.".to_string() - }; - Self { - name: SHELL_ROLE.into(), - prompt: format!( - r#"Provide only {shell} commands for {os} without any description. -Ensure the output is a valid {shell} command. {combine} -If there is a lack of details, provide most logical solution. -Output plain text only, without any markdown formatting."# - ), - temperature: None, - top_p: None, - } - } - - pub fn explain() -> Self { - Self { - name: EXPLAIN_ROLE.into(), - prompt: r#"Provide a terse, single sentence description of the given shell command. + pub fn builtin() -> Vec { + [ + (SHELL_ROLE, shell_prompt()), + ( + EXPLAIN_SHELL_ROLE, + r#"Provide a terse, single sentence description of the given shell command. Describe each argument and option of the command. Provide short responses in about 80 words. APPLY MARKDOWN formatting when possible."# - .into(), - temperature: None, - top_p: None, - } - } - - pub fn code() -> Self { - Self { - name: CODE_ROLE.into(), - prompt: r#"Provide only code without comments or explanations. + .into(), + ), + ( + CODE_ROLE, + r#"Provide only code without comments or explanations. ### INPUT: async sleep in js ### OUTPUT: @@ -96,10 +55,17 @@ async function timeout(ms) { } ``` "# - .into(), + .into(), + ), + ] + .into_iter() + .map(|(name, prompt)| Self { + name: name.into(), + prompt, temperature: None, top_p: None, - } + }) + .collect() } pub fn export(&self) -> Result { @@ -241,6 +207,30 @@ fn parse_structure_prompt(prompt: &str) -> (&str, Vec<(&str, &str)>) { (prompt, vec![]) } +fn shell_prompt() -> String { + let os = detect_os(); + let (detected_shell, _, _) = detect_shell(); + let (shell, use_semicolon) = match (detected_shell.as_str(), os.as_str()) { + // GPT doesn’t know much about nushell + ("nushell", "windows") => ("cmd", true), + ("nushell", _) => ("bash", true), + ("powershell", _) => ("powershell", true), + ("pwsh", _) => ("powershell", false), + _ => (detected_shell.as_str(), false), + }; + let combine = if use_semicolon { + "\nIf multiple steps required try to combine them together using ';'.\nIf it already combined with '&&' try to replace it with ';'.".to_string() + } else { + "\nIf multiple steps required try to combine them together using '&&'.".to_string() + }; + format!( + r#"Provide only {shell} commands for {os} without any description. +Ensure the output is a valid {shell} command. {combine} +If there is a lack of details, provide most logical solution. +Output plain text only, without any markdown formatting."# + ) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/main.rs b/src/main.rs index 397ab1e..558e473 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,7 +14,8 @@ extern crate log; use crate::cli::Cli; use crate::client::{ensure_model_capabilities, init_client, list_models, send_stream}; use crate::config::{ - Config, GlobalConfig, Input, WorkingMode, CODE_ROLE, EXPLAIN_ROLE, SHELL_ROLE, + Config, GlobalConfig, Input, InputContext, WorkingMode, CODE_ROLE, EXPLAIN_SHELL_ROLE, + SHELL_ROLE, }; use crate::render::{render_error, MarkdownRender}; use crate::repl::Repl; @@ -200,7 +201,6 @@ async fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> { return Ok(()); } if is_terminal_stdout { - let mut explain = false; loop { let answer = Select::new( markdown_render.render(&eval_str).trim(), @@ -222,13 +222,10 @@ async fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> { return execute(config, input).await; } "📙 Explain" => { - if !explain { - config.write().set_role(EXPLAIN_ROLE)?; - } - let input = Input::from_str(&eval_str, config.read().input_context()); + let role = config.read().retrieve_role(EXPLAIN_SHELL_ROLE)?; + let input = Input::from_str(&eval_str, InputContext::role(role)); let abort = create_abort_signal(); send_stream(&input, client.as_ref(), config, abort).await?; - explain = true; continue; } _ => {} diff --git a/src/repl/mod.rs b/src/repl/mod.rs index d7cd9d1..db0c9d2 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -176,8 +176,7 @@ impl Repl { Some(args) => match args.split_once(|c| c == '\n' || c == ' ') { Some((name, text)) => { let role = self.config.read().retrieve_role(name.trim())?; - let input = - Input::from_str(text.trim(), InputContext::new(Some(role), false)); + let input = Input::from_str(text.trim(), InputContext::role(role)); self.ask(input).await?; } None => {