diff --git a/Cargo.lock b/Cargo.lock index 7d5be2f..69a94d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,6 +20,7 @@ dependencies = [ "clap", "crossbeam", "crossterm 0.26.1", + "ctrlc", "dirs", "eventsource-stream", "futures-util", @@ -320,6 +321,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "ctrlc" +version = "3.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbcf33c2a618cbe41ee43ae6e9f2e48368cd9f9db2896f10167d8d762679f639" +dependencies = [ + "nix", + "windows-sys", +] + [[package]] name = "cxx" version = "1.0.92" @@ -886,6 +897,18 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "nix" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" +dependencies = [ + "bitflags", + "cfg-if", + "libc", + "static_assertions", +] + [[package]] name = "nom" version = "7.1.3" @@ -1362,6 +1385,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "strip-ansi-escapes" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index e997bbd..828f880 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ chrono = "0.4.23" atty = "0.2.14" unicode-width = "0.1.10" bincode = "1.3.3" +ctrlc = "3.2.5" [dependencies.syntect] version = "5.0.0" diff --git a/src/cli.rs b/src/cli.rs index 3d60c9c..f2d2640 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -6,8 +6,11 @@ pub struct Cli { /// Turn off highlight #[clap(short = 'H', long)] pub no_highlight: bool, + /// No stream output + #[clap(short = 'S', long)] + pub no_stream: bool, /// List all roles - #[clap(short = 'L', long)] + #[clap(long)] pub list_roles: bool, /// Select a role #[clap(short, long)] diff --git a/src/main.rs b/src/main.rs index f1ae303..dfce165 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,12 +14,13 @@ use std::{io::stdout, process::exit}; use cli::Cli; use client::ChatGptClient; use config::{Config, SharedConfig}; +use crossbeam::sync::WaitGroup; use is_terminal::IsTerminal; use anyhow::{anyhow, Result}; use clap::Parser; -use render::MarkdownRender; -use repl::Repl; +use render::{render_stream, MarkdownRender}; +use repl::{AbortSignal, Repl}; fn main() -> Result<()> { let cli = Cli::parse(); @@ -46,6 +47,7 @@ fn main() -> Result<()> { if cli.no_highlight { config.borrow_mut().highlight = false; } + let no_stream = cli.no_stream; let client = ChatGptClient::init(config.clone())?; if atty::isnt(atty::Stream::Stdin) { let mut input = String::new(); @@ -53,28 +55,46 @@ fn main() -> Result<()> { if let Some(text) = text { input = format!("{text}\n{input}"); } - start_directive(client, config, &input) + start_directive(client, config, &input, no_stream) } else { match text { - Some(text) => start_directive(client, config, &text), + Some(text) => start_directive(client, config, &text, no_stream), None => start_interactive(client, config), } } } -fn start_directive(client: ChatGptClient, config: SharedConfig, input: &str) -> Result<()> { +fn start_directive( + client: ChatGptClient, + config: SharedConfig, + input: &str, + no_stream: bool, +) -> Result<()> { let mut file = config.borrow().open_message_file()?; let prompt = config.borrow().get_prompt(); - let output = client.send_message(input, prompt)?; - let output = output.trim(); - if config.borrow().highlight && stdout().is_terminal() { - let mut markdown_render = MarkdownRender::new(); - println!("{}", markdown_render.render(output)) + let highlight = config.borrow().highlight && stdout().is_terminal(); + let output = if no_stream { + let output = client.send_message(input, prompt)?; + if highlight { + let mut markdown_render = MarkdownRender::new(); + println!("{}", markdown_render.render(&output)); + } else { + println!("{output}"); + } + output } else { - println!("{output}"); - } - - config.borrow().save_message(file.as_mut(), input, output) + let wg = WaitGroup::new(); + let abort = AbortSignal::new(); + let abort_clone = abort.clone(); + ctrlc::set_handler(move || { + abort_clone.set_ctrlc(); + }) + .expect("Error setting Ctrl-C handler"); + let output = render_stream(input, None, &client, highlight, false, abort, wg.clone())?; + wg.wait(); + output + }; + config.borrow().save_message(file.as_mut(), input, &output) } fn start_interactive(client: ChatGptClient, config: SharedConfig) -> Result<()> { diff --git a/src/render/cmd.rs b/src/render/cmd.rs new file mode 100644 index 0000000..f29fe7d --- /dev/null +++ b/src/render/cmd.rs @@ -0,0 +1,37 @@ +use super::MarkdownRender; +use crate::repl::{ReplyStreamEvent, SharedAbortSignal}; +use crate::utils::dump; + +use anyhow::Result; +use crossbeam::channel::Receiver; + +pub fn cmd_render_stream(rx: Receiver, abort: SharedAbortSignal) -> Result<()> { + let mut buffer = String::new(); + let mut markdown_render = MarkdownRender::new(); + loop { + if abort.aborted() { + return Ok(()); + } + if let Ok(evt) = rx.try_recv() { + match evt { + ReplyStreamEvent::Text(text) => { + if text.contains('\n') { + let text = format!("{buffer}{text}"); + let mut lines: Vec<&str> = text.split('\n').collect(); + buffer = lines.pop().unwrap_or_default().to_string(); + let output = lines.join("\n"); + dump(markdown_render.render(&output), 1); + } else { + buffer = format!("{buffer}{text}"); + } + } + ReplyStreamEvent::Done => { + let output = markdown_render.render(&buffer); + dump(output, 2); + break; + } + } + } + } + Ok(()) +} diff --git a/src/render/mod.rs b/src/render/mod.rs index e0f42e7..efbcdd7 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -1,125 +1,44 @@ +mod cmd; mod markdown; +mod repl; +use self::cmd::cmd_render_stream; pub use self::markdown::MarkdownRender; -use crate::repl::{ReplyStreamEvent, SharedAbortSignal}; +use self::repl::repl_render_stream; +use crate::client::ChatGptClient; +use crate::repl::{ReplyStreamHandler, SharedAbortSignal}; use anyhow::Result; -use crossbeam::channel::Receiver; -use crossterm::{ - cursor, - event::{self, Event, KeyCode, KeyModifiers}, - queue, style, - terminal::{self, disable_raw_mode, enable_raw_mode}, -}; -use std::{ - io::{self, Stdout, Write}, - time::{Duration, Instant}, -}; -use unicode_width::UnicodeWidthStr; - -pub fn render_stream(rx: Receiver, abort: SharedAbortSignal) -> Result<()> { - enable_raw_mode()?; - let mut stdout = io::stdout(); - queue!(stdout, event::DisableMouseCapture)?; - - let ret = render_stream_inner(rx, abort, &mut stdout); - - queue!(stdout, event::DisableMouseCapture)?; - disable_raw_mode()?; - - ret -} - -pub fn render_stream_inner( - rx: Receiver, +use crossbeam::channel::unbounded; +use crossbeam::sync::WaitGroup; +use std::thread::spawn; + +pub fn render_stream( + input: &str, + prompt: Option, + client: &ChatGptClient, + highlight: bool, + repl: bool, abort: SharedAbortSignal, - writer: &mut Stdout, -) -> Result<()> { - let mut last_tick = Instant::now(); - let tick_rate = Duration::from_millis(100); - let mut buffer = String::new(); - let mut markdown_render = MarkdownRender::new(); - let terminal_columns = terminal::size()?.0; - loop { - if abort.aborted() { - return Ok(()); - } - - if let Ok(evt) = rx.try_recv() { - recover_cursor(writer, terminal_columns, &buffer)?; - - match evt { - ReplyStreamEvent::Text(text) => { - if text.contains('\n') { - let text = format!("{buffer}{text}"); - let mut lines: Vec<&str> = text.split('\n').collect(); - buffer = lines.pop().unwrap_or_default().to_string(); - let output = markdown_render.render(&lines.join("\n")); - for line in output.split('\n') { - queue!( - writer, - style::Print(line), - style::Print("\n"), - cursor::MoveLeft(terminal_columns), - )?; - } - queue!(writer, style::Print(&buffer),)?; - } else { - buffer = format!("{buffer}{text}"); - let output = markdown_render.render_line_stateless(&buffer); - queue!(writer, style::Print(&output))?; - } - writer.flush()?; - } - ReplyStreamEvent::Done => { - let output = markdown_render.render_line_stateless(&buffer); - queue!(writer, style::Print(output), style::Print("\n"))?; - writer.flush()?; - break; - } - } - continue; - } - - let timeout = tick_rate - .checked_sub(last_tick.elapsed()) - .unwrap_or_else(|| Duration::from_secs(0)); - if crossterm::event::poll(timeout)? { - if let Event::Key(key) = event::read()? { - match key.code { - KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => { - abort.set_ctrlc(); - return Ok(()); - } - KeyCode::Char('d') if key.modifiers == KeyModifiers::CONTROL => { - abort.set_ctrld(); - return Ok(()); - } - _ => {} - } - } - } - - if last_tick.elapsed() >= tick_rate { - last_tick = Instant::now(); - } - } - Ok(()) -} - -fn recover_cursor(writer: &mut Stdout, terminal_columns: u16, buffer: &str) -> Result<()> { - let buffer_rows = (buffer.width() as u16 + terminal_columns - 1) / terminal_columns; - let (_, row) = cursor::position()?; - if buffer_rows == 0 { - queue!(writer, cursor::MoveTo(0, row))?; - } else if row + 1 >= buffer_rows { - queue!(writer, cursor::MoveTo(0, row + 1 - buffer_rows))?; + wg: WaitGroup, +) -> Result { + let mut stream_handler = if highlight { + let (tx, rx) = unbounded(); + let abort_clone = abort.clone(); + spawn(move || { + let _ = if repl { + repl_render_stream(rx, abort) + } else { + cmd_render_stream(rx, abort) + }; + drop(wg); + }); + ReplyStreamHandler::new(Some(tx), abort_clone) } else { - queue!( - writer, - terminal::ScrollUp(buffer_rows - 1 - row), - cursor::MoveTo(0, 0) - )?; - } - Ok(()) + drop(wg); + ReplyStreamHandler::new(None, abort) + }; + client.send_message_streaming(input, prompt, &mut stream_handler)?; + let buffer = stream_handler.get_buffer(); + Ok(buffer.to_string()) } diff --git a/src/render/repl.rs b/src/render/repl.rs new file mode 100644 index 0000000..96c6364 --- /dev/null +++ b/src/render/repl.rs @@ -0,0 +1,123 @@ +use super::MarkdownRender; +use crate::repl::{ReplyStreamEvent, SharedAbortSignal}; + +use anyhow::Result; +use crossbeam::channel::Receiver; +use crossterm::{ + cursor, + event::{self, Event, KeyCode, KeyModifiers}, + queue, style, + terminal::{self, disable_raw_mode, enable_raw_mode}, +}; +use std::{ + io::{self, Stdout, Write}, + time::{Duration, Instant}, +}; +use unicode_width::UnicodeWidthStr; + +pub fn repl_render_stream(rx: Receiver, abort: SharedAbortSignal) -> Result<()> { + enable_raw_mode()?; + let mut stdout = io::stdout(); + queue!(stdout, event::DisableMouseCapture)?; + + let ret = repl_render_stream_inner(rx, abort, &mut stdout); + + queue!(stdout, event::DisableMouseCapture)?; + disable_raw_mode()?; + + ret +} + +fn repl_render_stream_inner( + rx: Receiver, + abort: SharedAbortSignal, + writer: &mut Stdout, +) -> Result<()> { + let mut last_tick = Instant::now(); + let tick_rate = Duration::from_millis(100); + let mut buffer = String::new(); + let mut markdown_render = MarkdownRender::new(); + let terminal_columns = terminal::size()?.0; + loop { + if abort.aborted() { + return Ok(()); + } + + if let Ok(evt) = rx.try_recv() { + recover_cursor(writer, terminal_columns, &buffer)?; + + match evt { + ReplyStreamEvent::Text(text) => { + if text.contains('\n') { + let text = format!("{buffer}{text}"); + let mut lines: Vec<&str> = text.split('\n').collect(); + buffer = lines.pop().unwrap_or_default().to_string(); + let output = markdown_render.render(&lines.join("\n")); + for line in output.split('\n') { + queue!( + writer, + style::Print(line), + style::Print("\n"), + cursor::MoveLeft(terminal_columns), + )?; + } + queue!(writer, style::Print(&buffer),)?; + } else { + buffer = format!("{buffer}{text}"); + let output = markdown_render.render_line_stateless(&buffer); + queue!(writer, style::Print(&output))?; + } + writer.flush()?; + } + ReplyStreamEvent::Done => { + let output = markdown_render.render_line_stateless(&buffer); + queue!(writer, style::Print(output), style::Print("\n"))?; + writer.flush()?; + break; + } + } + continue; + } + + let timeout = tick_rate + .checked_sub(last_tick.elapsed()) + .unwrap_or_else(|| Duration::from_secs(0)); + if crossterm::event::poll(timeout)? { + if let Event::Key(key) = event::read()? { + match key.code { + KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => { + abort.set_ctrlc(); + return Ok(()); + } + KeyCode::Char('d') if key.modifiers == KeyModifiers::CONTROL => { + abort.set_ctrld(); + return Ok(()); + } + _ => {} + } + } + } + + if last_tick.elapsed() >= tick_rate { + last_tick = Instant::now(); + } + } + Ok(()) +} + +fn recover_cursor(writer: &mut Stdout, terminal_columns: u16, buffer: &str) -> Result<()> { + let buffer_rows = (buffer.width() as u16 + terminal_columns - 1) / terminal_columns; + let (_, row) = cursor::position()?; + if buffer_rows == 0 { + queue!(writer, cursor::MoveTo(0, row))?; + } else if row + 1 >= buffer_rows { + queue!(writer, cursor::MoveTo(0, row + 1 - buffer_rows))?; + } else { + queue!( + writer, + terminal::ScrollUp(buffer_rows - 1 - row), + cursor::MoveTo(0, 0) + )?; + } + Ok(()) +} diff --git a/src/repl/handler.rs b/src/repl/handler.rs index add3eac..384fce0 100644 --- a/src/repl/handler.rs +++ b/src/repl/handler.rs @@ -4,11 +4,10 @@ use crate::render::render_stream; use crate::utils::dump; use anyhow::Result; -use crossbeam::channel::{unbounded, Sender}; +use crossbeam::channel::Sender; use crossbeam::sync::WaitGroup; use std::cell::RefCell; use std::fs::File; -use std::thread::spawn; use super::abort::SharedAbortSignal; @@ -59,23 +58,26 @@ impl ReplCmdHandler { self.state.borrow_mut().reply.clear(); return Ok(()); } - let wg = WaitGroup::new(); let highlight = self.config.borrow().highlight; - let stream_handler = if highlight { - let (tx, rx) = unbounded(); - let abort = self.abort.clone(); - let wg = wg.clone(); - spawn(move || { - let _ = render_stream(rx, abort); - drop(wg); - }); - ReplyStreamHandler::new(Some(tx), self.abort.clone()) - } else { - ReplyStreamHandler::new(None, self.abort.clone()) - }; - let ret = self.handle_send_stream(&input, stream_handler); + let prompt = self.config.borrow().get_prompt(); + let wg = WaitGroup::new(); + let ret = render_stream( + &input, + prompt, + &self.client, + highlight, + true, + self.abort.clone(), + wg.clone(), + ); wg.wait(); - self.state.borrow_mut().reply = ret?; + let buffer = ret?; + self.config.borrow().save_message( + self.state.borrow_mut().save_file.as_mut(), + &input, + &buffer, + )?; + self.state.borrow_mut().reply = buffer; } ReplCmd::SetRole(name) => { let output = self.config.borrow_mut().change_role(&name); @@ -100,23 +102,6 @@ impl ReplCmdHandler { } Ok(()) } - - fn handle_send_stream( - &self, - input: &str, - mut stream_handler: ReplyStreamHandler, - ) -> Result { - let prompt = self.config.borrow().get_prompt(); - self.client - .send_message_streaming(input, prompt, &mut stream_handler)?; - let buffer = stream_handler.get_buffer(); - self.config.borrow().save_message( - self.state.borrow_mut().save_file.as_mut(), - input, - buffer, - )?; - Ok(buffer.to_string()) - } } pub struct ReplyStreamHandler {