mirror of
https://github.com/sigoden/aichat
synced 2024-11-16 06:15:26 +00:00
feat: custom REPL prompt (#283)
This commit is contained in:
parent
89fefb4d1a
commit
1c9ca1b002
@ -155,8 +155,9 @@ aichat has a powerful Chat REPL.
|
||||
The Chat REPL supports:
|
||||
|
||||
- Emacs/Vi keybinding
|
||||
- Command autocompletion
|
||||
- Edit/paste multiline input
|
||||
- [Custom REPL Prompt](https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt)
|
||||
- Tab Completion
|
||||
- Edit/paste multiline text
|
||||
- Undo support
|
||||
|
||||
### `.help` - print help message
|
||||
|
@ -9,6 +9,10 @@ auto_copy: false # Automatically copy the last output to the cli
|
||||
keybindings: emacs # REPL keybindings. (emacs, vi)
|
||||
prelude: '' # Set a default role or session (role:<name>, session:<name>)
|
||||
|
||||
# Custom REPL prompt, see https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt
|
||||
left_prompt: '{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} '
|
||||
right_prompt: '{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}'
|
||||
|
||||
clients:
|
||||
# All clients have the following configuration:
|
||||
# - type: xxxx
|
||||
@ -38,7 +42,7 @@ clients:
|
||||
# See https://github.com/jmorganca/ollama
|
||||
- type: ollama
|
||||
api_base: http://localhost:11434/api
|
||||
api_key: Baisc xxx
|
||||
api_key: Basic xxx # Set authorization header
|
||||
chat_endpoint: /chat # Optional field
|
||||
models:
|
||||
- name: gpt4all-j
|
||||
|
@ -246,7 +246,7 @@ fn build_body(data: SendData, model: String, is_vl: bool) -> Result<(Value, bool
|
||||
Ok((body, has_upload))
|
||||
}
|
||||
|
||||
/// Patch messsages, upload emebeded images to oss
|
||||
/// Patch messsages, upload embedded images to oss
|
||||
async fn patch_messages(model: &str, api_key: &str, messages: &mut Vec<Message>) -> Result<()> {
|
||||
for message in messages {
|
||||
if let MessageContent::Array(list) = message.content.borrow_mut() {
|
||||
@ -258,7 +258,7 @@ async fn patch_messages(model: &str, api_key: &str, messages: &mut Vec<Message>)
|
||||
if url.starts_with("data:") {
|
||||
*url = upload(model, api_key, url)
|
||||
.await
|
||||
.with_context(|| "Failed to upload embeded image to oss")?;
|
||||
.with_context(|| "Failed to upload embedded image to oss")?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -11,13 +11,14 @@ use crate::client::{
|
||||
Model, OpenAIClient, SendData,
|
||||
};
|
||||
use crate::render::{MarkdownRender, RenderOptions};
|
||||
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err};
|
||||
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err, render_prompt};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use inquire::{Confirm, Select, Text};
|
||||
use is_terminal::IsTerminal;
|
||||
use parking_lot::RwLock;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::{
|
||||
env,
|
||||
fs::{create_dir_all, read_dir, read_to_string, remove_file, File, OpenOptions},
|
||||
@ -66,6 +67,10 @@ pub struct Config {
|
||||
pub keybindings: Keybindings,
|
||||
/// Set a default role or session (role:<name>, session:<name>)
|
||||
pub prelude: String,
|
||||
/// REPL left prompt
|
||||
pub left_prompt: String,
|
||||
/// REPL right prompt
|
||||
pub right_prompt: String,
|
||||
/// Setup clients
|
||||
pub clients: Vec<ClientConfig>,
|
||||
/// Predefined roles
|
||||
@ -99,6 +104,9 @@ impl Default for Config {
|
||||
auto_copy: false,
|
||||
keybindings: Default::default(),
|
||||
prelude: String::new(),
|
||||
left_prompt: "{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} ".to_string(),
|
||||
right_prompt: "{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}"
|
||||
.to_string(),
|
||||
clients: vec![ClientConfig::default()],
|
||||
roles: vec![],
|
||||
role: None,
|
||||
@ -648,18 +656,14 @@ impl Config {
|
||||
Ok(RenderOptions::new(theme, wrap, self.wrap_code))
|
||||
}
|
||||
|
||||
pub fn render_prompt_left(&self) -> String {
|
||||
let variables = self.generate_prompt_context();
|
||||
render_prompt(&self.left_prompt, &variables)
|
||||
}
|
||||
|
||||
pub fn render_prompt_right(&self) -> String {
|
||||
if let Some(session) = &self.session {
|
||||
let (tokens, percent) = session.tokens_and_percent();
|
||||
let percent = if percent == 0.0 {
|
||||
String::new()
|
||||
} else {
|
||||
format!("({percent}%)")
|
||||
};
|
||||
format!("{tokens}{percent}")
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
let variables = self.generate_prompt_context();
|
||||
render_prompt(&self.right_prompt, &variables)
|
||||
}
|
||||
|
||||
pub fn prepare_send_data(&self, input: &Input, stream: bool) -> Result<SendData> {
|
||||
@ -681,6 +685,70 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_prompt_context(&self) -> HashMap<&str, String> {
|
||||
let mut output = HashMap::new();
|
||||
output.insert("model", self.model.id());
|
||||
output.insert("client_name", self.model.client_name.clone());
|
||||
output.insert("model_name", self.model.name.clone());
|
||||
output.insert(
|
||||
"max_tokens",
|
||||
self.model.max_tokens.unwrap_or_default().to_string(),
|
||||
);
|
||||
if let Some(temperature) = self.temperature {
|
||||
if temperature != 0.0 {
|
||||
output.insert("temperature", temperature.to_string());
|
||||
}
|
||||
}
|
||||
if self.dry_run {
|
||||
output.insert("dry_run", "true".to_string());
|
||||
}
|
||||
if self.save {
|
||||
output.insert("save", "true".to_string());
|
||||
}
|
||||
if let Some(wrap) = &self.wrap {
|
||||
if wrap != "no" {
|
||||
output.insert("wrap", wrap.clone());
|
||||
}
|
||||
}
|
||||
if self.auto_copy {
|
||||
output.insert("auto_copy", "true".to_string());
|
||||
}
|
||||
if let Some(role) = &self.role {
|
||||
output.insert("role", role.name.clone());
|
||||
}
|
||||
if let Some(session) = &self.session {
|
||||
output.insert("session", session.name().to_string());
|
||||
let (tokens, percent) = session.tokens_and_percent();
|
||||
output.insert("consume_tokens", tokens.to_string());
|
||||
output.insert("consume_percent", percent.to_string());
|
||||
output.insert("user_messages_len", session.user_messages_len().to_string());
|
||||
}
|
||||
|
||||
if self.highlight {
|
||||
output.insert("color.reset", "\u{1b}[0m".to_string());
|
||||
output.insert("color.black", "\u{1b}[30m".to_string());
|
||||
output.insert("color.dark_gray", "\u{1b}[90m".to_string());
|
||||
output.insert("color.red", "\u{1b}[31m".to_string());
|
||||
output.insert("color.light_red", "\u{1b}[91m".to_string());
|
||||
output.insert("color.green", "\u{1b}[32m".to_string());
|
||||
output.insert("color.light_green", "\u{1b}[92m".to_string());
|
||||
output.insert("color.yellow", "\u{1b}[33m".to_string());
|
||||
output.insert("color.light_yellow", "\u{1b}[93m".to_string());
|
||||
output.insert("color.blue", "\u{1b}[34m".to_string());
|
||||
output.insert("color.light_blue", "\u{1b}[94m".to_string());
|
||||
output.insert("color.purple", "\u{1b}[35m".to_string());
|
||||
output.insert("color.light_purple", "\u{1b}[95m".to_string());
|
||||
output.insert("color.magenta", "\u{1b}[35m".to_string());
|
||||
output.insert("color.light_magenta", "\u{1b}[95m".to_string());
|
||||
output.insert("color.cyan", "\u{1b}[36m".to_string());
|
||||
output.insert("color.light_cyan", "\u{1b}[96m".to_string());
|
||||
output.insert("color.white", "\u{1b}[37m".to_string());
|
||||
output.insert("color.light_gray", "\u{1b}[97m".to_string());
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn open_message_file(&self) -> Result<File> {
|
||||
let path = Self::messages_file()?;
|
||||
ensure_parent_exists(&path)?;
|
||||
|
@ -78,6 +78,10 @@ impl Session {
|
||||
self.model.total_tokens(&self.messages)
|
||||
}
|
||||
|
||||
pub fn user_messages_len(&self) -> usize {
|
||||
self.messages.iter().filter(|v| v.role.is_user()).count()
|
||||
}
|
||||
|
||||
pub fn export(&self) -> Result<String> {
|
||||
self.guard_save()?;
|
||||
let (tokens, percent) = self.tokens_and_percent();
|
||||
|
@ -1,14 +1,8 @@
|
||||
use crate::config::GlobalConfig;
|
||||
|
||||
use crossterm::style::Color;
|
||||
use reedline::{Prompt, PromptHistorySearch, PromptHistorySearchStatus};
|
||||
use std::borrow::Cow;
|
||||
|
||||
const PROMPT_COLOR: Color = Color::Green;
|
||||
const PROMPT_MULTILINE_COLOR: nu_ansi_term::Color = nu_ansi_term::Color::LightBlue;
|
||||
const INDICATOR_COLOR: Color = Color::Cyan;
|
||||
const PROMPT_RIGHT_COLOR: Color = Color::AnsiValue(5);
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ReplPrompt {
|
||||
config: GlobalConfig,
|
||||
@ -24,13 +18,7 @@ impl ReplPrompt {
|
||||
|
||||
impl Prompt for ReplPrompt {
|
||||
fn render_prompt_left(&self) -> Cow<str> {
|
||||
if let Some(session) = &self.config.read().session {
|
||||
Cow::Owned(session.name().to_string())
|
||||
} else if let Some(role) = &self.config.read().role {
|
||||
Cow::Owned(role.name.clone())
|
||||
} else {
|
||||
Cow::Borrowed("")
|
||||
}
|
||||
Cow::Owned(self.config.read().render_prompt_left())
|
||||
}
|
||||
|
||||
fn render_prompt_right(&self) -> Cow<str> {
|
||||
@ -38,11 +26,7 @@ impl Prompt for ReplPrompt {
|
||||
}
|
||||
|
||||
fn render_prompt_indicator(&self, _prompt_mode: reedline::PromptEditMode) -> Cow<str> {
|
||||
if self.config.read().session.is_some() {
|
||||
Cow::Borrowed(") ")
|
||||
} else {
|
||||
Cow::Borrowed("> ")
|
||||
}
|
||||
Cow::Borrowed("")
|
||||
}
|
||||
|
||||
fn render_prompt_multiline_indicator(&self) -> Cow<str> {
|
||||
@ -64,20 +48,4 @@ impl Prompt for ReplPrompt {
|
||||
prefix, history_search.term
|
||||
))
|
||||
}
|
||||
|
||||
fn get_prompt_color(&self) -> Color {
|
||||
PROMPT_COLOR
|
||||
}
|
||||
/// Get the default multiline prompt color
|
||||
fn get_prompt_multiline_color(&self) -> nu_ansi_term::Color {
|
||||
PROMPT_MULTILINE_COLOR
|
||||
}
|
||||
/// Get the default indicator color
|
||||
fn get_indicator_color(&self) -> Color {
|
||||
INDICATOR_COLOR
|
||||
}
|
||||
/// Get the default right prompt color
|
||||
fn get_prompt_right_color(&self) -> Color {
|
||||
PROMPT_RIGHT_COLOR
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,13 @@
|
||||
mod abort_signal;
|
||||
mod clipboard;
|
||||
mod prompt_input;
|
||||
mod render_prompt;
|
||||
mod tiktoken;
|
||||
|
||||
pub use self::abort_signal::{create_abort_signal, AbortSignal};
|
||||
pub use self::clipboard::set_text;
|
||||
pub use self::prompt_input::*;
|
||||
pub use self::render_prompt::render_prompt;
|
||||
pub use self::tiktoken::cl100k_base_singleton;
|
||||
|
||||
use sha2::{Digest, Sha256};
|
||||
|
155
src/utils/render_prompt.rs
Normal file
155
src/utils/render_prompt.rs
Normal file
@ -0,0 +1,155 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Render REPL prompt
|
||||
///
|
||||
/// The template comprises plain text and `{...}`.
|
||||
///
|
||||
/// The syntax of `{...}`:
|
||||
/// - `{var}` - When `var` has a value, replace `var` with the value and eval `template`
|
||||
/// - `{?var <template>}` - Eval `template` when `var` is evaluated as true
|
||||
/// - `{!var <template>}` - Eval `template` when `var` is evaluated as false
|
||||
pub fn render_prompt(template: &str, variables: &HashMap<&str, String>) -> String {
|
||||
let exprs = parse_template(template);
|
||||
eval_exprs(&exprs, variables)
|
||||
}
|
||||
|
||||
fn parse_template(template: &str) -> Vec<Expr> {
|
||||
let chars: Vec<char> = template.chars().collect();
|
||||
let mut exprs = vec![];
|
||||
let mut current = vec![];
|
||||
let mut balances = vec![];
|
||||
for ch in chars.iter().cloned() {
|
||||
if !balances.is_empty() {
|
||||
if ch == '}' {
|
||||
balances.pop();
|
||||
if balances.is_empty() {
|
||||
if !current.is_empty() {
|
||||
let block = parse_block(&mut current);
|
||||
exprs.push(block)
|
||||
}
|
||||
} else {
|
||||
current.push(ch);
|
||||
}
|
||||
} else if ch == '{' {
|
||||
balances.push(ch);
|
||||
current.push(ch);
|
||||
} else {
|
||||
current.push(ch);
|
||||
}
|
||||
} else if ch == '{' {
|
||||
balances.push(ch);
|
||||
add_text(&mut exprs, &mut current);
|
||||
} else {
|
||||
current.push(ch)
|
||||
}
|
||||
}
|
||||
add_text(&mut exprs, &mut current);
|
||||
exprs
|
||||
}
|
||||
|
||||
fn parse_block(current: &mut Vec<char>) -> Expr {
|
||||
let value: String = current.drain(..).collect();
|
||||
match value.split_once(' ') {
|
||||
Some((name, tail)) => {
|
||||
if let Some(name) = name.strip_prefix('?') {
|
||||
let block_exprs = parse_template(tail);
|
||||
Expr::Block(BlockType::Yes, name.to_string(), block_exprs)
|
||||
} else if let Some(name) = name.strip_prefix('!') {
|
||||
let block_exprs = parse_template(tail);
|
||||
Expr::Block(BlockType::No, name.to_string(), block_exprs)
|
||||
} else {
|
||||
Expr::Text(format!("{{{value}}}"))
|
||||
}
|
||||
}
|
||||
None => Expr::Variable(value),
|
||||
}
|
||||
}
|
||||
|
||||
fn eval_exprs(exprs: &[Expr], variables: &HashMap<&str, String>) -> String {
|
||||
let mut output = String::new();
|
||||
for part in exprs {
|
||||
match part {
|
||||
Expr::Text(text) => output.push_str(text),
|
||||
Expr::Variable(variable) => {
|
||||
let value = variables
|
||||
.get(variable.as_str())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
output.push_str(&value);
|
||||
}
|
||||
Expr::Block(typ, variable, block_exprs) => {
|
||||
let value = variables
|
||||
.get(variable.as_str())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
match typ {
|
||||
BlockType::Yes => {
|
||||
if truly(&value) {
|
||||
let block_output = eval_exprs(block_exprs, variables);
|
||||
output.push_str(&block_output)
|
||||
}
|
||||
}
|
||||
BlockType::No => {
|
||||
if !truly(&value) {
|
||||
let block_output = eval_exprs(block_exprs, variables);
|
||||
output.push_str(&block_output)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
fn add_text(exprs: &mut Vec<Expr>, current: &mut Vec<char>) {
|
||||
if current.is_empty() {
|
||||
return;
|
||||
}
|
||||
let value: String = current.drain(..).collect();
|
||||
exprs.push(Expr::Text(value));
|
||||
}
|
||||
|
||||
fn truly(value: &str) -> bool {
|
||||
!(value.is_empty() || value == "0" || value == "false")
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Expr {
|
||||
Text(String),
|
||||
Variable(String),
|
||||
Block(BlockType, String, Vec<Expr>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum BlockType {
|
||||
Yes,
|
||||
No,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
macro_rules! assert_render {
|
||||
($template:expr, [$(($key:literal, $value:literal),)*], $expect:literal) => {
|
||||
let data = HashMap::from([
|
||||
$(($key, $value.into()),)*
|
||||
]);
|
||||
assert_eq!(render_prompt($template, &data), $expect);
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_render() {
|
||||
let prompt = "{?session {session}{?role /}}{role}{?session )}{!session >}";
|
||||
assert_render!(prompt, [], ">");
|
||||
assert_render!(prompt, [("role", "coder"),], "coder>");
|
||||
assert_render!(prompt, [("session", "temp"),], "temp)");
|
||||
assert_render!(
|
||||
prompt,
|
||||
[("session", "temp"), ("role", "coder"),],
|
||||
"temp/coder)"
|
||||
);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user