refactor: list roles includeing builtin roles (#499)

pull/500/head
sigoden 1 month ago committed by GitHub
parent 058299e500
commit 1e8fc5d269
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -106,7 +106,7 @@ impl Input {
} }
pub fn session<'a>(&self, session: &'a Option<Session>) -> Option<&'a Session> { pub fn session<'a>(&self, session: &'a Option<Session>) -> Option<&'a Session> {
if self.context.in_session { if self.context.session {
session.as_ref() session.as_ref()
} else { } else {
None None
@ -114,7 +114,7 @@ impl Input {
} }
pub fn session_mut<'a>(&self, session: &'a mut Option<Session>) -> Option<&'a mut Session> { pub fn session_mut<'a>(&self, session: &'a mut Option<Session>) -> Option<&'a mut Session> {
if self.context.in_session { if self.context.session {
session.as_mut() session.as_mut()
} else { } else {
None None
@ -199,12 +199,19 @@ impl Input {
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct InputContext { pub struct InputContext {
role: Option<Role>, role: Option<Role>,
in_session: bool, session: bool,
} }
impl InputContext { impl InputContext {
pub fn new(role: Option<Role>, in_session: bool) -> Self { pub fn new(role: Option<Role>, session: bool) -> Self {
Self { role, in_session } Self { role, session }
}
pub fn role(role: Role) -> Self {
Self {
role: Some(role),
session: false,
}
} }
} }

@ -3,7 +3,7 @@ mod role;
mod session; mod session;
pub use self::input::{Input, InputContext}; pub use self::input::{Input, InputContext};
pub use self::role::{Role, CODE_ROLE, EXPLAIN_ROLE, SHELL_ROLE}; pub use self::role::{Role, CODE_ROLE, EXPLAIN_SHELL_ROLE, SHELL_ROLE};
use self::session::{Session, TEMP_SESSION_NAME}; use self::session::{Session, TEMP_SESSION_NAME};
use crate::client::{ use crate::client::{
@ -190,7 +190,6 @@ impl Config {
role.complete_prompt_args(name); role.complete_prompt_args(name);
role role
}) })
.or_else(|| Role::find_system_role(name))
.ok_or_else(|| anyhow!("Unknown role `{name}`")) .ok_or_else(|| anyhow!("Unknown role `{name}`"))
} }
@ -682,10 +681,6 @@ impl Config {
Ok(()) Ok(())
} }
pub fn has_session(&self) -> bool {
self.session.is_some()
}
pub fn clear_session_messages(&mut self) -> Result<()> { pub fn clear_session_messages(&mut self) -> Result<()> {
if let Some(session) = self.session.as_mut() { if let Some(session) = self.session.as_mut() {
session.clear_messages(); session.clear_messages();
@ -818,7 +813,7 @@ impl Config {
} }
pub fn input_context(&self) -> InputContext { pub fn input_context(&self) -> InputContext {
InputContext::new(self.role.clone(), self.has_session()) InputContext::new(self.role.clone(), self.session.is_some())
} }
pub fn maybe_print_send_tokens(&self, input: &Input) { pub fn maybe_print_send_tokens(&self, input: &Input) {
@ -978,7 +973,15 @@ impl Config {
.with_context(|| format!("Failed to load roles at {}", path.display()))?; .with_context(|| format!("Failed to load roles at {}", path.display()))?;
let roles: Vec<Role> = let roles: Vec<Role> =
serde_yaml::from_str(&content).with_context(|| "Invalid roles config")?; serde_yaml::from_str(&content).with_context(|| "Invalid roles config")?;
let exist_roles: HashSet<_> = roles.iter().map(|v| v.name.clone()).collect();
self.roles = roles; self.roles = roles;
let builtin_roles = Role::builtin();
for role in builtin_roles {
if !exist_roles.contains(&role.name) {
self.roles.push(role);
}
}
Ok(()) Ok(())
} }

@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize};
pub const TEMP_ROLE: &str = "%%"; pub const TEMP_ROLE: &str = "%%";
pub const SHELL_ROLE: &str = "%shell%"; pub const SHELL_ROLE: &str = "%shell%";
pub const EXPLAIN_ROLE: &str = "%explain%"; pub const EXPLAIN_SHELL_ROLE: &str = "%explain-shell%";
pub const CODE_ROLE: &str = "%code%"; pub const CODE_ROLE: &str = "%code%";
pub const INPUT_PLACEHOLDER: &str = "__INPUT__"; pub const INPUT_PLACEHOLDER: &str = "__INPUT__";
@ -32,61 +32,20 @@ impl Role {
} }
} }
pub fn find_system_role(name: &str) -> Option<Self> { pub fn builtin() -> Vec<Role> {
match name { [
SHELL_ROLE => Some(Self::shell()), (SHELL_ROLE, shell_prompt()),
EXPLAIN_ROLE => Some(Self::explain()), (
CODE_ROLE => Some(Self::code()), EXPLAIN_SHELL_ROLE,
_ => None, r#"Provide a terse, single sentence description of the given shell command.
}
}
pub fn shell() -> Self {
let os = detect_os();
let (detected_shell, _, _) = detect_shell();
let (shell, use_semicolon) = match (detected_shell.as_str(), os.as_str()) {
// GPT doesnt know much about nushell
("nushell", "windows") => ("cmd", true),
("nushell", _) => ("bash", true),
("powershell", _) => ("powershell", true),
("pwsh", _) => ("powershell", false),
_ => (detected_shell.as_str(), false),
};
let combine = if use_semicolon {
"\nIf multiple steps required try to combine them together using ';'.\nIf it already combined with '&&' try to replace it with ';'.".to_string()
} else {
"\nIf multiple steps required try to combine them together using &&.".to_string()
};
Self {
name: SHELL_ROLE.into(),
prompt: format!(
r#"Provide only {shell} commands for {os} without any description.
Ensure the output is a valid {shell} command. {combine}
If there is a lack of details, provide most logical solution.
Output plain text only, without any markdown formatting."#
),
temperature: None,
top_p: None,
}
}
pub fn explain() -> Self {
Self {
name: EXPLAIN_ROLE.into(),
prompt: r#"Provide a terse, single sentence description of the given shell command.
Describe each argument and option of the command. Describe each argument and option of the command.
Provide short responses in about 80 words. Provide short responses in about 80 words.
APPLY MARKDOWN formatting when possible."# APPLY MARKDOWN formatting when possible."#
.into(), .into(),
temperature: None, ),
top_p: None, (
} CODE_ROLE,
} r#"Provide only code without comments or explanations.
pub fn code() -> Self {
Self {
name: CODE_ROLE.into(),
prompt: r#"Provide only code without comments or explanations.
### INPUT: ### INPUT:
async sleep in js async sleep in js
### OUTPUT: ### OUTPUT:
@ -96,10 +55,17 @@ async function timeout(ms) {
} }
``` ```
"# "#
.into(), .into(),
),
]
.into_iter()
.map(|(name, prompt)| Self {
name: name.into(),
prompt,
temperature: None, temperature: None,
top_p: None, top_p: None,
} })
.collect()
} }
pub fn export(&self) -> Result<String> { pub fn export(&self) -> Result<String> {
@ -241,6 +207,30 @@ fn parse_structure_prompt(prompt: &str) -> (&str, Vec<(&str, &str)>) {
(prompt, vec![]) (prompt, vec![])
} }
fn shell_prompt() -> String {
let os = detect_os();
let (detected_shell, _, _) = detect_shell();
let (shell, use_semicolon) = match (detected_shell.as_str(), os.as_str()) {
// GPT doesnt know much about nushell
("nushell", "windows") => ("cmd", true),
("nushell", _) => ("bash", true),
("powershell", _) => ("powershell", true),
("pwsh", _) => ("powershell", false),
_ => (detected_shell.as_str(), false),
};
let combine = if use_semicolon {
"\nIf multiple steps required try to combine them together using ';'.\nIf it already combined with '&&' try to replace it with ';'.".to_string()
} else {
"\nIf multiple steps required try to combine them together using '&&'.".to_string()
};
format!(
r#"Provide only {shell} commands for {os} without any description.
Ensure the output is a valid {shell} command. {combine}
If there is a lack of details, provide most logical solution.
Output plain text only, without any markdown formatting."#
)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

@ -14,7 +14,8 @@ extern crate log;
use crate::cli::Cli; use crate::cli::Cli;
use crate::client::{ensure_model_capabilities, init_client, list_models, send_stream}; use crate::client::{ensure_model_capabilities, init_client, list_models, send_stream};
use crate::config::{ use crate::config::{
Config, GlobalConfig, Input, WorkingMode, CODE_ROLE, EXPLAIN_ROLE, SHELL_ROLE, Config, GlobalConfig, Input, InputContext, WorkingMode, CODE_ROLE, EXPLAIN_SHELL_ROLE,
SHELL_ROLE,
}; };
use crate::render::{render_error, MarkdownRender}; use crate::render::{render_error, MarkdownRender};
use crate::repl::Repl; use crate::repl::Repl;
@ -200,7 +201,6 @@ async fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> {
return Ok(()); return Ok(());
} }
if is_terminal_stdout { if is_terminal_stdout {
let mut explain = false;
loop { loop {
let answer = Select::new( let answer = Select::new(
markdown_render.render(&eval_str).trim(), markdown_render.render(&eval_str).trim(),
@ -222,13 +222,10 @@ async fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> {
return execute(config, input).await; return execute(config, input).await;
} }
"📙 Explain" => { "📙 Explain" => {
if !explain { let role = config.read().retrieve_role(EXPLAIN_SHELL_ROLE)?;
config.write().set_role(EXPLAIN_ROLE)?; let input = Input::from_str(&eval_str, InputContext::role(role));
}
let input = Input::from_str(&eval_str, config.read().input_context());
let abort = create_abort_signal(); let abort = create_abort_signal();
send_stream(&input, client.as_ref(), config, abort).await?; send_stream(&input, client.as_ref(), config, abort).await?;
explain = true;
continue; continue;
} }
_ => {} _ => {}

@ -176,8 +176,7 @@ impl Repl {
Some(args) => match args.split_once(|c| c == '\n' || c == ' ') { Some(args) => match args.split_once(|c| c == '\n' || c == ' ') {
Some((name, text)) => { Some((name, text)) => {
let role = self.config.read().retrieve_role(name.trim())?; let role = self.config.read().retrieve_role(name.trim())?;
let input = let input = Input::from_str(text.trim(), InputContext::role(role));
Input::from_str(text.trim(), InputContext::new(Some(role), false));
self.ask(input).await?; self.ask(input).await?;
} }
None => { None => {

Loading…
Cancel
Save