From 3f693ea060d96b0adc397768c3b2dd47708a20ce Mon Sep 17 00:00:00 2001 From: sigoden Date: Mon, 4 Mar 2024 11:08:59 +0800 Subject: [PATCH] feat: compress session automaticlly (#333) * feat: compress session automaticlly * non-block * update field description * set compress_threshold * update session::clear_messages * able to override session compress_threshold * enable compress_threshold by default * make session compress_threshold optional --- README.md | 2 ++ config.example.yaml | 7 +++++ src/client/message.rs | 2 +- src/config/mod.rs | 66 +++++++++++++++++++++++++++++++++++++------ src/config/role.rs | 2 +- src/config/session.rs | 57 ++++++++++++++++++++++++++++--------- src/repl/mod.rs | 20 +++++++++++++ 7 files changed, 133 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 501a34c..2d353f5 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,7 @@ wrap_code: false # Whether wrap code block auto_copy: false # Automatically copy the last output to the clipboard keybindings: emacs # REPL keybindings. values: emacs, vi prelude: '' # Set a default role or session (role:, session:) +compress_threshold: 1000 # Compress session if tokens exceed this value (valid when >=1000) clients: - type: openai @@ -296,6 +297,7 @@ Usage: .file ... [-- text...] > .set highlight false > .set save false > .set auto_copy true +> .set compress_threshold 1000 ``` ## Command diff --git a/config.example.yaml b/config.example.yaml index e510734..69220f1 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -9,6 +9,13 @@ auto_copy: false # Automatically copy the last output to the cli keybindings: emacs # REPL keybindings. (emacs, vi) prelude: '' # Set a default role or session (role:, session:) +# Compress session if tokens exceed this value (valid when >=1000) +compress_threshold: 1000 +# The prompt for summarizing session messages +summarize_prompt: 'Summarize the discussion briefly in 200 words or less to use as a prompt for future context.' +# The prompt for the summary of the session +summary_prompt: 'This is a summary of the chat history as a recap: ' + # Custom REPL prompt, see https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt left_prompt: '{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} ' right_prompt: '{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}' diff --git a/src/client/message.rs b/src/client/message.rs index dc8c3e1..757b673 100644 --- a/src/client/message.rs +++ b/src/client/message.rs @@ -17,7 +17,7 @@ impl Message { } } -#[derive(Debug, Clone, Copy, Deserialize, Serialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] pub enum MessageRole { System, diff --git a/src/config/mod.rs b/src/config/mod.rs index 804a2fd..4a37378 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -67,6 +67,12 @@ pub struct Config { pub keybindings: Keybindings, /// Set a default role or session (role:, session:) pub prelude: String, + /// Compress session if tokens exceed this value (>=1000) + pub compress_threshold: usize, + /// The prompt for summarizing session messages + pub summarize_prompt: String, + // The prompt for the summary of the session + pub summary_prompt: String, /// REPL left prompt pub left_prompt: String, /// REPL right prompt @@ -104,6 +110,9 @@ impl Default for Config { auto_copy: false, keybindings: Default::default(), prelude: String::new(), + compress_threshold: 2000, + summarize_prompt: "Summarize the discussion briefly in 200 words or less to use as a prompt for future context.".to_string(), + summary_prompt: "This is a summary of the chat history as a recap: ".into(), left_prompt: "{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} ".to_string(), right_prompt: "{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}" .to_string(), @@ -345,12 +354,18 @@ impl Config { self.temperature } - pub fn set_temperature(&mut self, value: Option) -> Result<()> { + pub fn set_temperature(&mut self, value: Option) { self.temperature = value; if let Some(session) = self.session.as_mut() { session.set_temperature(value); } - Ok(()) + } + + pub fn set_compress_threshold(&mut self, value: usize) { + self.compress_threshold = value; + if let Some(session) = self.session.as_mut() { + session.set_compress_threshold(value); + } } pub fn echo_messages(&self, input: &Input) -> String { @@ -430,6 +445,7 @@ impl Config { ("auto_copy", self.auto_copy.to_string()), ("keybindings", self.keybindings.stringify().into()), ("prelude", prelude), + ("compress_threshold", self.compress_threshold.to_string()), ("config_file", display_path(&Self::config_file()?)), ("roles_file", display_path(&Self::roles_file()?)), ("messages_file", display_path(&Self::messages_file()?)), @@ -445,7 +461,7 @@ impl Config { pub fn role_info(&self) -> Result { if let Some(role) = &self.role { - role.info() + role.export() } else { bail!("No role") } @@ -455,7 +471,7 @@ impl Config { if let Some(session) = &self.session { let render_options = self.get_render_options()?; let mut markdown_render = MarkdownRender::init(render_options)?; - session.render(&mut markdown_render) + session.info(&mut markdown_render) } else { bail!("No session") } @@ -465,7 +481,7 @@ impl Config { if let Some(session) = &self.session { session.export() } else if let Some(role) = &self.role { - role.info() + role.export() } else { self.sys_info() } @@ -486,6 +502,7 @@ impl Config { ".session" => self.list_sessions(), ".set" => vec![ "temperature ", + "compress_threshold", "save ", "highlight ", "dry_run ", @@ -532,7 +549,11 @@ impl Config { let value = value.parse().with_context(|| "Invalid value")?; Some(value) }; - self.set_temperature(value)?; + self.set_temperature(value); + } + "compress_threshold" => { + let value = value.parse().with_context(|| "Invalid value")?; + self.set_compress_threshold(value); } "save" => { let value = value.parse().with_context(|| "Invalid value")?; @@ -608,7 +629,7 @@ impl Config { if let Some(mut session) = self.session.take() { self.last_message = None; self.temperature = self.default_temperature; - if session.should_save() { + if session.dirty { let ans = Confirm::new("Save session?").with_default(false).prompt()?; if !ans { return Ok(()); @@ -634,7 +655,7 @@ impl Config { pub fn clear_session_messages(&mut self) -> Result<()> { if let Some(session) = self.session.as_mut() { - session.clear_messgaes(); + session.clear_messages(); } Ok(()) } @@ -660,6 +681,35 @@ impl Config { } } + pub fn should_compress_session(&mut self) -> bool { + if let Some(sesion) = self.session.as_mut() { + if sesion.need_compress(self.compress_threshold) { + sesion.compressing = true; + return true; + } + } + false + } + + pub fn compress_session(&mut self, summary: &str) { + if let Some(session) = self.session.as_mut() { + session.compress(format!("{}{}", self.summary_prompt, summary)); + } + } + + pub fn is_compressing_session(&self) -> bool { + self.session + .as_ref() + .map(|v| v.compressing) + .unwrap_or_default() + } + + pub fn end_compressing_session(&mut self) { + if let Some(session) = self.session.as_mut() { + session.compressing = false; + } + } + pub fn get_render_options(&self) -> Result { let theme = if self.highlight { let theme_mode = if self.light_theme { "light" } else { "dark" }; diff --git a/src/config/role.rs b/src/config/role.rs index 1acf027..2bff545 100644 --- a/src/config/role.rs +++ b/src/config/role.rs @@ -72,7 +72,7 @@ For example if the prompt is "Hello world Python", you should return "print('Hel } } - pub fn info(&self) -> Result { + pub fn export(&self) -> Result { let output = serde_yaml::to_string(&self) .with_context(|| format!("Unable to show info about role {}", &self.name))?; Ok(output.trim_end().to_string()) diff --git a/src/config/session.rs b/src/config/session.rs index 076fa27..e824c43 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -22,6 +22,9 @@ pub struct Session { messages: Vec, #[serde(default)] data_urls: HashMap, + #[serde(default)] + compressed_messages: Vec, + compress_threshold: Option, #[serde(skip)] pub name: String, #[serde(skip)] @@ -29,6 +32,8 @@ pub struct Session { #[serde(skip)] pub dirty: bool, #[serde(skip)] + pub compressing: bool, + #[serde(skip)] pub role: Option, #[serde(skip)] pub model: Model, @@ -41,10 +46,13 @@ impl Session { model_id: model.id(), temperature, messages: vec![], + compressed_messages: vec![], + compress_threshold: None, data_urls: Default::default(), name: name.to_string(), path: None, dirty: false, + compressing: false, role, model, } @@ -74,6 +82,13 @@ impl Session { self.temperature } + pub fn need_compress(&self, current_compress_threshold: usize) -> bool { + let threshold = self + .compress_threshold + .unwrap_or(current_compress_threshold); + threshold >= 1000 && self.tokens() > threshold + } + pub fn tokens(&self) -> usize { self.model.total_tokens(&self.messages) } @@ -106,7 +121,7 @@ impl Session { Ok(output) } - pub fn render(&self, render: &mut MarkdownRender) -> Result { + pub fn info(&self, render: &mut MarkdownRender) -> Result { let mut items = vec![]; if let Some(path) = &self.path { @@ -119,6 +134,10 @@ impl Session { items.push(("temperature", temperature.to_string())); } + if let Some(compress_threshold) = self.compress_threshold { + items.push(("compress_threshold", compress_threshold.to_string())); + } + if let Some(max_tokens) = self.model.max_tokens { items.push(("max_tokens", max_tokens.to_string())); } @@ -135,7 +154,7 @@ impl Session { for message in &self.messages { match message.role { MessageRole::System => { - continue; + lines.push(render.render(&message.content.render_input(resolve_url_fn))); } MessageRole::Assistant => { if let MessageContent::Text(text) = &message.content { @@ -181,14 +200,28 @@ impl Session { self.temperature = value; } + pub fn set_compress_threshold(&mut self, value: usize) { + self.compress_threshold = Some(value); + } + pub fn set_model(&mut self, model: Model) -> Result<()> { self.model_id = model.id(); self.model = model; Ok(()) } + pub fn compress(&mut self, prompt: String) { + self.compressed_messages.append(&mut self.messages); + self.messages.push(Message { + role: MessageRole::System, + content: MessageContent::Text(prompt), + }); + self.role = None; + self.dirty = true; + } + pub fn save(&mut self, session_path: &Path) -> Result<()> { - if !self.should_save() { + if !self.dirty { return Ok(()); } self.path = Some(session_path.display().to_string()); @@ -208,10 +241,6 @@ impl Session { Ok(()) } - pub fn should_save(&self) -> bool { - !self.is_empty() && self.dirty - } - pub fn guard_save(&self) -> Result<()> { if self.path.is_none() { bail!("Not found session '{}'", self.name) @@ -258,11 +287,9 @@ impl Session { Ok(()) } - pub fn clear_messgaes(&mut self) { - if self.messages.is_empty() { - return; - } + pub fn clear_messages(&mut self) { self.messages.clear(); + self.compressed_messages.clear(); self.data_urls.clear(); self.dirty = true; } @@ -275,12 +302,16 @@ impl Session { pub fn build_emssages(&self, input: &Input) -> Vec { let mut messages = self.messages.clone(); let mut need_add_msg = true; - if messages.is_empty() { + let len = messages.len(); + if len == 0 { if let Some(role) = self.role.as_ref() { messages = role.build_messages(input); need_add_msg = false; } - }; + } else if len == 1 && self.compressed_messages.len() >= 2 { + messages + .extend(self.compressed_messages[self.compressed_messages.len() - 2..].to_vec()); + } if need_add_msg { messages.push(Message { role: MessageRole::User, diff --git a/src/repl/mod.rs b/src/repl/mod.rs index e11b495..6ce52f2 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -258,6 +258,9 @@ impl Repl { if text.is_empty() && files.is_empty() { return Ok(()); } + while self.config.read().is_compressing_session() { + std::thread::sleep(std::time::Duration::from_millis(100)); + } let input = if files.is_empty() { Input::from_str(text) } else { @@ -269,6 +272,14 @@ impl Repl { let output = render_stream(&input, client.as_ref(), &self.config, self.abort.clone())?; self.config.write().save_message(input, &output)?; self.config.read().maybe_copy(&output); + if self.config.write().should_compress_session() { + let config = self.config.clone(); + std::thread::spawn(move || -> anyhow::Result<()> { + let _ = compress_session(&config); + config.write().end_compressing_session(); + Ok(()) + }); + } Ok(()) } @@ -418,6 +429,15 @@ fn parse_command(line: &str) -> Option<(&str, Option<&str>)> { } } +fn compress_session(config: &GlobalConfig) -> Result<()> { + let input = Input::from_str(&config.read().summarize_prompt); + let mut client = init_client(config)?; + ensure_model_capabilities(client.as_mut(), input.required_capabilities())?; + let summary = client.send_message(input)?; + config.write().compress_session(&summary); + Ok(()) +} + #[cfg(test)] mod tests { use super::*;