diff --git a/README.md b/README.md index 49686db..a96e004 100644 --- a/README.md +++ b/README.md @@ -155,8 +155,9 @@ aichat has a powerful Chat REPL. The Chat REPL supports: - Emacs/Vi keybinding -- Command autocompletion -- Edit/paste multiline input +- [Custom REPL Prompt](https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt) +- Tab Completion +- Edit/paste multiline text - Undo support ### `.help` - print help message diff --git a/config.example.yaml b/config.example.yaml index 874a53f..975c0e1 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -9,6 +9,10 @@ auto_copy: false # Automatically copy the last output to the cli keybindings: emacs # REPL keybindings. (emacs, vi) prelude: '' # Set a default role or session (role:, session:) +# Custom REPL prompt, see https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt +left_prompt: '{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} ' +right_prompt: '{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}' + clients: # All clients have the following configuration: # - type: xxxx @@ -38,7 +42,7 @@ clients: # See https://github.com/jmorganca/ollama - type: ollama api_base: http://localhost:11434/api - api_key: Baisc xxx + api_key: Basic xxx # Set authorization header chat_endpoint: /chat # Optional field models: - name: gpt4all-j diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index 022e286..5b8755f 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -246,7 +246,7 @@ fn build_body(data: SendData, model: String, is_vl: bool) -> Result<(Value, bool Ok((body, has_upload)) } -/// Patch messsages, upload emebeded images to oss +/// Patch messsages, upload embedded images to oss async fn patch_messages(model: &str, api_key: &str, messages: &mut Vec) -> Result<()> { for message in messages { if let MessageContent::Array(list) = message.content.borrow_mut() { @@ -258,7 +258,7 @@ async fn patch_messages(model: &str, api_key: &str, messages: &mut Vec) if url.starts_with("data:") { *url = upload(model, api_key, url) .await - .with_context(|| "Failed to upload embeded image to oss")?; + .with_context(|| "Failed to upload embedded image to oss")?; } } } diff --git a/src/config/mod.rs b/src/config/mod.rs index 52203ba..71e47da 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -11,13 +11,14 @@ use crate::client::{ Model, OpenAIClient, SendData, }; use crate::render::{MarkdownRender, RenderOptions}; -use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err}; +use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err, render_prompt}; use anyhow::{anyhow, bail, Context, Result}; use inquire::{Confirm, Select, Text}; use is_terminal::IsTerminal; use parking_lot::RwLock; use serde::Deserialize; +use std::collections::HashMap; use std::{ env, fs::{create_dir_all, read_dir, read_to_string, remove_file, File, OpenOptions}, @@ -66,6 +67,10 @@ pub struct Config { pub keybindings: Keybindings, /// Set a default role or session (role:, session:) pub prelude: String, + /// REPL left prompt + pub left_prompt: String, + /// REPL right prompt + pub right_prompt: String, /// Setup clients pub clients: Vec, /// Predefined roles @@ -99,6 +104,9 @@ impl Default for Config { auto_copy: false, keybindings: Default::default(), prelude: String::new(), + left_prompt: "{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} ".to_string(), + right_prompt: "{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}" + .to_string(), clients: vec![ClientConfig::default()], roles: vec![], role: None, @@ -648,18 +656,14 @@ impl Config { Ok(RenderOptions::new(theme, wrap, self.wrap_code)) } + pub fn render_prompt_left(&self) -> String { + let variables = self.generate_prompt_context(); + render_prompt(&self.left_prompt, &variables) + } + pub fn render_prompt_right(&self) -> String { - if let Some(session) = &self.session { - let (tokens, percent) = session.tokens_and_percent(); - let percent = if percent == 0.0 { - String::new() - } else { - format!("({percent}%)") - }; - format!("{tokens}{percent}") - } else { - String::new() - } + let variables = self.generate_prompt_context(); + render_prompt(&self.right_prompt, &variables) } pub fn prepare_send_data(&self, input: &Input, stream: bool) -> Result { @@ -681,6 +685,70 @@ impl Config { } } + fn generate_prompt_context(&self) -> HashMap<&str, String> { + let mut output = HashMap::new(); + output.insert("model", self.model.id()); + output.insert("client_name", self.model.client_name.clone()); + output.insert("model_name", self.model.name.clone()); + output.insert( + "max_tokens", + self.model.max_tokens.unwrap_or_default().to_string(), + ); + if let Some(temperature) = self.temperature { + if temperature != 0.0 { + output.insert("temperature", temperature.to_string()); + } + } + if self.dry_run { + output.insert("dry_run", "true".to_string()); + } + if self.save { + output.insert("save", "true".to_string()); + } + if let Some(wrap) = &self.wrap { + if wrap != "no" { + output.insert("wrap", wrap.clone()); + } + } + if self.auto_copy { + output.insert("auto_copy", "true".to_string()); + } + if let Some(role) = &self.role { + output.insert("role", role.name.clone()); + } + if let Some(session) = &self.session { + output.insert("session", session.name().to_string()); + let (tokens, percent) = session.tokens_and_percent(); + output.insert("consume_tokens", tokens.to_string()); + output.insert("consume_percent", percent.to_string()); + output.insert("user_messages_len", session.user_messages_len().to_string()); + } + + if self.highlight { + output.insert("color.reset", "\u{1b}[0m".to_string()); + output.insert("color.black", "\u{1b}[30m".to_string()); + output.insert("color.dark_gray", "\u{1b}[90m".to_string()); + output.insert("color.red", "\u{1b}[31m".to_string()); + output.insert("color.light_red", "\u{1b}[91m".to_string()); + output.insert("color.green", "\u{1b}[32m".to_string()); + output.insert("color.light_green", "\u{1b}[92m".to_string()); + output.insert("color.yellow", "\u{1b}[33m".to_string()); + output.insert("color.light_yellow", "\u{1b}[93m".to_string()); + output.insert("color.blue", "\u{1b}[34m".to_string()); + output.insert("color.light_blue", "\u{1b}[94m".to_string()); + output.insert("color.purple", "\u{1b}[35m".to_string()); + output.insert("color.light_purple", "\u{1b}[95m".to_string()); + output.insert("color.magenta", "\u{1b}[35m".to_string()); + output.insert("color.light_magenta", "\u{1b}[95m".to_string()); + output.insert("color.cyan", "\u{1b}[36m".to_string()); + output.insert("color.light_cyan", "\u{1b}[96m".to_string()); + output.insert("color.white", "\u{1b}[37m".to_string()); + output.insert("color.light_gray", "\u{1b}[97m".to_string()); + } + + output + } + fn open_message_file(&self) -> Result { let path = Self::messages_file()?; ensure_parent_exists(&path)?; diff --git a/src/config/session.rs b/src/config/session.rs index e135c58..cbfe1c2 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -78,6 +78,10 @@ impl Session { self.model.total_tokens(&self.messages) } + pub fn user_messages_len(&self) -> usize { + self.messages.iter().filter(|v| v.role.is_user()).count() + } + pub fn export(&self) -> Result { self.guard_save()?; let (tokens, percent) = self.tokens_and_percent(); diff --git a/src/repl/prompt.rs b/src/repl/prompt.rs index 30600af..7b52106 100644 --- a/src/repl/prompt.rs +++ b/src/repl/prompt.rs @@ -1,14 +1,8 @@ use crate::config::GlobalConfig; -use crossterm::style::Color; use reedline::{Prompt, PromptHistorySearch, PromptHistorySearchStatus}; use std::borrow::Cow; -const PROMPT_COLOR: Color = Color::Green; -const PROMPT_MULTILINE_COLOR: nu_ansi_term::Color = nu_ansi_term::Color::LightBlue; -const INDICATOR_COLOR: Color = Color::Cyan; -const PROMPT_RIGHT_COLOR: Color = Color::AnsiValue(5); - #[derive(Clone)] pub struct ReplPrompt { config: GlobalConfig, @@ -24,13 +18,7 @@ impl ReplPrompt { impl Prompt for ReplPrompt { fn render_prompt_left(&self) -> Cow { - if let Some(session) = &self.config.read().session { - Cow::Owned(session.name().to_string()) - } else if let Some(role) = &self.config.read().role { - Cow::Owned(role.name.clone()) - } else { - Cow::Borrowed("") - } + Cow::Owned(self.config.read().render_prompt_left()) } fn render_prompt_right(&self) -> Cow { @@ -38,11 +26,7 @@ impl Prompt for ReplPrompt { } fn render_prompt_indicator(&self, _prompt_mode: reedline::PromptEditMode) -> Cow { - if self.config.read().session.is_some() { - Cow::Borrowed(") ") - } else { - Cow::Borrowed("> ") - } + Cow::Borrowed("") } fn render_prompt_multiline_indicator(&self) -> Cow { @@ -64,20 +48,4 @@ impl Prompt for ReplPrompt { prefix, history_search.term )) } - - fn get_prompt_color(&self) -> Color { - PROMPT_COLOR - } - /// Get the default multiline prompt color - fn get_prompt_multiline_color(&self) -> nu_ansi_term::Color { - PROMPT_MULTILINE_COLOR - } - /// Get the default indicator color - fn get_indicator_color(&self) -> Color { - INDICATOR_COLOR - } - /// Get the default right prompt color - fn get_prompt_right_color(&self) -> Color { - PROMPT_RIGHT_COLOR - } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index bd9f381..04c5668 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,11 +1,13 @@ mod abort_signal; mod clipboard; mod prompt_input; +mod render_prompt; mod tiktoken; pub use self::abort_signal::{create_abort_signal, AbortSignal}; pub use self::clipboard::set_text; pub use self::prompt_input::*; +pub use self::render_prompt::render_prompt; pub use self::tiktoken::cl100k_base_singleton; use sha2::{Digest, Sha256}; diff --git a/src/utils/render_prompt.rs b/src/utils/render_prompt.rs new file mode 100644 index 0000000..12661fa --- /dev/null +++ b/src/utils/render_prompt.rs @@ -0,0 +1,155 @@ +use std::collections::HashMap; + +/// Render REPL prompt +/// +/// The template comprises plain text and `{...}`. +/// +/// The syntax of `{...}`: +/// - `{var}` - When `var` has a value, replace `var` with the value and eval `template` +/// - `{?var