refactor: optimize ctrl+c/ctrl+d abort handling (#27)

This commit is contained in:
sigoden 2023-03-07 11:51:52 +08:00 committed by GitHub
parent 1640456049
commit 11dc4d104b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 104 additions and 48 deletions

View File

@ -1,13 +1,12 @@
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
use crate::repl::{ReplyStreamHandler, SharedAbortSignal};
use anyhow::{anyhow, Context, Result};
use eventsource_stream::Eventsource;
use futures_util::StreamExt;
use reqwest::{Client, Proxy, RequestBuilder};
use serde_json::{json, Value};
use std::sync::atomic::{AtomicBool, Ordering};
use std::{sync::Arc, time::Duration};
use std::time::Duration;
use tokio::runtime::Runtime;
use tokio::time::sleep;
@ -43,27 +42,27 @@ impl ChatGptClient {
prompt: Option<String>,
handler: &mut ReplyStreamHandler,
) -> Result<()> {
async fn watch_ctrlc(ctrlc: Arc<AtomicBool>) {
async fn watch_abort(abort: SharedAbortSignal) {
loop {
if ctrlc.load(Ordering::SeqCst) {
if abort.aborted() {
break;
}
sleep(Duration::from_millis(100)).await;
}
}
let ctrlc = handler.get_ctrlc();
let abort = handler.get_abort();
self.runtime.block_on(async {
tokio::select! {
ret = self.send_message_streaming_inner(input, prompt, handler) => {
handler.done();
ret.with_context(|| "Failed to send message streaming")
}
_ = watch_ctrlc(ctrlc.clone()) => {
_ = watch_abort(abort.clone()) => {
handler.done();
Ok(())
},
_ = tokio::signal::ctrl_c() => {
ctrlc.store(true, Ordering::SeqCst);
abort.set_ctrlc();
Ok(())
}
}

View File

@ -1,7 +1,7 @@
mod markdown;
pub use self::markdown::MarkdownRender;
use crate::repl::ReplyStreamEvent;
use crate::repl::{ReplyStreamEvent, SharedAbortSignal};
use anyhow::Result;
use crossbeam::channel::Receiver;
@ -13,20 +13,16 @@ use crossterm::{
};
use std::{
io::{self, Stdout, Write},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::{Duration, Instant},
};
use unicode_width::UnicodeWidthStr;
pub fn render_stream(rx: Receiver<ReplyStreamEvent>, ctrlc: Arc<AtomicBool>) -> Result<()> {
pub fn render_stream(rx: Receiver<ReplyStreamEvent>, abort: SharedAbortSignal) -> Result<()> {
enable_raw_mode()?;
let mut stdout = io::stdout();
queue!(stdout, event::DisableMouseCapture)?;
let ret = render_stream_inner(rx, ctrlc, &mut stdout);
let ret = render_stream_inner(rx, abort, &mut stdout);
queue!(stdout, event::DisableMouseCapture)?;
disable_raw_mode()?;
@ -36,7 +32,7 @@ pub fn render_stream(rx: Receiver<ReplyStreamEvent>, ctrlc: Arc<AtomicBool>) ->
pub fn render_stream_inner(
rx: Receiver<ReplyStreamEvent>,
ctrlc: Arc<AtomicBool>,
abort: SharedAbortSignal,
writer: &mut Stdout,
) -> Result<()> {
let mut last_tick = Instant::now();
@ -45,7 +41,7 @@ pub fn render_stream_inner(
let mut markdown_render = MarkdownRender::new();
let terminal_columns = terminal::size()?.0;
loop {
if ctrlc.load(Ordering::SeqCst) {
if abort.aborted() {
return Ok(());
}
@ -89,7 +85,11 @@ pub fn render_stream_inner(
if let Event::Key(key) = event::read()? {
match key.code {
KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => {
ctrlc.store(true, Ordering::SeqCst);
abort.set_ctrlc();
return Ok(());
}
KeyCode::Char('d') if key.modifiers == KeyModifiers::CONTROL => {
abort.set_ctrld();
return Ok(());
}
_ => {}

51
src/repl/abort.rs Normal file
View File

@ -0,0 +1,51 @@
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
pub type SharedAbortSignal = Arc<AbortSignal>;
pub struct AbortSignal {
ctrlc: AtomicBool,
ctrld: AtomicBool,
}
impl AbortSignal {
pub fn new() -> SharedAbortSignal {
Arc::new(Self {
ctrlc: AtomicBool::new(false),
ctrld: AtomicBool::new(false),
})
}
pub fn aborted(&self) -> bool {
if self.aborted_ctrlc() {
return true;
}
if self.aborted_ctrld() {
return true;
}
false
}
pub fn aborted_ctrlc(&self) -> bool {
self.ctrlc.load(Ordering::SeqCst)
}
pub fn aborted_ctrld(&self) -> bool {
self.ctrld.load(Ordering::SeqCst)
}
pub fn reset(&self) {
self.ctrlc.store(false, Ordering::SeqCst);
self.ctrld.store(false, Ordering::SeqCst);
}
pub fn set_ctrlc(&self) {
self.ctrlc.store(true, Ordering::SeqCst);
}
pub fn set_ctrld(&self) {
self.ctrld.store(true, Ordering::SeqCst);
}
}

View File

@ -8,10 +8,10 @@ use crossbeam::channel::{unbounded, Sender};
use crossbeam::sync::WaitGroup;
use std::cell::RefCell;
use std::fs::File;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::thread::spawn;
use super::abort::SharedAbortSignal;
pub enum ReplCmd {
Submit(String),
SetRole(String),
@ -25,7 +25,7 @@ pub struct ReplCmdHandler {
client: ChatGptClient,
config: SharedConfig,
state: RefCell<ReplCmdHandlerState>,
ctrlc: Arc<AtomicBool>,
abort: SharedAbortSignal,
}
pub struct ReplCmdHandlerState {
@ -34,9 +34,12 @@ pub struct ReplCmdHandlerState {
}
impl ReplCmdHandler {
pub fn init(client: ChatGptClient, config: SharedConfig) -> Result<Self> {
pub fn init(
client: ChatGptClient,
config: SharedConfig,
abort: SharedAbortSignal,
) -> Result<Self> {
let save_file = config.as_ref().borrow().open_message_file()?;
let ctrlc = Arc::new(AtomicBool::new(false));
let state = RefCell::new(ReplCmdHandlerState {
save_file,
reply: String::new(),
@ -45,7 +48,7 @@ impl ReplCmdHandler {
client,
config,
state,
ctrlc,
abort,
})
}
@ -61,15 +64,15 @@ impl ReplCmdHandler {
let highlight = self.config.borrow().highlight;
let mut stream_handler = if highlight {
let (tx, rx) = unbounded();
let ctrlc = self.ctrlc.clone();
let abort = self.abort.clone();
let wg = wg.clone();
spawn(move || {
let _ = render_stream(rx, ctrlc);
let _ = render_stream(rx, abort);
drop(wg);
});
ReplyStreamHandler::new(Some(tx), self.ctrlc.clone())
ReplyStreamHandler::new(Some(tx), self.abort.clone())
} else {
ReplyStreamHandler::new(None, self.ctrlc.clone())
ReplyStreamHandler::new(None, self.abort.clone())
};
self.client
.send_message_streaming(&input, prompt, &mut stream_handler)?;
@ -109,23 +112,19 @@ impl ReplCmdHandler {
pub fn get_reply(&self) -> String {
self.state.borrow().reply.to_string()
}
pub fn get_ctrlc(&self) -> Arc<AtomicBool> {
self.ctrlc.clone()
}
}
pub struct ReplyStreamHandler {
sender: Option<Sender<ReplyStreamEvent>>,
buffer: String,
ctrlc: Arc<AtomicBool>,
abort: SharedAbortSignal,
}
impl ReplyStreamHandler {
pub fn new(sender: Option<Sender<ReplyStreamEvent>>, ctrlc: Arc<AtomicBool>) -> Self {
pub fn new(sender: Option<Sender<ReplyStreamEvent>>, abort: SharedAbortSignal) -> Self {
Self {
sender,
ctrlc,
abort,
buffer: String::new(),
}
}
@ -157,8 +156,8 @@ impl ReplyStreamHandler {
&self.buffer
}
pub fn get_ctrlc(&self) -> Arc<AtomicBool> {
self.ctrlc.clone()
pub fn get_abort(&self) -> SharedAbortSignal {
self.abort.clone()
}
}

View File

@ -1,3 +1,4 @@
mod abort;
mod handler;
mod init;
@ -8,9 +9,9 @@ use crate::utils::{copy, dump};
use anyhow::{Context, Result};
use reedline::{DefaultPrompt, Reedline, Signal};
use std::sync::atomic::Ordering;
use std::sync::Arc;
pub use self::abort::*;
pub use self::handler::*;
pub const REPL_COMMANDS: [(&str, &str, bool); 12] = [
@ -35,23 +36,27 @@ pub struct Repl {
impl Repl {
pub fn run(&mut self, client: ChatGptClient, config: SharedConfig) -> Result<()> {
let handler = ReplCmdHandler::init(client, config)?;
let abort = AbortSignal::new();
let handler = ReplCmdHandler::init(client, config, abort.clone())?;
dump(
format!("Welcome to aichat {}", env!("CARGO_PKG_VERSION")),
1,
);
dump("Type \".help\" for more information.", 1);
let mut current_ctrlc = false;
let mut already_ctrlc = false;
let handler = Arc::new(handler);
loop {
let handler_ctrlc = handler.get_ctrlc();
if handler_ctrlc.load(Ordering::SeqCst) {
handler_ctrlc.store(false, Ordering::SeqCst);
current_ctrlc = true
if abort.aborted_ctrld() {
break;
}
match self.editor.read_line(&self.prompt) {
if abort.aborted_ctrlc() && !already_ctrlc {
already_ctrlc = true;
}
let sig = self.editor.read_line(&self.prompt);
match sig {
Ok(Signal::Success(line)) => {
current_ctrlc = false;
already_ctrlc = false;
abort.reset();
match self.handle_line(handler.clone(), line) {
Ok(quit) => {
if quit {
@ -65,14 +70,16 @@ impl Repl {
}
}
Ok(Signal::CtrlC) => {
if !current_ctrlc {
current_ctrlc = true;
abort.set_ctrlc();
if !already_ctrlc {
already_ctrlc = true;
dump("(To exit, press Ctrl+C again or Ctrl+D or type .exit)", 2);
} else {
break;
}
}
Ok(Signal::CtrlD) => {
abort.set_ctrld();
break;
}
_ => {}