diff --git a/src/client.rs b/src/client.rs index 589ad1a..19e6c27 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,13 +1,12 @@ use crate::config::SharedConfig; -use crate::repl::ReplyStreamHandler; +use crate::repl::{ReplyStreamHandler, SharedAbortSignal}; use anyhow::{anyhow, Context, Result}; use eventsource_stream::Eventsource; use futures_util::StreamExt; use reqwest::{Client, Proxy, RequestBuilder}; use serde_json::{json, Value}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::{sync::Arc, time::Duration}; +use std::time::Duration; use tokio::runtime::Runtime; use tokio::time::sleep; @@ -43,27 +42,27 @@ impl ChatGptClient { prompt: Option, handler: &mut ReplyStreamHandler, ) -> Result<()> { - async fn watch_ctrlc(ctrlc: Arc) { + async fn watch_abort(abort: SharedAbortSignal) { loop { - if ctrlc.load(Ordering::SeqCst) { + if abort.aborted() { break; } sleep(Duration::from_millis(100)).await; } } - let ctrlc = handler.get_ctrlc(); + let abort = handler.get_abort(); self.runtime.block_on(async { tokio::select! { ret = self.send_message_streaming_inner(input, prompt, handler) => { handler.done(); ret.with_context(|| "Failed to send message streaming") } - _ = watch_ctrlc(ctrlc.clone()) => { + _ = watch_abort(abort.clone()) => { handler.done(); Ok(()) }, _ = tokio::signal::ctrl_c() => { - ctrlc.store(true, Ordering::SeqCst); + abort.set_ctrlc(); Ok(()) } } diff --git a/src/render/mod.rs b/src/render/mod.rs index d9fcc95..b6e28f7 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -1,7 +1,7 @@ mod markdown; pub use self::markdown::MarkdownRender; -use crate::repl::ReplyStreamEvent; +use crate::repl::{ReplyStreamEvent, SharedAbortSignal}; use anyhow::Result; use crossbeam::channel::Receiver; @@ -13,20 +13,16 @@ use crossterm::{ }; use std::{ io::{self, Stdout, Write}, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, time::{Duration, Instant}, }; use unicode_width::UnicodeWidthStr; -pub fn render_stream(rx: Receiver, ctrlc: Arc) -> Result<()> { +pub fn render_stream(rx: Receiver, abort: SharedAbortSignal) -> Result<()> { enable_raw_mode()?; let mut stdout = io::stdout(); queue!(stdout, event::DisableMouseCapture)?; - let ret = render_stream_inner(rx, ctrlc, &mut stdout); + let ret = render_stream_inner(rx, abort, &mut stdout); queue!(stdout, event::DisableMouseCapture)?; disable_raw_mode()?; @@ -36,7 +32,7 @@ pub fn render_stream(rx: Receiver, ctrlc: Arc) -> pub fn render_stream_inner( rx: Receiver, - ctrlc: Arc, + abort: SharedAbortSignal, writer: &mut Stdout, ) -> Result<()> { let mut last_tick = Instant::now(); @@ -45,7 +41,7 @@ pub fn render_stream_inner( let mut markdown_render = MarkdownRender::new(); let terminal_columns = terminal::size()?.0; loop { - if ctrlc.load(Ordering::SeqCst) { + if abort.aborted() { return Ok(()); } @@ -89,7 +85,11 @@ pub fn render_stream_inner( if let Event::Key(key) = event::read()? { match key.code { KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => { - ctrlc.store(true, Ordering::SeqCst); + abort.set_ctrlc(); + return Ok(()); + } + KeyCode::Char('d') if key.modifiers == KeyModifiers::CONTROL => { + abort.set_ctrld(); return Ok(()); } _ => {} diff --git a/src/repl/abort.rs b/src/repl/abort.rs new file mode 100644 index 0000000..f76abb6 --- /dev/null +++ b/src/repl/abort.rs @@ -0,0 +1,51 @@ +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; + +pub type SharedAbortSignal = Arc; + +pub struct AbortSignal { + ctrlc: AtomicBool, + ctrld: AtomicBool, +} + +impl AbortSignal { + pub fn new() -> SharedAbortSignal { + Arc::new(Self { + ctrlc: AtomicBool::new(false), + ctrld: AtomicBool::new(false), + }) + } + + pub fn aborted(&self) -> bool { + if self.aborted_ctrlc() { + return true; + } + if self.aborted_ctrld() { + return true; + } + false + } + + pub fn aborted_ctrlc(&self) -> bool { + self.ctrlc.load(Ordering::SeqCst) + } + + pub fn aborted_ctrld(&self) -> bool { + self.ctrld.load(Ordering::SeqCst) + } + + pub fn reset(&self) { + self.ctrlc.store(false, Ordering::SeqCst); + self.ctrld.store(false, Ordering::SeqCst); + } + + pub fn set_ctrlc(&self) { + self.ctrlc.store(true, Ordering::SeqCst); + } + + pub fn set_ctrld(&self) { + self.ctrld.store(true, Ordering::SeqCst); + } +} diff --git a/src/repl/handler.rs b/src/repl/handler.rs index 809c202..f2c34e2 100644 --- a/src/repl/handler.rs +++ b/src/repl/handler.rs @@ -8,10 +8,10 @@ use crossbeam::channel::{unbounded, Sender}; use crossbeam::sync::WaitGroup; use std::cell::RefCell; use std::fs::File; -use std::sync::atomic::AtomicBool; -use std::sync::Arc; use std::thread::spawn; +use super::abort::SharedAbortSignal; + pub enum ReplCmd { Submit(String), SetRole(String), @@ -25,7 +25,7 @@ pub struct ReplCmdHandler { client: ChatGptClient, config: SharedConfig, state: RefCell, - ctrlc: Arc, + abort: SharedAbortSignal, } pub struct ReplCmdHandlerState { @@ -34,9 +34,12 @@ pub struct ReplCmdHandlerState { } impl ReplCmdHandler { - pub fn init(client: ChatGptClient, config: SharedConfig) -> Result { + pub fn init( + client: ChatGptClient, + config: SharedConfig, + abort: SharedAbortSignal, + ) -> Result { let save_file = config.as_ref().borrow().open_message_file()?; - let ctrlc = Arc::new(AtomicBool::new(false)); let state = RefCell::new(ReplCmdHandlerState { save_file, reply: String::new(), @@ -45,7 +48,7 @@ impl ReplCmdHandler { client, config, state, - ctrlc, + abort, }) } @@ -61,15 +64,15 @@ impl ReplCmdHandler { let highlight = self.config.borrow().highlight; let mut stream_handler = if highlight { let (tx, rx) = unbounded(); - let ctrlc = self.ctrlc.clone(); + let abort = self.abort.clone(); let wg = wg.clone(); spawn(move || { - let _ = render_stream(rx, ctrlc); + let _ = render_stream(rx, abort); drop(wg); }); - ReplyStreamHandler::new(Some(tx), self.ctrlc.clone()) + ReplyStreamHandler::new(Some(tx), self.abort.clone()) } else { - ReplyStreamHandler::new(None, self.ctrlc.clone()) + ReplyStreamHandler::new(None, self.abort.clone()) }; self.client .send_message_streaming(&input, prompt, &mut stream_handler)?; @@ -109,23 +112,19 @@ impl ReplCmdHandler { pub fn get_reply(&self) -> String { self.state.borrow().reply.to_string() } - - pub fn get_ctrlc(&self) -> Arc { - self.ctrlc.clone() - } } pub struct ReplyStreamHandler { sender: Option>, buffer: String, - ctrlc: Arc, + abort: SharedAbortSignal, } impl ReplyStreamHandler { - pub fn new(sender: Option>, ctrlc: Arc) -> Self { + pub fn new(sender: Option>, abort: SharedAbortSignal) -> Self { Self { sender, - ctrlc, + abort, buffer: String::new(), } } @@ -157,8 +156,8 @@ impl ReplyStreamHandler { &self.buffer } - pub fn get_ctrlc(&self) -> Arc { - self.ctrlc.clone() + pub fn get_abort(&self) -> SharedAbortSignal { + self.abort.clone() } } diff --git a/src/repl/mod.rs b/src/repl/mod.rs index 826ec6a..30e1c5e 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -1,3 +1,4 @@ +mod abort; mod handler; mod init; @@ -8,9 +9,9 @@ use crate::utils::{copy, dump}; use anyhow::{Context, Result}; use reedline::{DefaultPrompt, Reedline, Signal}; -use std::sync::atomic::Ordering; use std::sync::Arc; +pub use self::abort::*; pub use self::handler::*; pub const REPL_COMMANDS: [(&str, &str, bool); 12] = [ @@ -35,23 +36,27 @@ pub struct Repl { impl Repl { pub fn run(&mut self, client: ChatGptClient, config: SharedConfig) -> Result<()> { - let handler = ReplCmdHandler::init(client, config)?; + let abort = AbortSignal::new(); + let handler = ReplCmdHandler::init(client, config, abort.clone())?; dump( format!("Welcome to aichat {}", env!("CARGO_PKG_VERSION")), 1, ); dump("Type \".help\" for more information.", 1); - let mut current_ctrlc = false; + let mut already_ctrlc = false; let handler = Arc::new(handler); loop { - let handler_ctrlc = handler.get_ctrlc(); - if handler_ctrlc.load(Ordering::SeqCst) { - handler_ctrlc.store(false, Ordering::SeqCst); - current_ctrlc = true + if abort.aborted_ctrld() { + break; } - match self.editor.read_line(&self.prompt) { + if abort.aborted_ctrlc() && !already_ctrlc { + already_ctrlc = true; + } + let sig = self.editor.read_line(&self.prompt); + match sig { Ok(Signal::Success(line)) => { - current_ctrlc = false; + already_ctrlc = false; + abort.reset(); match self.handle_line(handler.clone(), line) { Ok(quit) => { if quit { @@ -65,14 +70,16 @@ impl Repl { } } Ok(Signal::CtrlC) => { - if !current_ctrlc { - current_ctrlc = true; + abort.set_ctrlc(); + if !already_ctrlc { + already_ctrlc = true; dump("(To exit, press Ctrl+C again or Ctrl+D or type .exit)", 2); } else { break; } } Ok(Signal::CtrlD) => { + abort.set_ctrld(); break; } _ => {}