From bc44026ff8828f532e9eba138e2ba35f7c247271 Mon Sep 17 00:00:00 2001 From: sigoden Date: Sat, 28 Oct 2023 21:39:17 +0800 Subject: [PATCH] feat: enhance session/conversation (#162) * feat: enhance session/conversation * updates * updates * cut version v0.9.0-rc2 * add .session name completion --- Cargo.lock | 2 +- Cargo.toml | 2 +- README.md | 52 ++++----- src/cli.rs | 6 + src/config/conversation.rs | 91 --------------- src/config/mod.rs | 229 +++++++++++++++++++++++++++---------- src/config/role.rs | 6 +- src/config/session.rs | 159 +++++++++++++++++++++++++ src/main.rs | 28 +++-- src/repl/handler.rs | 29 ++--- src/repl/mod.rs | 13 ++- src/repl/prompt.rs | 16 +-- 12 files changed, 409 insertions(+), 224 deletions(-) delete mode 100644 src/config/conversation.rs create mode 100644 src/config/session.rs diff --git a/Cargo.lock b/Cargo.lock index fc0dc1d..7d93df1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,7 +28,7 @@ dependencies = [ [[package]] name = "aichat" -version = "0.9.0-rc1" +version = "0.9.0-rc2" dependencies = [ "anyhow", "arboard", diff --git a/Cargo.toml b/Cargo.toml index 6a2753d..3e142dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aichat" -version = "0.9.0-rc1" +version = "0.9.0-rc2" edition = "2021" authors = ["sigoden "] description = "A powerful chatgpt cli." diff --git a/README.md b/README.md index 41b0ba4..ee212a0 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ Download it from [GitHub Releases](https://github.com/sigoden/aichat/releases), - Predefine AI [roles](#roles) - Use GPT prompt easily - Powerful [Chat REPL](#chat-repl) -- Context-aware conversation +- Context-aware conversation/session - Syntax highlighting markdown and 200 other languages - Stream output with hand-typing effect - Support multiple models @@ -58,9 +58,8 @@ On completion, it will automatically create the configuration file. Of course, y ```yaml model: openai:gpt-3.5-turbo # Choose a model temperature: 1.0 # See https://platform.openai.com/docs/api-reference/chat/create#chat/create-temperature -save: true # If set true, aichat will save chat messages to message.md +save: true # If set true, aichat will save non-session chat messages to messages.md highlight: true # Set false to turn highlight -conversation_first: false # If set true, start a conversation immediately upon repl light_theme: false # If set true, use light theme auto_copy: false # Automatically copy the last output to the clipboard keybindings: emacs # REPL keybindings, possible values: emacs (default), vi @@ -167,10 +166,10 @@ AIChat also provides `.edit` command for multi-lines editing. .prompt Add a GPT prompt .role Select a role .clear role Clear the currently selected role -.conversation Start a conversation. -.clear conversation End current conversation. +.session Start a session +.clear session End current session .copy Copy the last output to the clipboard -.read Read the contents of a file into the prompt +.read Read the contents of a file and submit .edit Multi-line editing (CTRL+S to finish) .history Print the history .clear history Clear the history @@ -187,11 +186,11 @@ Press Ctrl+C to abort readline, Ctrl+D to exit the REPL config_file /home/alice/.config/aichat/config.yaml roles_file /home/alice/.config/aichat/roles.yaml messages_file /home/alice/.config/aichat/messages.md +sessions_dir /home/alice/.config/aichat/sessions model openai:gpt-3.5-turbo temperature 0.7 save true highlight true -conversation_first false light_theme false dry_run false vi_keybindings true @@ -264,41 +263,30 @@ emoji〉.clear role Hello there! How can I assist you today? ``` -### `.conversation` - start a context-aware conversation +## Session - context-aware conversation By default, aichat behaves in a one-off request/response manner. +You should run aichat with "-s/--session" or use the ".session" command to start a session. -You can run `.conversation` to enter context-aware mode, or set `config.conversation_first` true to start a conversation immediately upon repl. ``` -〉.conversation - -)list 1 to 5, one per line 4089 -1 -2 -3 -4 -5 - -)reverse the list 4065 -5 -4 -3 -2 -1 +〉.session +temp)1 to 5, odd only 4089 +1, 3, 5 -``` - -When entering conversation mode, prompt `〉` will change to `)`. A number will appear on the right, -indicating how many tokens are left to use. -Once the number becomes zero, you need to start a new conversation. +temp)to 7 4070 +1, 3, 5, 7 -Exit conversation mode: +temp).clear session +〉 ``` -).clear conversation 4043 -〉 +```sh +aichat --list-sessions # List sessions +aichat -s # Start with a new session +aichat -s rust # If session rust exists, use it. If it does not exist, create a new session. +aichat -s rust --info # Show session details ``` ## License diff --git a/src/cli.rs b/src/cli.rs index 42285e6..1df8ba1 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -31,6 +31,12 @@ pub struct Cli { /// Run in dry run mode #[clap(long)] pub dry_run: bool, + /// List sessions + #[clap(long)] + pub list_sessions: bool, + /// Initiate or continue named session + #[clap(short = 's', long)] + pub session: Option>, /// Input text text: Vec, } diff --git a/src/config/conversation.rs b/src/config/conversation.rs deleted file mode 100644 index b9a0793..0000000 --- a/src/config/conversation.rs +++ /dev/null @@ -1,91 +0,0 @@ -use super::message::{num_tokens_from_messages, Message, MessageRole}; -use super::role::Role; - -use anyhow::{bail, Result}; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Conversation { - pub tokens: usize, - pub role: Option, - pub messages: Vec, -} - -impl Conversation { - pub fn new(role: Option) -> Self { - let mut value = Self { - tokens: 0, - role, - messages: vec![], - }; - value.update_tokens(); - value - } - - pub fn update_role(&mut self, role: &Role) -> Result<()> { - if self.messages.is_empty() { - self.role = Some(role.clone()); - self.update_tokens(); - } else { - bail!("Error: Cannot perform this action in the middle of conversation") - } - Ok(()) - } - - pub fn can_clear_role(&self) -> Result<()> { - if self.messages.is_empty() { - return Ok(()); - } - bail!("Error: Cannot perform this action in the middle of conversation") - } - - pub fn update_tokens(&mut self) { - self.tokens = num_tokens_from_messages(&self.build_emssages("")); - } - - #[allow(clippy::unnecessary_wraps)] - pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> { - let mut need_add_msg = true; - if self.messages.is_empty() { - if let Some(role) = self.role.as_ref() { - self.messages.extend(role.build_messages(input)); - need_add_msg = false; - } - } - if need_add_msg { - self.messages.push(Message { - role: MessageRole::User, - content: input.to_string(), - }); - } - self.messages.push(Message { - role: MessageRole::Assistant, - content: output.to_string(), - }); - self.tokens = num_tokens_from_messages(&self.messages); - Ok(()) - } - - pub fn echo_messages(&self, content: &str) -> String { - let messages = self.build_emssages(content); - serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into()) - } - - pub fn build_emssages(&self, content: &str) -> Vec { - let mut messages = self.messages.clone(); - let mut need_add_msg = true; - if messages.is_empty() { - if let Some(role) = self.role.as_ref() { - messages = role.build_messages(content); - need_add_msg = false; - } - }; - if need_add_msg { - messages.push(Message { - role: MessageRole::User, - content: content.into(), - }); - } - messages - } -} diff --git a/src/config/mod.rs b/src/config/mod.rs index 8e5157f..73c9444 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,10 +1,10 @@ -mod conversation; mod message; mod role; +mod session; -use self::conversation::Conversation; use self::message::Message; use self::role::Role; +use self::session::{Session, TEMP_SESSION_NAME}; use crate::client::openai::{OpenAIClient, OpenAIConfig}; use crate::client::{all_clients, create_client_config, list_models, ClientConfig, ModelInfo}; @@ -12,12 +12,12 @@ use crate::config::message::num_tokens_from_messages; use crate::utils::{get_env_name, now}; use anyhow::{anyhow, bail, Context, Result}; -use inquire::{Confirm, Select}; +use inquire::{Confirm, Select, Text}; use parking_lot::RwLock; use serde::Deserialize; use std::{ env, - fs::{create_dir_all, read_to_string, File, OpenOptions}, + fs::{create_dir_all, read_dir, read_to_string, remove_file, File, OpenOptions}, io::Write, path::{Path, PathBuf}, process::exit, @@ -27,7 +27,9 @@ use std::{ const CONFIG_FILE_NAME: &str = "config.yaml"; const ROLES_FILE_NAME: &str = "roles.yaml"; const HISTORY_FILE_NAME: &str = "history.txt"; -const MESSAGE_FILE_NAME: &str = "messages.md"; +const MESSAGES_FILE_NAME: &str = "messages.md"; +const SESSIONS_DIR_NAME: &str = "sessions"; + const SET_COMPLETIONS: [&str; 7] = [ ".set temperature", ".set save true", @@ -46,14 +48,12 @@ pub struct Config { pub model: Option, /// What sampling temperature to use, between 0 and 2 pub temperature: Option, - /// Whether to persistently save chat messages + /// Whether to persistently save non-session chat messages pub save: bool, /// Whether to disable highlight pub highlight: bool, /// Used only for debugging pub dry_run: bool, - /// If set ture, start a conversation immediately upon repl - pub conversation_first: bool, /// If set true, use light theme pub light_theme: bool, /// Automatically copy the last output to the clipboard @@ -68,11 +68,13 @@ pub struct Config { /// Current selected role #[serde(skip)] pub role: Option, - /// Current conversation + /// Current session #[serde(skip)] - pub conversation: Option, + pub session: Option, #[serde(skip)] pub model_info: ModelInfo, + #[serde(skip)] + pub last_message: Option<(String, String)>, } impl Default for Config { @@ -83,15 +85,15 @@ impl Default for Config { save: false, highlight: true, dry_run: false, - conversation_first: false, light_theme: false, auto_copy: false, keybindings: Default::default(), - roles: vec![], clients: vec![ClientConfig::OpenAI(OpenAIConfig::default())], + roles: vec![], role: None, - conversation: None, + session: None, model_info: Default::default(), + last_message: None, } } } @@ -123,19 +125,14 @@ impl Config { if let Some(name) = config.model.clone() { config.set_model(&name)?; } + config.merge_env_vars(); config.load_roles()?; + config.ensure_sessions_dir()?; Ok(config) } - pub fn on_repl(&mut self) -> Result<()> { - if self.conversation_first { - self.start_conversation()?; - } - Ok(()) - } - pub fn get_role(&self, name: &str) -> Option { self.roles.iter().find(|v| v.match_name(name)).map(|v| { let mut role = v.clone(); @@ -156,13 +153,20 @@ impl Config { Ok(path) } - pub fn local_file(name: &str) -> Result { + pub fn local_path(name: &str) -> Result { let mut path = Self::config_dir()?; path.push(name); Ok(path) } - pub fn save_message(&self, input: &str, output: &str) -> Result<()> { + pub fn save_message(&mut self, input: &str, output: &str) -> Result<()> { + self.last_message = Some((input.to_string(), output.to_string())); + + if let Some(session) = self.session.as_mut() { + session.add_message(input, output)?; + return Ok(()); + } + if !self.save { return Ok(()); } @@ -193,30 +197,40 @@ impl Config { } pub fn config_file() -> Result { - Self::local_file(CONFIG_FILE_NAME) + Self::local_path(CONFIG_FILE_NAME) } pub fn roles_file() -> Result { let env_name = get_env_name("roles_file"); env::var(env_name).map_or_else( - |_| Self::local_file(ROLES_FILE_NAME), + |_| Self::local_path(ROLES_FILE_NAME), |value| Ok(PathBuf::from(value)), ) } pub fn history_file() -> Result { - Self::local_file(HISTORY_FILE_NAME) + Self::local_path(HISTORY_FILE_NAME) } pub fn messages_file() -> Result { - Self::local_file(MESSAGE_FILE_NAME) + Self::local_path(MESSAGES_FILE_NAME) + } + + pub fn sessions_dir() -> Result { + Self::local_path(SESSIONS_DIR_NAME) + } + + pub fn session_file(name: &str) -> Result { + let mut path = Self::sessions_dir()?; + path.push(&format!("{name}.yaml")); + Ok(path) } pub fn change_role(&mut self, name: &str) -> Result { match self.get_role(name) { Some(role) => { - if let Some(conversation) = self.conversation.as_mut() { - conversation.update_role(&role)?; + if let Some(session) = self.session.as_mut() { + session.update_role(Some(role.clone()))?; } let output = serde_yaml::to_string(&role) .unwrap_or_else(|_| "Unable to echo role details".into()); @@ -228,8 +242,8 @@ impl Config { } pub fn clear_role(&mut self) -> Result<()> { - if let Some(conversation) = self.conversation.as_ref() { - conversation.can_clear_role()?; + if let Some(session) = self.session.as_mut() { + session.update_role(None)?; } self.role = None; Ok(()) @@ -237,8 +251,8 @@ impl Config { pub fn add_prompt(&mut self, prompt: &str) -> Result<()> { let role = Role::new(prompt, self.temperature); - if let Some(conversation) = self.conversation.as_mut() { - conversation.update_role(&role)?; + if let Some(session) = self.session.as_mut() { + session.update_role(Some(role.clone()))?; } self.role = Some(role); Ok(()) @@ -253,8 +267,8 @@ impl Config { pub fn echo_messages(&self, content: &str) -> String { #[allow(clippy::option_if_let_else)] - if let Some(conversation) = self.conversation.as_ref() { - conversation.echo_messages(content) + if let Some(session) = self.session.as_ref() { + session.echo_messages(content) } else if let Some(role) = self.role.as_ref() { role.echo_messages(content) } else { @@ -264,8 +278,8 @@ impl Config { pub fn build_messages(&self, content: &str) -> Result> { #[allow(clippy::option_if_let_else)] - let messages = if let Some(conversation) = self.conversation.as_ref() { - conversation.build_emssages(content) + let messages = if let Some(session) = self.session.as_ref() { + session.build_emssages(content) } else if let Some(role) = self.role.as_ref() { role.build_messages(content) } else { @@ -282,28 +296,36 @@ impl Config { pub fn set_model(&mut self, value: &str) -> Result<()> { let models = list_models(self); + let mut model_info = None; if value.contains(':') { if let Some(model) = models.iter().find(|v| v.stringify() == value) { - self.model_info = model.clone(); - return Ok(()); + model_info = Some(model.clone()); } } else if let Some(model) = models.iter().find(|v| v.client == value) { - self.model_info = model.clone(); - return Ok(()); + model_info = Some(model.clone()); + } + match model_info { + None => bail!("Invalid model"), + Some(model_info) => { + if let Some(session) = self.session.as_mut() { + session.model = model_info.stringify(); + } + self.model_info = model_info; + Ok(()) + } } - bail!("Invalid model") } pub const fn get_reamind_tokens(&self) -> usize { let mut tokens = self.model_info.max_tokens; - if let Some(conversation) = self.conversation.as_ref() { - tokens = tokens.saturating_sub(conversation.tokens); + if let Some(session) = self.session.as_ref() { + tokens = tokens.saturating_sub(session.tokens); } tokens } pub fn info(&self) -> Result { - let file_info = |path: &Path| { + let path_info = |path: &Path| { let state = if path.exists() { "" } else { " ⚠️" }; format!("{}{state}", path.display()) }; @@ -311,14 +333,14 @@ impl Config { .temperature .map_or_else(|| String::from("-"), |v| v.to_string()); let items = vec![ - ("config_file", file_info(&Self::config_file()?)), - ("roles_file", file_info(&Self::roles_file()?)), - ("messages_file", file_info(&Self::messages_file()?)), + ("config_file", path_info(&Self::config_file()?)), + ("roles_file", path_info(&Self::roles_file()?)), + ("messages_file", path_info(&Self::messages_file()?)), + ("sessions_dir", path_info(&Self::sessions_dir()?)), ("model", self.model_info.stringify()), ("temperature", temperature), ("save", self.save.to_string()), ("highlight", self.highlight.to_string()), - ("conversation_first", self.conversation_first.to_string()), ("light_theme", self.light_theme.to_string()), ("dry_run", self.dry_run.to_string()), ("keybindings", self.keybindings.stringify().into()), @@ -343,6 +365,13 @@ impl Config { .iter() .map(|v| format!(".model {}", v.stringify())), ); + completion.extend( + list_models(self) + .iter() + .map(|v| format!(".model {}", v.stringify())), + ); + let sessions = self.list_sessions().unwrap_or_default(); + completion.extend(sessions.iter().map(|v| format!(".session {}", v))); completion } @@ -380,28 +409,94 @@ impl Config { Ok(()) } - pub fn start_conversation(&mut self) -> Result<()> { - if self.conversation.is_some() && self.get_reamind_tokens() > 0 { - let ans = Confirm::new("Already in a conversation, start a new one?") - .with_default(true) - .prompt()?; - if !ans { - return Ok(()); + pub fn start_session(&mut self, session: &Option) -> Result<()> { + if self.session.is_some() { + bail!("Already in a session, please use '.clear session' to exit the session first?"); + } + match session { + None => { + let session_file = Self::session_file(TEMP_SESSION_NAME)?; + if session_file.exists() { + remove_file(session_file) + .with_context(|| "Failed to clean previous session")?; + } + self.session = Some(Session::new( + TEMP_SESSION_NAME, + &self.model_info.stringify(), + self.role.clone(), + )); + } + Some(name) => { + let session_path = Self::session_file(name)?; + if !session_path.exists() { + self.session = Some(Session::new( + name, + &self.model_info.stringify(), + self.role.clone(), + )); + } else { + let mut session = Session::load(name, &session_path)?; + if let Some(role) = &session.role { + self.change_role(&role.name)?; + } + self.set_model(&session.model)?; + session.update_tokens(); + self.session = Some(session); + } + } + } + if let Some(session) = self.session.as_mut() { + if session.is_empty() { + if let Some((input, output)) = &self.last_message { + let ans = Confirm::new( + "Start a session that incorporates the last question and answer?", + ) + .with_default(false) + .prompt()?; + if ans { + session.add_message(input, output)?; + } + } } } - self.conversation = Some(Conversation::new(self.role.clone())); Ok(()) } - pub fn end_conversation(&mut self) { - self.conversation = None; + pub fn end_session(&mut self) -> Result<()> { + if let Some(mut session) = self.session.take() { + self.last_message = None; + if session.should_save() { + let ans = Confirm::new("Save session?").with_default(true).prompt()?; + if !ans { + return Ok(()); + } + let mut name = session.name.clone(); + if session.is_temp() { + name = Text::new("Session name:").with_default(&name).prompt()?; + } + let session_path = Self::session_file(&name)?; + session.save(&session_path)?; + } + } + Ok(()) } - pub fn save_conversation(&mut self, input: &str, output: &str) -> Result<()> { - if let Some(conversation) = self.conversation.as_mut() { - conversation.add_message(input, output)?; + pub fn list_sessions(&self) -> Result> { + let sessions_dir = Self::sessions_dir()?; + match read_dir(&sessions_dir) { + Ok(rd) => { + let mut names = vec![]; + for entry in rd { + let entry = entry?; + let name = entry.file_name(); + if let Some(name) = name.to_string_lossy().strip_suffix(".yaml") { + names.push(name.to_string()); + } + } + Ok(names) + } + Err(_) => Ok(vec![]), } - Ok(()) } pub const fn get_render_options(&self) -> (bool, bool) { @@ -463,6 +558,16 @@ impl Config { } } + fn ensure_sessions_dir(&self) -> Result<()> { + let sessions_dir = Self::sessions_dir()?; + if !sessions_dir.exists() { + create_dir_all(&sessions_dir).with_context(|| { + format!("Failed to create session_dir '{}'", sessions_dir.display()) + })?; + } + Ok(()) + } + fn compat_old_config(&mut self, config_path: &PathBuf) -> Result<()> { let content = read_to_string(config_path)?; let value: serde_json::Value = serde_yaml::from_str(&content)?; diff --git a/src/config/role.rs b/src/config/role.rs index 32ade08..23a69df 100644 --- a/src/config/role.rs +++ b/src/config/role.rs @@ -2,7 +2,7 @@ use super::message::{Message, MessageRole}; use serde::{Deserialize, Serialize}; -const TEMP_NAME: &str = "P"; +const TEMP_ROLE_NAME: &str = "temp"; const INPUT_PLACEHOLDER: &str = "__INPUT__"; #[derive(Debug, Clone, Deserialize, Serialize)] @@ -18,14 +18,14 @@ pub struct Role { impl Role { pub fn new(prompt: &str, temperature: Option) -> Self { Self { - name: TEMP_NAME.into(), + name: TEMP_ROLE_NAME.into(), prompt: prompt.into(), temperature, } } pub fn is_temp(&self) -> bool { - self.name == TEMP_NAME + self.name == TEMP_ROLE_NAME } pub fn embeded(&self) -> bool { diff --git a/src/config/session.rs b/src/config/session.rs new file mode 100644 index 0000000..95b01ef --- /dev/null +++ b/src/config/session.rs @@ -0,0 +1,159 @@ +use super::message::{num_tokens_from_messages, Message, MessageRole}; +use super::role::Role; + +use anyhow::{bail, Context, Result}; +use serde::{Deserialize, Serialize}; +use std::fs::{self, read_to_string}; +use std::path::Path; + +pub const TEMP_SESSION_NAME: &str = "temp"; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Session { + pub path: Option, + pub model: String, + pub tokens: usize, + pub messages: Vec, + #[serde(skip)] + pub dirty: bool, + #[serde(skip)] + pub role: Option, + #[serde(skip)] + pub name: String, +} + +impl Session { + pub fn new(name: &str, model: &str, role: Option) -> Self { + let mut value = Self { + path: None, + model: model.to_string(), + tokens: 0, + messages: vec![], + dirty: false, + role, + name: name.to_string(), + }; + value.update_tokens(); + value + } + + pub fn load(name: &str, path: &Path) -> Result { + let content = read_to_string(path) + .with_context(|| format!("Failed to load session {} at {}", name, path.display()))?; + let mut session: Self = + serde_yaml::from_str(&content).with_context(|| format!("Invalid sesion {}", name))?; + + session.name = name.to_string(); + session.path = Some(path.display().to_string()); + + Ok(session) + } + + pub fn info(&self) -> Result { + self.guard_save()?; + let output = serde_yaml::to_string(&self) + .with_context(|| format!("Unable to show info about session {}", &self.name))?; + Ok(output) + } + + pub fn update_role(&mut self, role: Option) -> Result<()> { + self.guard_empty()?; + self.role = role; + self.update_tokens(); + Ok(()) + } + + pub fn save(&mut self, session_path: &Path) -> Result<()> { + if !self.should_save() { + return Ok(()); + } + self.dirty = false; + let content = serde_yaml::to_string(&self) + .with_context(|| format!("Failed to serde session {}", self.name))?; + fs::write(session_path, content).with_context(|| { + format!( + "Failed to write session {} to {}", + self.name, + session_path.display() + ) + })?; + Ok(()) + } + + pub fn should_save(&self) -> bool { + !self.is_empty() && self.dirty + } + + pub fn guard_save(&self) -> Result<()> { + if self.path.is_none() { + bail!("Not found session '{}'", self.name) + } + Ok(()) + } + + pub fn guard_empty(&self) -> Result<()> { + if !self.is_empty() { + bail!("Cannot perform this action in session") + } + Ok(()) + } + + pub fn is_temp(&self) -> bool { + self.name == TEMP_SESSION_NAME + } + + pub fn is_empty(&self) -> bool { + self.messages.is_empty() + } + + pub fn update_tokens(&mut self) { + self.tokens = num_tokens_from_messages(&self.build_emssages("")); + } + + #[allow(clippy::unnecessary_wraps)] + pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> { + let mut need_add_msg = true; + if self.messages.is_empty() { + if let Some(role) = self.role.as_ref() { + self.messages.extend(role.build_messages(input)); + need_add_msg = false; + } + } + if need_add_msg { + self.messages.push(Message { + role: MessageRole::User, + content: input.to_string(), + }); + } + self.messages.push(Message { + role: MessageRole::Assistant, + content: output.to_string(), + }); + self.tokens = num_tokens_from_messages(&self.messages); + self.dirty = true; + Ok(()) + } + + pub fn echo_messages(&self, content: &str) -> String { + let messages = self.build_emssages(content); + serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into()) + } + + pub fn build_emssages(&self, content: &str) -> Vec { + let mut messages = self.messages.clone(); + let mut need_add_msg = true; + if messages.is_empty() { + if let Some(role) = self.role.as_ref() { + messages = role.build_messages(content); + need_add_msg = false; + } + }; + if need_add_msg { + messages.push(Message { + role: MessageRole::User, + content: content.into(), + }); + } + messages + } +} diff --git a/src/main.rs b/src/main.rs index 65bc5f2..d73d4c1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -42,9 +42,20 @@ fn main() -> Result<()> { } exit(0); } + if cli.list_sessions { + let sessions = config.read().list_sessions()?.join("\n"); + println!("{sessions}"); + exit(0); + } if cli.dry_run { config.write().dry_run = true; } + if let Some(session) = &cli.session { + config.write().start_session(session)?; + } + if let Some(model) = &cli.model { + config.write().set_model(model)?; + } let role = match &cli.role { Some(name) => Some( config @@ -54,9 +65,6 @@ fn main() -> Result<()> { ), None => None, }; - if let Some(model) = &cli.model { - config.write().set_model(model)?; - } config.write().role = role; if cli.no_highlight { config.write().highlight = false; @@ -65,7 +73,11 @@ fn main() -> Result<()> { config.write().add_prompt(prompt)?; } if cli.info { - let info = config.read().info()?; + let info = if let Some(session) = &config.read().session { + session.info()? + } else { + config.read().info()? + }; println!("{info}"); exit(0); } @@ -92,6 +104,9 @@ fn start_directive( input: &str, no_stream: bool, ) -> Result<()> { + if let Some(sesion) = &config.read().session { + sesion.guard_save()?; + } if !stdout().is_terminal() { config.write().highlight = false; } @@ -118,12 +133,11 @@ fn start_directive( wg.wait(); output }; - config.read().save_message(input, &output) + config.write().save_message(input, &output) } fn start_interactive(config: SharedConfig) -> Result<()> { cl100k_base_singleton(); - config.write().on_repl()?; - let mut repl = Repl::init(config.clone())?; + let mut repl: Repl = Repl::init(config.clone())?; repl.run(config) } diff --git a/src/repl/handler.rs b/src/repl/handler.rs index 1952ebe..fd157dc 100644 --- a/src/repl/handler.rs +++ b/src/repl/handler.rs @@ -21,8 +21,8 @@ pub enum ReplCmd { Prompt(String), ClearRole, ViewInfo, - StartConversation, - EndConversatoin, + StartSession(Option), + EndSession, Copy, ReadFile(String), } @@ -30,7 +30,6 @@ pub enum ReplCmd { #[allow(clippy::module_name_repetitions)] pub struct ReplCmdHandler { config: SharedConfig, - reply: RefCell, abort: SharedAbortSignal, clipboard: std::result::Result, arboard::Error>, } @@ -38,11 +37,9 @@ pub struct ReplCmdHandler { impl ReplCmdHandler { #[allow(clippy::unnecessary_wraps)] pub fn init(config: SharedConfig, abort: SharedAbortSignal) -> Result { - let reply = RefCell::new(String::new()); let clipboard = Clipboard::new().map(RefCell::new); Ok(Self { config, - reply, abort, clipboard, }) @@ -52,7 +49,6 @@ impl ReplCmdHandler { match cmd { ReplCmd::Submit(input) => { if input.is_empty() { - self.reply.borrow_mut().clear(); return Ok(()); } self.config.read().maybe_print_send_tokens(&input); @@ -68,12 +64,10 @@ impl ReplCmdHandler { ); wg.wait(); let buffer = ret?; - self.config.read().save_message(&input, &buffer)?; + self.config.write().save_message(&input, &buffer)?; if self.config.read().auto_copy { let _ = self.copy(&buffer); } - self.config.write().save_conversation(&input, &buffer)?; - *self.reply.borrow_mut() = buffer; } ReplCmd::SetModel(name) => { self.config.write().set_model(&name)?; @@ -99,16 +93,23 @@ impl ReplCmdHandler { self.config.write().update(&input)?; print_now!("\n"); } - ReplCmd::StartConversation => { - self.config.write().start_conversation()?; + ReplCmd::StartSession(name) => { + self.config.write().start_session(&name)?; print_now!("\n"); } - ReplCmd::EndConversatoin => { - self.config.write().end_conversation(); + ReplCmd::EndSession => { + self.config.write().end_session()?; print_now!("\n"); } ReplCmd::Copy => { - self.copy(&self.reply.borrow()) + let reply = self + .config + .read() + .last_message + .as_ref() + .map(|v| v.1.clone()) + .unwrap_or_default(); + self.copy(&reply) .with_context(|| "Failed to copy the last output")?; print_now!("\n"); } diff --git a/src/repl/mod.rs b/src/repl/mod.rs index f498ab4..797519b 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -26,10 +26,10 @@ pub const REPL_COMMANDS: [(&str, &str); 15] = [ (".prompt", "Add a GPT prompt"), (".role", "Select a role"), (".clear role", "Clear the currently selected role"), - (".conversation", "Start a conversation."), - (".clear conversation", "End current conversation."), + (".session", "Start a session"), + (".clear session", "End current session"), (".copy", "Copy the last output to the clipboard"), - (".read", "Read the contents of a file into the prompt"), + (".read", "Read the contents of a file and submit"), (".edit", "Multi-line editing (CTRL+S to finish)"), (".history", "Print the history"), (".clear history", "Clear the history"), @@ -89,6 +89,7 @@ impl Repl { _ => {} } } + handler.handle(ReplCmd::EndSession)?; Ok(()) } @@ -111,7 +112,7 @@ impl Repl { print_now!("\n"); } Some("role") => handler.handle(ReplCmd::ClearRole)?, - Some("conversation") => handler.handle(ReplCmd::EndConversatoin)?, + Some("session") => handler.handle(ReplCmd::EndSession)?, _ => dump_unknown_command(), }, ".history" => { @@ -141,8 +142,8 @@ impl Repl { handler.handle(ReplCmd::Prompt(text))?; } } - ".conversation" => { - handler.handle(ReplCmd::StartConversation)?; + ".session" => { + handler.handle(ReplCmd::StartSession(args.map(|v| v.to_string())))?; } ".copy" => { handler.handle(ReplCmd::Copy)?; diff --git a/src/repl/prompt.rs b/src/repl/prompt.rs index 1d92407..be3763b 100644 --- a/src/repl/prompt.rs +++ b/src/repl/prompt.rs @@ -69,15 +69,17 @@ impl ReplPrompt { impl Prompt for ReplPrompt { fn render_prompt_left(&self) -> Cow { - self.config - .read() - .role - .as_ref() - .map_or(Cow::Borrowed(""), |role| Cow::Owned(role.name.clone())) + if let Some(session) = &self.config.read().session { + Cow::Owned(session.name.clone()) + } else if let Some(role) = &self.config.read().role { + Cow::Owned(role.name.clone()) + } else { + Cow::Borrowed("") + } } fn render_prompt_right(&self) -> Cow { - if self.config.read().conversation.is_none() { + if self.config.read().session.is_none() { Cow::Borrowed("") } else { self.config.read().get_reamind_tokens().to_string().into() @@ -85,7 +87,7 @@ impl Prompt for ReplPrompt { } fn render_prompt_indicator(&self, _prompt_mode: reedline::PromptEditMode) -> Cow { - if self.config.read().conversation.is_some() { + if self.config.read().session.is_some() { Cow::Borrowed(")") } else { Cow::Borrowed("〉")