feat: command mode supports stream out (#31)

* feat: command mode supports stream out

* update cli
pull/32/head
sigoden 1 year ago committed by GitHub
parent c7fcdb1744
commit 360264121c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

29
Cargo.lock generated

@ -20,6 +20,7 @@ dependencies = [
"clap", "clap",
"crossbeam", "crossbeam",
"crossterm 0.26.1", "crossterm 0.26.1",
"ctrlc",
"dirs", "dirs",
"eventsource-stream", "eventsource-stream",
"futures-util", "futures-util",
@ -320,6 +321,16 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "ctrlc"
version = "3.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbcf33c2a618cbe41ee43ae6e9f2e48368cd9f9db2896f10167d8d762679f639"
dependencies = [
"nix",
"windows-sys",
]
[[package]] [[package]]
name = "cxx" name = "cxx"
version = "1.0.92" version = "1.0.92"
@ -886,6 +897,18 @@ dependencies = [
"unicode-segmentation", "unicode-segmentation",
] ]
[[package]]
name = "nix"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a"
dependencies = [
"bitflags",
"cfg-if",
"libc",
"static_assertions",
]
[[package]] [[package]]
name = "nom" name = "nom"
version = "7.1.3" version = "7.1.3"
@ -1362,6 +1385,12 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]] [[package]]
name = "strip-ansi-escapes" name = "strip-ansi-escapes"
version = "0.1.1" version = "0.1.1"

@ -31,6 +31,7 @@ chrono = "0.4.23"
atty = "0.2.14" atty = "0.2.14"
unicode-width = "0.1.10" unicode-width = "0.1.10"
bincode = "1.3.3" bincode = "1.3.3"
ctrlc = "3.2.5"
[dependencies.syntect] [dependencies.syntect]
version = "5.0.0" version = "5.0.0"

@ -6,8 +6,11 @@ pub struct Cli {
/// Turn off highlight /// Turn off highlight
#[clap(short = 'H', long)] #[clap(short = 'H', long)]
pub no_highlight: bool, pub no_highlight: bool,
/// No stream output
#[clap(short = 'S', long)]
pub no_stream: bool,
/// List all roles /// List all roles
#[clap(short = 'L', long)] #[clap(long)]
pub list_roles: bool, pub list_roles: bool,
/// Select a role /// Select a role
#[clap(short, long)] #[clap(short, long)]

@ -14,12 +14,13 @@ use std::{io::stdout, process::exit};
use cli::Cli; use cli::Cli;
use client::ChatGptClient; use client::ChatGptClient;
use config::{Config, SharedConfig}; use config::{Config, SharedConfig};
use crossbeam::sync::WaitGroup;
use is_terminal::IsTerminal; use is_terminal::IsTerminal;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use clap::Parser; use clap::Parser;
use render::MarkdownRender; use render::{render_stream, MarkdownRender};
use repl::Repl; use repl::{AbortSignal, Repl};
fn main() -> Result<()> { fn main() -> Result<()> {
let cli = Cli::parse(); let cli = Cli::parse();
@ -46,6 +47,7 @@ fn main() -> Result<()> {
if cli.no_highlight { if cli.no_highlight {
config.borrow_mut().highlight = false; config.borrow_mut().highlight = false;
} }
let no_stream = cli.no_stream;
let client = ChatGptClient::init(config.clone())?; let client = ChatGptClient::init(config.clone())?;
if atty::isnt(atty::Stream::Stdin) { if atty::isnt(atty::Stream::Stdin) {
let mut input = String::new(); let mut input = String::new();
@ -53,28 +55,46 @@ fn main() -> Result<()> {
if let Some(text) = text { if let Some(text) = text {
input = format!("{text}\n{input}"); input = format!("{text}\n{input}");
} }
start_directive(client, config, &input) start_directive(client, config, &input, no_stream)
} else { } else {
match text { match text {
Some(text) => start_directive(client, config, &text), Some(text) => start_directive(client, config, &text, no_stream),
None => start_interactive(client, config), None => start_interactive(client, config),
} }
} }
} }
fn start_directive(client: ChatGptClient, config: SharedConfig, input: &str) -> Result<()> { fn start_directive(
client: ChatGptClient,
config: SharedConfig,
input: &str,
no_stream: bool,
) -> Result<()> {
let mut file = config.borrow().open_message_file()?; let mut file = config.borrow().open_message_file()?;
let prompt = config.borrow().get_prompt(); let prompt = config.borrow().get_prompt();
let output = client.send_message(input, prompt)?; let highlight = config.borrow().highlight && stdout().is_terminal();
let output = output.trim(); let output = if no_stream {
if config.borrow().highlight && stdout().is_terminal() { let output = client.send_message(input, prompt)?;
let mut markdown_render = MarkdownRender::new(); if highlight {
println!("{}", markdown_render.render(output)) let mut markdown_render = MarkdownRender::new();
println!("{}", markdown_render.render(&output));
} else {
println!("{output}");
}
output
} else { } else {
println!("{output}"); let wg = WaitGroup::new();
} let abort = AbortSignal::new();
let abort_clone = abort.clone();
config.borrow().save_message(file.as_mut(), input, output) ctrlc::set_handler(move || {
abort_clone.set_ctrlc();
})
.expect("Error setting Ctrl-C handler");
let output = render_stream(input, None, &client, highlight, false, abort, wg.clone())?;
wg.wait();
output
};
config.borrow().save_message(file.as_mut(), input, &output)
} }
fn start_interactive(client: ChatGptClient, config: SharedConfig) -> Result<()> { fn start_interactive(client: ChatGptClient, config: SharedConfig) -> Result<()> {

@ -0,0 +1,37 @@
use super::MarkdownRender;
use crate::repl::{ReplyStreamEvent, SharedAbortSignal};
use crate::utils::dump;
use anyhow::Result;
use crossbeam::channel::Receiver;
pub fn cmd_render_stream(rx: Receiver<ReplyStreamEvent>, abort: SharedAbortSignal) -> Result<()> {
let mut buffer = String::new();
let mut markdown_render = MarkdownRender::new();
loop {
if abort.aborted() {
return Ok(());
}
if let Ok(evt) = rx.try_recv() {
match evt {
ReplyStreamEvent::Text(text) => {
if text.contains('\n') {
let text = format!("{buffer}{text}");
let mut lines: Vec<&str> = text.split('\n').collect();
buffer = lines.pop().unwrap_or_default().to_string();
let output = lines.join("\n");
dump(markdown_render.render(&output), 1);
} else {
buffer = format!("{buffer}{text}");
}
}
ReplyStreamEvent::Done => {
let output = markdown_render.render(&buffer);
dump(output, 2);
break;
}
}
}
}
Ok(())
}

@ -1,125 +1,44 @@
mod cmd;
mod markdown; mod markdown;
mod repl;
use self::cmd::cmd_render_stream;
pub use self::markdown::MarkdownRender; pub use self::markdown::MarkdownRender;
use crate::repl::{ReplyStreamEvent, SharedAbortSignal}; use self::repl::repl_render_stream;
use crate::client::ChatGptClient;
use crate::repl::{ReplyStreamHandler, SharedAbortSignal};
use anyhow::Result; use anyhow::Result;
use crossbeam::channel::Receiver; use crossbeam::channel::unbounded;
use crossterm::{ use crossbeam::sync::WaitGroup;
cursor, use std::thread::spawn;
event::{self, Event, KeyCode, KeyModifiers},
queue, style, pub fn render_stream(
terminal::{self, disable_raw_mode, enable_raw_mode}, input: &str,
}; prompt: Option<String>,
use std::{ client: &ChatGptClient,
io::{self, Stdout, Write}, highlight: bool,
time::{Duration, Instant}, repl: bool,
};
use unicode_width::UnicodeWidthStr;
pub fn render_stream(rx: Receiver<ReplyStreamEvent>, abort: SharedAbortSignal) -> Result<()> {
enable_raw_mode()?;
let mut stdout = io::stdout();
queue!(stdout, event::DisableMouseCapture)?;
let ret = render_stream_inner(rx, abort, &mut stdout);
queue!(stdout, event::DisableMouseCapture)?;
disable_raw_mode()?;
ret
}
pub fn render_stream_inner(
rx: Receiver<ReplyStreamEvent>,
abort: SharedAbortSignal, abort: SharedAbortSignal,
writer: &mut Stdout, wg: WaitGroup,
) -> Result<()> { ) -> Result<String> {
let mut last_tick = Instant::now(); let mut stream_handler = if highlight {
let tick_rate = Duration::from_millis(100); let (tx, rx) = unbounded();
let mut buffer = String::new(); let abort_clone = abort.clone();
let mut markdown_render = MarkdownRender::new(); spawn(move || {
let terminal_columns = terminal::size()?.0; let _ = if repl {
loop { repl_render_stream(rx, abort)
if abort.aborted() { } else {
return Ok(()); cmd_render_stream(rx, abort)
} };
drop(wg);
if let Ok(evt) = rx.try_recv() { });
recover_cursor(writer, terminal_columns, &buffer)?; ReplyStreamHandler::new(Some(tx), abort_clone)
match evt {
ReplyStreamEvent::Text(text) => {
if text.contains('\n') {
let text = format!("{buffer}{text}");
let mut lines: Vec<&str> = text.split('\n').collect();
buffer = lines.pop().unwrap_or_default().to_string();
let output = markdown_render.render(&lines.join("\n"));
for line in output.split('\n') {
queue!(
writer,
style::Print(line),
style::Print("\n"),
cursor::MoveLeft(terminal_columns),
)?;
}
queue!(writer, style::Print(&buffer),)?;
} else {
buffer = format!("{buffer}{text}");
let output = markdown_render.render_line_stateless(&buffer);
queue!(writer, style::Print(&output))?;
}
writer.flush()?;
}
ReplyStreamEvent::Done => {
let output = markdown_render.render_line_stateless(&buffer);
queue!(writer, style::Print(output), style::Print("\n"))?;
writer.flush()?;
break;
}
}
continue;
}
let timeout = tick_rate
.checked_sub(last_tick.elapsed())
.unwrap_or_else(|| Duration::from_secs(0));
if crossterm::event::poll(timeout)? {
if let Event::Key(key) = event::read()? {
match key.code {
KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => {
abort.set_ctrlc();
return Ok(());
}
KeyCode::Char('d') if key.modifiers == KeyModifiers::CONTROL => {
abort.set_ctrld();
return Ok(());
}
_ => {}
}
}
}
if last_tick.elapsed() >= tick_rate {
last_tick = Instant::now();
}
}
Ok(())
}
fn recover_cursor(writer: &mut Stdout, terminal_columns: u16, buffer: &str) -> Result<()> {
let buffer_rows = (buffer.width() as u16 + terminal_columns - 1) / terminal_columns;
let (_, row) = cursor::position()?;
if buffer_rows == 0 {
queue!(writer, cursor::MoveTo(0, row))?;
} else if row + 1 >= buffer_rows {
queue!(writer, cursor::MoveTo(0, row + 1 - buffer_rows))?;
} else { } else {
queue!( drop(wg);
writer, ReplyStreamHandler::new(None, abort)
terminal::ScrollUp(buffer_rows - 1 - row), };
cursor::MoveTo(0, 0) client.send_message_streaming(input, prompt, &mut stream_handler)?;
)?; let buffer = stream_handler.get_buffer();
} Ok(buffer.to_string())
Ok(())
} }

@ -0,0 +1,123 @@
use super::MarkdownRender;
use crate::repl::{ReplyStreamEvent, SharedAbortSignal};
use anyhow::Result;
use crossbeam::channel::Receiver;
use crossterm::{
cursor,
event::{self, Event, KeyCode, KeyModifiers},
queue, style,
terminal::{self, disable_raw_mode, enable_raw_mode},
};
use std::{
io::{self, Stdout, Write},
time::{Duration, Instant},
};
use unicode_width::UnicodeWidthStr;
pub fn repl_render_stream(rx: Receiver<ReplyStreamEvent>, abort: SharedAbortSignal) -> Result<()> {
enable_raw_mode()?;
let mut stdout = io::stdout();
queue!(stdout, event::DisableMouseCapture)?;
let ret = repl_render_stream_inner(rx, abort, &mut stdout);
queue!(stdout, event::DisableMouseCapture)?;
disable_raw_mode()?;
ret
}
fn repl_render_stream_inner(
rx: Receiver<ReplyStreamEvent>,
abort: SharedAbortSignal,
writer: &mut Stdout,
) -> Result<()> {
let mut last_tick = Instant::now();
let tick_rate = Duration::from_millis(100);
let mut buffer = String::new();
let mut markdown_render = MarkdownRender::new();
let terminal_columns = terminal::size()?.0;
loop {
if abort.aborted() {
return Ok(());
}
if let Ok(evt) = rx.try_recv() {
recover_cursor(writer, terminal_columns, &buffer)?;
match evt {
ReplyStreamEvent::Text(text) => {
if text.contains('\n') {
let text = format!("{buffer}{text}");
let mut lines: Vec<&str> = text.split('\n').collect();
buffer = lines.pop().unwrap_or_default().to_string();
let output = markdown_render.render(&lines.join("\n"));
for line in output.split('\n') {
queue!(
writer,
style::Print(line),
style::Print("\n"),
cursor::MoveLeft(terminal_columns),
)?;
}
queue!(writer, style::Print(&buffer),)?;
} else {
buffer = format!("{buffer}{text}");
let output = markdown_render.render_line_stateless(&buffer);
queue!(writer, style::Print(&output))?;
}
writer.flush()?;
}
ReplyStreamEvent::Done => {
let output = markdown_render.render_line_stateless(&buffer);
queue!(writer, style::Print(output), style::Print("\n"))?;
writer.flush()?;
break;
}
}
continue;
}
let timeout = tick_rate
.checked_sub(last_tick.elapsed())
.unwrap_or_else(|| Duration::from_secs(0));
if crossterm::event::poll(timeout)? {
if let Event::Key(key) = event::read()? {
match key.code {
KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => {
abort.set_ctrlc();
return Ok(());
}
KeyCode::Char('d') if key.modifiers == KeyModifiers::CONTROL => {
abort.set_ctrld();
return Ok(());
}
_ => {}
}
}
}
if last_tick.elapsed() >= tick_rate {
last_tick = Instant::now();
}
}
Ok(())
}
fn recover_cursor(writer: &mut Stdout, terminal_columns: u16, buffer: &str) -> Result<()> {
let buffer_rows = (buffer.width() as u16 + terminal_columns - 1) / terminal_columns;
let (_, row) = cursor::position()?;
if buffer_rows == 0 {
queue!(writer, cursor::MoveTo(0, row))?;
} else if row + 1 >= buffer_rows {
queue!(writer, cursor::MoveTo(0, row + 1 - buffer_rows))?;
} else {
queue!(
writer,
terminal::ScrollUp(buffer_rows - 1 - row),
cursor::MoveTo(0, 0)
)?;
}
Ok(())
}

@ -4,11 +4,10 @@ use crate::render::render_stream;
use crate::utils::dump; use crate::utils::dump;
use anyhow::Result; use anyhow::Result;
use crossbeam::channel::{unbounded, Sender}; use crossbeam::channel::Sender;
use crossbeam::sync::WaitGroup; use crossbeam::sync::WaitGroup;
use std::cell::RefCell; use std::cell::RefCell;
use std::fs::File; use std::fs::File;
use std::thread::spawn;
use super::abort::SharedAbortSignal; use super::abort::SharedAbortSignal;
@ -59,23 +58,26 @@ impl ReplCmdHandler {
self.state.borrow_mut().reply.clear(); self.state.borrow_mut().reply.clear();
return Ok(()); return Ok(());
} }
let wg = WaitGroup::new();
let highlight = self.config.borrow().highlight; let highlight = self.config.borrow().highlight;
let stream_handler = if highlight { let prompt = self.config.borrow().get_prompt();
let (tx, rx) = unbounded(); let wg = WaitGroup::new();
let abort = self.abort.clone(); let ret = render_stream(
let wg = wg.clone(); &input,
spawn(move || { prompt,
let _ = render_stream(rx, abort); &self.client,
drop(wg); highlight,
}); true,
ReplyStreamHandler::new(Some(tx), self.abort.clone()) self.abort.clone(),
} else { wg.clone(),
ReplyStreamHandler::new(None, self.abort.clone()) );
};
let ret = self.handle_send_stream(&input, stream_handler);
wg.wait(); wg.wait();
self.state.borrow_mut().reply = ret?; let buffer = ret?;
self.config.borrow().save_message(
self.state.borrow_mut().save_file.as_mut(),
&input,
&buffer,
)?;
self.state.borrow_mut().reply = buffer;
} }
ReplCmd::SetRole(name) => { ReplCmd::SetRole(name) => {
let output = self.config.borrow_mut().change_role(&name); let output = self.config.borrow_mut().change_role(&name);
@ -100,23 +102,6 @@ impl ReplCmdHandler {
} }
Ok(()) Ok(())
} }
fn handle_send_stream(
&self,
input: &str,
mut stream_handler: ReplyStreamHandler,
) -> Result<String> {
let prompt = self.config.borrow().get_prompt();
self.client
.send_message_streaming(input, prompt, &mut stream_handler)?;
let buffer = stream_handler.get_buffer();
self.config.borrow().save_message(
self.state.borrow_mut().save_file.as_mut(),
input,
buffer,
)?;
Ok(buffer.to_string())
}
} }
pub struct ReplyStreamHandler { pub struct ReplyStreamHandler {

Loading…
Cancel
Save