feat: custom REPL prompt (#283)

This commit is contained in:
sigoden 2023-12-24 16:04:18 +08:00 committed by GitHub
parent 89fefb4d1a
commit 1c9ca1b002
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 253 additions and 51 deletions

View File

@ -155,8 +155,9 @@ aichat has a powerful Chat REPL.
The Chat REPL supports:
- Emacs/Vi keybinding
- Command autocompletion
- Edit/paste multiline input
- [Custom REPL Prompt](https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt)
- Tab Completion
- Edit/paste multiline text
- Undo support
### `.help` - print help message

View File

@ -9,6 +9,10 @@ auto_copy: false # Automatically copy the last output to the cli
keybindings: emacs # REPL keybindings. (emacs, vi)
prelude: '' # Set a default role or session (role:<name>, session:<name>)
# Custom REPL prompt, see https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt
left_prompt: '{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} '
right_prompt: '{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}'
clients:
# All clients have the following configuration:
# - type: xxxx
@ -38,7 +42,7 @@ clients:
# See https://github.com/jmorganca/ollama
- type: ollama
api_base: http://localhost:11434/api
api_key: Baisc xxx
api_key: Basic xxx # Set authorization header
chat_endpoint: /chat # Optional field
models:
- name: gpt4all-j

View File

@ -246,7 +246,7 @@ fn build_body(data: SendData, model: String, is_vl: bool) -> Result<(Value, bool
Ok((body, has_upload))
}
/// Patch messsages, upload emebeded images to oss
/// Patch messsages, upload embedded images to oss
async fn patch_messages(model: &str, api_key: &str, messages: &mut Vec<Message>) -> Result<()> {
for message in messages {
if let MessageContent::Array(list) = message.content.borrow_mut() {
@ -258,7 +258,7 @@ async fn patch_messages(model: &str, api_key: &str, messages: &mut Vec<Message>)
if url.starts_with("data:") {
*url = upload(model, api_key, url)
.await
.with_context(|| "Failed to upload embeded image to oss")?;
.with_context(|| "Failed to upload embedded image to oss")?;
}
}
}

View File

@ -11,13 +11,14 @@ use crate::client::{
Model, OpenAIClient, SendData,
};
use crate::render::{MarkdownRender, RenderOptions};
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err};
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err, render_prompt};
use anyhow::{anyhow, bail, Context, Result};
use inquire::{Confirm, Select, Text};
use is_terminal::IsTerminal;
use parking_lot::RwLock;
use serde::Deserialize;
use std::collections::HashMap;
use std::{
env,
fs::{create_dir_all, read_dir, read_to_string, remove_file, File, OpenOptions},
@ -66,6 +67,10 @@ pub struct Config {
pub keybindings: Keybindings,
/// Set a default role or session (role:<name>, session:<name>)
pub prelude: String,
/// REPL left prompt
pub left_prompt: String,
/// REPL right prompt
pub right_prompt: String,
/// Setup clients
pub clients: Vec<ClientConfig>,
/// Predefined roles
@ -99,6 +104,9 @@ impl Default for Config {
auto_copy: false,
keybindings: Default::default(),
prelude: String::new(),
left_prompt: "{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} ".to_string(),
right_prompt: "{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}"
.to_string(),
clients: vec![ClientConfig::default()],
roles: vec![],
role: None,
@ -648,18 +656,14 @@ impl Config {
Ok(RenderOptions::new(theme, wrap, self.wrap_code))
}
pub fn render_prompt_left(&self) -> String {
let variables = self.generate_prompt_context();
render_prompt(&self.left_prompt, &variables)
}
pub fn render_prompt_right(&self) -> String {
if let Some(session) = &self.session {
let (tokens, percent) = session.tokens_and_percent();
let percent = if percent == 0.0 {
String::new()
} else {
format!("({percent}%)")
};
format!("{tokens}{percent}")
} else {
String::new()
}
let variables = self.generate_prompt_context();
render_prompt(&self.right_prompt, &variables)
}
pub fn prepare_send_data(&self, input: &Input, stream: bool) -> Result<SendData> {
@ -681,6 +685,70 @@ impl Config {
}
}
fn generate_prompt_context(&self) -> HashMap<&str, String> {
let mut output = HashMap::new();
output.insert("model", self.model.id());
output.insert("client_name", self.model.client_name.clone());
output.insert("model_name", self.model.name.clone());
output.insert(
"max_tokens",
self.model.max_tokens.unwrap_or_default().to_string(),
);
if let Some(temperature) = self.temperature {
if temperature != 0.0 {
output.insert("temperature", temperature.to_string());
}
}
if self.dry_run {
output.insert("dry_run", "true".to_string());
}
if self.save {
output.insert("save", "true".to_string());
}
if let Some(wrap) = &self.wrap {
if wrap != "no" {
output.insert("wrap", wrap.clone());
}
}
if self.auto_copy {
output.insert("auto_copy", "true".to_string());
}
if let Some(role) = &self.role {
output.insert("role", role.name.clone());
}
if let Some(session) = &self.session {
output.insert("session", session.name().to_string());
let (tokens, percent) = session.tokens_and_percent();
output.insert("consume_tokens", tokens.to_string());
output.insert("consume_percent", percent.to_string());
output.insert("user_messages_len", session.user_messages_len().to_string());
}
if self.highlight {
output.insert("color.reset", "\u{1b}[0m".to_string());
output.insert("color.black", "\u{1b}[30m".to_string());
output.insert("color.dark_gray", "\u{1b}[90m".to_string());
output.insert("color.red", "\u{1b}[31m".to_string());
output.insert("color.light_red", "\u{1b}[91m".to_string());
output.insert("color.green", "\u{1b}[32m".to_string());
output.insert("color.light_green", "\u{1b}[92m".to_string());
output.insert("color.yellow", "\u{1b}[33m".to_string());
output.insert("color.light_yellow", "\u{1b}[93m".to_string());
output.insert("color.blue", "\u{1b}[34m".to_string());
output.insert("color.light_blue", "\u{1b}[94m".to_string());
output.insert("color.purple", "\u{1b}[35m".to_string());
output.insert("color.light_purple", "\u{1b}[95m".to_string());
output.insert("color.magenta", "\u{1b}[35m".to_string());
output.insert("color.light_magenta", "\u{1b}[95m".to_string());
output.insert("color.cyan", "\u{1b}[36m".to_string());
output.insert("color.light_cyan", "\u{1b}[96m".to_string());
output.insert("color.white", "\u{1b}[37m".to_string());
output.insert("color.light_gray", "\u{1b}[97m".to_string());
}
output
}
fn open_message_file(&self) -> Result<File> {
let path = Self::messages_file()?;
ensure_parent_exists(&path)?;

View File

@ -78,6 +78,10 @@ impl Session {
self.model.total_tokens(&self.messages)
}
pub fn user_messages_len(&self) -> usize {
self.messages.iter().filter(|v| v.role.is_user()).count()
}
pub fn export(&self) -> Result<String> {
self.guard_save()?;
let (tokens, percent) = self.tokens_and_percent();

View File

@ -1,14 +1,8 @@
use crate::config::GlobalConfig;
use crossterm::style::Color;
use reedline::{Prompt, PromptHistorySearch, PromptHistorySearchStatus};
use std::borrow::Cow;
const PROMPT_COLOR: Color = Color::Green;
const PROMPT_MULTILINE_COLOR: nu_ansi_term::Color = nu_ansi_term::Color::LightBlue;
const INDICATOR_COLOR: Color = Color::Cyan;
const PROMPT_RIGHT_COLOR: Color = Color::AnsiValue(5);
#[derive(Clone)]
pub struct ReplPrompt {
config: GlobalConfig,
@ -24,13 +18,7 @@ impl ReplPrompt {
impl Prompt for ReplPrompt {
fn render_prompt_left(&self) -> Cow<str> {
if let Some(session) = &self.config.read().session {
Cow::Owned(session.name().to_string())
} else if let Some(role) = &self.config.read().role {
Cow::Owned(role.name.clone())
} else {
Cow::Borrowed("")
}
Cow::Owned(self.config.read().render_prompt_left())
}
fn render_prompt_right(&self) -> Cow<str> {
@ -38,11 +26,7 @@ impl Prompt for ReplPrompt {
}
fn render_prompt_indicator(&self, _prompt_mode: reedline::PromptEditMode) -> Cow<str> {
if self.config.read().session.is_some() {
Cow::Borrowed(") ")
} else {
Cow::Borrowed("> ")
}
Cow::Borrowed("")
}
fn render_prompt_multiline_indicator(&self) -> Cow<str> {
@ -64,20 +48,4 @@ impl Prompt for ReplPrompt {
prefix, history_search.term
))
}
fn get_prompt_color(&self) -> Color {
PROMPT_COLOR
}
/// Get the default multiline prompt color
fn get_prompt_multiline_color(&self) -> nu_ansi_term::Color {
PROMPT_MULTILINE_COLOR
}
/// Get the default indicator color
fn get_indicator_color(&self) -> Color {
INDICATOR_COLOR
}
/// Get the default right prompt color
fn get_prompt_right_color(&self) -> Color {
PROMPT_RIGHT_COLOR
}
}

View File

@ -1,11 +1,13 @@
mod abort_signal;
mod clipboard;
mod prompt_input;
mod render_prompt;
mod tiktoken;
pub use self::abort_signal::{create_abort_signal, AbortSignal};
pub use self::clipboard::set_text;
pub use self::prompt_input::*;
pub use self::render_prompt::render_prompt;
pub use self::tiktoken::cl100k_base_singleton;
use sha2::{Digest, Sha256};

155
src/utils/render_prompt.rs Normal file
View File

@ -0,0 +1,155 @@
use std::collections::HashMap;
/// Render REPL prompt
///
/// The template comprises plain text and `{...}`.
///
/// The syntax of `{...}`:
/// - `{var}` - When `var` has a value, replace `var` with the value and eval `template`
/// - `{?var <template>}` - Eval `template` when `var` is evaluated as true
/// - `{!var <template>}` - Eval `template` when `var` is evaluated as false
pub fn render_prompt(template: &str, variables: &HashMap<&str, String>) -> String {
let exprs = parse_template(template);
eval_exprs(&exprs, variables)
}
fn parse_template(template: &str) -> Vec<Expr> {
let chars: Vec<char> = template.chars().collect();
let mut exprs = vec![];
let mut current = vec![];
let mut balances = vec![];
for ch in chars.iter().cloned() {
if !balances.is_empty() {
if ch == '}' {
balances.pop();
if balances.is_empty() {
if !current.is_empty() {
let block = parse_block(&mut current);
exprs.push(block)
}
} else {
current.push(ch);
}
} else if ch == '{' {
balances.push(ch);
current.push(ch);
} else {
current.push(ch);
}
} else if ch == '{' {
balances.push(ch);
add_text(&mut exprs, &mut current);
} else {
current.push(ch)
}
}
add_text(&mut exprs, &mut current);
exprs
}
fn parse_block(current: &mut Vec<char>) -> Expr {
let value: String = current.drain(..).collect();
match value.split_once(' ') {
Some((name, tail)) => {
if let Some(name) = name.strip_prefix('?') {
let block_exprs = parse_template(tail);
Expr::Block(BlockType::Yes, name.to_string(), block_exprs)
} else if let Some(name) = name.strip_prefix('!') {
let block_exprs = parse_template(tail);
Expr::Block(BlockType::No, name.to_string(), block_exprs)
} else {
Expr::Text(format!("{{{value}}}"))
}
}
None => Expr::Variable(value),
}
}
fn eval_exprs(exprs: &[Expr], variables: &HashMap<&str, String>) -> String {
let mut output = String::new();
for part in exprs {
match part {
Expr::Text(text) => output.push_str(text),
Expr::Variable(variable) => {
let value = variables
.get(variable.as_str())
.cloned()
.unwrap_or_default();
output.push_str(&value);
}
Expr::Block(typ, variable, block_exprs) => {
let value = variables
.get(variable.as_str())
.cloned()
.unwrap_or_default();
match typ {
BlockType::Yes => {
if truly(&value) {
let block_output = eval_exprs(block_exprs, variables);
output.push_str(&block_output)
}
}
BlockType::No => {
if !truly(&value) {
let block_output = eval_exprs(block_exprs, variables);
output.push_str(&block_output)
}
}
}
}
}
}
output
}
fn add_text(exprs: &mut Vec<Expr>, current: &mut Vec<char>) {
if current.is_empty() {
return;
}
let value: String = current.drain(..).collect();
exprs.push(Expr::Text(value));
}
fn truly(value: &str) -> bool {
!(value.is_empty() || value == "0" || value == "false")
}
#[derive(Debug)]
enum Expr {
Text(String),
Variable(String),
Block(BlockType, String, Vec<Expr>),
}
#[derive(Debug)]
enum BlockType {
Yes,
No,
}
#[cfg(test)]
mod tests {
use super::*;
macro_rules! assert_render {
($template:expr, [$(($key:literal, $value:literal),)*], $expect:literal) => {
let data = HashMap::from([
$(($key, $value.into()),)*
]);
assert_eq!(render_prompt($template, &data), $expect);
};
}
#[test]
fn test_render() {
let prompt = "{?session {session}{?role /}}{role}{?session )}{!session >}";
assert_render!(prompt, [], ">");
assert_render!(prompt, [("role", "coder"),], "coder>");
assert_render!(prompt, [("session", "temp"),], "temp)");
assert_render!(
prompt,
[("session", "temp"), ("role", "coder"),],
"temp/coder)"
);
}
}