feat: command mode supports stream out (#31)

* feat: command mode supports stream out

* update cli
pull/32/head
sigoden 1 year ago committed by GitHub
parent c7fcdb1744
commit 360264121c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

29
Cargo.lock generated

@ -20,6 +20,7 @@ dependencies = [
"clap",
"crossbeam",
"crossterm 0.26.1",
"ctrlc",
"dirs",
"eventsource-stream",
"futures-util",
@ -320,6 +321,16 @@ dependencies = [
"winapi",
]
[[package]]
name = "ctrlc"
version = "3.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbcf33c2a618cbe41ee43ae6e9f2e48368cd9f9db2896f10167d8d762679f639"
dependencies = [
"nix",
"windows-sys",
]
[[package]]
name = "cxx"
version = "1.0.92"
@ -886,6 +897,18 @@ dependencies = [
"unicode-segmentation",
]
[[package]]
name = "nix"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a"
dependencies = [
"bitflags",
"cfg-if",
"libc",
"static_assertions",
]
[[package]]
name = "nom"
version = "7.1.3"
@ -1362,6 +1385,12 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "strip-ansi-escapes"
version = "0.1.1"

@ -31,6 +31,7 @@ chrono = "0.4.23"
atty = "0.2.14"
unicode-width = "0.1.10"
bincode = "1.3.3"
ctrlc = "3.2.5"
[dependencies.syntect]
version = "5.0.0"

@ -6,8 +6,11 @@ pub struct Cli {
/// Turn off highlight
#[clap(short = 'H', long)]
pub no_highlight: bool,
/// No stream output
#[clap(short = 'S', long)]
pub no_stream: bool,
/// List all roles
#[clap(short = 'L', long)]
#[clap(long)]
pub list_roles: bool,
/// Select a role
#[clap(short, long)]

@ -14,12 +14,13 @@ use std::{io::stdout, process::exit};
use cli::Cli;
use client::ChatGptClient;
use config::{Config, SharedConfig};
use crossbeam::sync::WaitGroup;
use is_terminal::IsTerminal;
use anyhow::{anyhow, Result};
use clap::Parser;
use render::MarkdownRender;
use repl::Repl;
use render::{render_stream, MarkdownRender};
use repl::{AbortSignal, Repl};
fn main() -> Result<()> {
let cli = Cli::parse();
@ -46,6 +47,7 @@ fn main() -> Result<()> {
if cli.no_highlight {
config.borrow_mut().highlight = false;
}
let no_stream = cli.no_stream;
let client = ChatGptClient::init(config.clone())?;
if atty::isnt(atty::Stream::Stdin) {
let mut input = String::new();
@ -53,28 +55,46 @@ fn main() -> Result<()> {
if let Some(text) = text {
input = format!("{text}\n{input}");
}
start_directive(client, config, &input)
start_directive(client, config, &input, no_stream)
} else {
match text {
Some(text) => start_directive(client, config, &text),
Some(text) => start_directive(client, config, &text, no_stream),
None => start_interactive(client, config),
}
}
}
fn start_directive(client: ChatGptClient, config: SharedConfig, input: &str) -> Result<()> {
fn start_directive(
client: ChatGptClient,
config: SharedConfig,
input: &str,
no_stream: bool,
) -> Result<()> {
let mut file = config.borrow().open_message_file()?;
let prompt = config.borrow().get_prompt();
let output = client.send_message(input, prompt)?;
let output = output.trim();
if config.borrow().highlight && stdout().is_terminal() {
let mut markdown_render = MarkdownRender::new();
println!("{}", markdown_render.render(output))
let highlight = config.borrow().highlight && stdout().is_terminal();
let output = if no_stream {
let output = client.send_message(input, prompt)?;
if highlight {
let mut markdown_render = MarkdownRender::new();
println!("{}", markdown_render.render(&output));
} else {
println!("{output}");
}
output
} else {
println!("{output}");
}
config.borrow().save_message(file.as_mut(), input, output)
let wg = WaitGroup::new();
let abort = AbortSignal::new();
let abort_clone = abort.clone();
ctrlc::set_handler(move || {
abort_clone.set_ctrlc();
})
.expect("Error setting Ctrl-C handler");
let output = render_stream(input, None, &client, highlight, false, abort, wg.clone())?;
wg.wait();
output
};
config.borrow().save_message(file.as_mut(), input, &output)
}
fn start_interactive(client: ChatGptClient, config: SharedConfig) -> Result<()> {

@ -0,0 +1,37 @@
use super::MarkdownRender;
use crate::repl::{ReplyStreamEvent, SharedAbortSignal};
use crate::utils::dump;
use anyhow::Result;
use crossbeam::channel::Receiver;
pub fn cmd_render_stream(rx: Receiver<ReplyStreamEvent>, abort: SharedAbortSignal) -> Result<()> {
let mut buffer = String::new();
let mut markdown_render = MarkdownRender::new();
loop {
if abort.aborted() {
return Ok(());
}
if let Ok(evt) = rx.try_recv() {
match evt {
ReplyStreamEvent::Text(text) => {
if text.contains('\n') {
let text = format!("{buffer}{text}");
let mut lines: Vec<&str> = text.split('\n').collect();
buffer = lines.pop().unwrap_or_default().to_string();
let output = lines.join("\n");
dump(markdown_render.render(&output), 1);
} else {
buffer = format!("{buffer}{text}");
}
}
ReplyStreamEvent::Done => {
let output = markdown_render.render(&buffer);
dump(output, 2);
break;
}
}
}
}
Ok(())
}

@ -1,125 +1,44 @@
mod cmd;
mod markdown;
mod repl;
use self::cmd::cmd_render_stream;
pub use self::markdown::MarkdownRender;
use crate::repl::{ReplyStreamEvent, SharedAbortSignal};
use self::repl::repl_render_stream;
use crate::client::ChatGptClient;
use crate::repl::{ReplyStreamHandler, SharedAbortSignal};
use anyhow::Result;
use crossbeam::channel::Receiver;
use crossterm::{
cursor,
event::{self, Event, KeyCode, KeyModifiers},
queue, style,
terminal::{self, disable_raw_mode, enable_raw_mode},
};
use std::{
io::{self, Stdout, Write},
time::{Duration, Instant},
};
use unicode_width::UnicodeWidthStr;
pub fn render_stream(rx: Receiver<ReplyStreamEvent>, abort: SharedAbortSignal) -> Result<()> {
enable_raw_mode()?;
let mut stdout = io::stdout();
queue!(stdout, event::DisableMouseCapture)?;
let ret = render_stream_inner(rx, abort, &mut stdout);
queue!(stdout, event::DisableMouseCapture)?;
disable_raw_mode()?;
ret
}
pub fn render_stream_inner(
rx: Receiver<ReplyStreamEvent>,
use crossbeam::channel::unbounded;
use crossbeam::sync::WaitGroup;
use std::thread::spawn;
pub fn render_stream(
input: &str,
prompt: Option<String>,
client: &ChatGptClient,
highlight: bool,
repl: bool,
abort: SharedAbortSignal,
writer: &mut Stdout,
) -> Result<()> {
let mut last_tick = Instant::now();
let tick_rate = Duration::from_millis(100);
let mut buffer = String::new();
let mut markdown_render = MarkdownRender::new();
let terminal_columns = terminal::size()?.0;
loop {
if abort.aborted() {
return Ok(());
}
if let Ok(evt) = rx.try_recv() {
recover_cursor(writer, terminal_columns, &buffer)?;
match evt {
ReplyStreamEvent::Text(text) => {
if text.contains('\n') {
let text = format!("{buffer}{text}");
let mut lines: Vec<&str> = text.split('\n').collect();
buffer = lines.pop().unwrap_or_default().to_string();
let output = markdown_render.render(&lines.join("\n"));
for line in output.split('\n') {
queue!(
writer,
style::Print(line),
style::Print("\n"),
cursor::MoveLeft(terminal_columns),
)?;
}
queue!(writer, style::Print(&buffer),)?;
} else {
buffer = format!("{buffer}{text}");
let output = markdown_render.render_line_stateless(&buffer);
queue!(writer, style::Print(&output))?;
}
writer.flush()?;
}
ReplyStreamEvent::Done => {
let output = markdown_render.render_line_stateless(&buffer);
queue!(writer, style::Print(output), style::Print("\n"))?;
writer.flush()?;
break;
}
}
continue;
}
let timeout = tick_rate
.checked_sub(last_tick.elapsed())
.unwrap_or_else(|| Duration::from_secs(0));
if crossterm::event::poll(timeout)? {
if let Event::Key(key) = event::read()? {
match key.code {
KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => {
abort.set_ctrlc();
return Ok(());
}
KeyCode::Char('d') if key.modifiers == KeyModifiers::CONTROL => {
abort.set_ctrld();
return Ok(());
}
_ => {}
}
}
}
if last_tick.elapsed() >= tick_rate {
last_tick = Instant::now();
}
}
Ok(())
}
fn recover_cursor(writer: &mut Stdout, terminal_columns: u16, buffer: &str) -> Result<()> {
let buffer_rows = (buffer.width() as u16 + terminal_columns - 1) / terminal_columns;
let (_, row) = cursor::position()?;
if buffer_rows == 0 {
queue!(writer, cursor::MoveTo(0, row))?;
} else if row + 1 >= buffer_rows {
queue!(writer, cursor::MoveTo(0, row + 1 - buffer_rows))?;
wg: WaitGroup,
) -> Result<String> {
let mut stream_handler = if highlight {
let (tx, rx) = unbounded();
let abort_clone = abort.clone();
spawn(move || {
let _ = if repl {
repl_render_stream(rx, abort)
} else {
cmd_render_stream(rx, abort)
};
drop(wg);
});
ReplyStreamHandler::new(Some(tx), abort_clone)
} else {
queue!(
writer,
terminal::ScrollUp(buffer_rows - 1 - row),
cursor::MoveTo(0, 0)
)?;
}
Ok(())
drop(wg);
ReplyStreamHandler::new(None, abort)
};
client.send_message_streaming(input, prompt, &mut stream_handler)?;
let buffer = stream_handler.get_buffer();
Ok(buffer.to_string())
}

@ -0,0 +1,123 @@
use super::MarkdownRender;
use crate::repl::{ReplyStreamEvent, SharedAbortSignal};
use anyhow::Result;
use crossbeam::channel::Receiver;
use crossterm::{
cursor,
event::{self, Event, KeyCode, KeyModifiers},
queue, style,
terminal::{self, disable_raw_mode, enable_raw_mode},
};
use std::{
io::{self, Stdout, Write},
time::{Duration, Instant},
};
use unicode_width::UnicodeWidthStr;
pub fn repl_render_stream(rx: Receiver<ReplyStreamEvent>, abort: SharedAbortSignal) -> Result<()> {
enable_raw_mode()?;
let mut stdout = io::stdout();
queue!(stdout, event::DisableMouseCapture)?;
let ret = repl_render_stream_inner(rx, abort, &mut stdout);
queue!(stdout, event::DisableMouseCapture)?;
disable_raw_mode()?;
ret
}
fn repl_render_stream_inner(
rx: Receiver<ReplyStreamEvent>,
abort: SharedAbortSignal,
writer: &mut Stdout,
) -> Result<()> {
let mut last_tick = Instant::now();
let tick_rate = Duration::from_millis(100);
let mut buffer = String::new();
let mut markdown_render = MarkdownRender::new();
let terminal_columns = terminal::size()?.0;
loop {
if abort.aborted() {
return Ok(());
}
if let Ok(evt) = rx.try_recv() {
recover_cursor(writer, terminal_columns, &buffer)?;
match evt {
ReplyStreamEvent::Text(text) => {
if text.contains('\n') {
let text = format!("{buffer}{text}");
let mut lines: Vec<&str> = text.split('\n').collect();
buffer = lines.pop().unwrap_or_default().to_string();
let output = markdown_render.render(&lines.join("\n"));
for line in output.split('\n') {
queue!(
writer,
style::Print(line),
style::Print("\n"),
cursor::MoveLeft(terminal_columns),
)?;
}
queue!(writer, style::Print(&buffer),)?;
} else {
buffer = format!("{buffer}{text}");
let output = markdown_render.render_line_stateless(&buffer);
queue!(writer, style::Print(&output))?;
}
writer.flush()?;
}
ReplyStreamEvent::Done => {
let output = markdown_render.render_line_stateless(&buffer);
queue!(writer, style::Print(output), style::Print("\n"))?;
writer.flush()?;
break;
}
}
continue;
}
let timeout = tick_rate
.checked_sub(last_tick.elapsed())
.unwrap_or_else(|| Duration::from_secs(0));
if crossterm::event::poll(timeout)? {
if let Event::Key(key) = event::read()? {
match key.code {
KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => {
abort.set_ctrlc();
return Ok(());
}
KeyCode::Char('d') if key.modifiers == KeyModifiers::CONTROL => {
abort.set_ctrld();
return Ok(());
}
_ => {}
}
}
}
if last_tick.elapsed() >= tick_rate {
last_tick = Instant::now();
}
}
Ok(())
}
fn recover_cursor(writer: &mut Stdout, terminal_columns: u16, buffer: &str) -> Result<()> {
let buffer_rows = (buffer.width() as u16 + terminal_columns - 1) / terminal_columns;
let (_, row) = cursor::position()?;
if buffer_rows == 0 {
queue!(writer, cursor::MoveTo(0, row))?;
} else if row + 1 >= buffer_rows {
queue!(writer, cursor::MoveTo(0, row + 1 - buffer_rows))?;
} else {
queue!(
writer,
terminal::ScrollUp(buffer_rows - 1 - row),
cursor::MoveTo(0, 0)
)?;
}
Ok(())
}

@ -4,11 +4,10 @@ use crate::render::render_stream;
use crate::utils::dump;
use anyhow::Result;
use crossbeam::channel::{unbounded, Sender};
use crossbeam::channel::Sender;
use crossbeam::sync::WaitGroup;
use std::cell::RefCell;
use std::fs::File;
use std::thread::spawn;
use super::abort::SharedAbortSignal;
@ -59,23 +58,26 @@ impl ReplCmdHandler {
self.state.borrow_mut().reply.clear();
return Ok(());
}
let wg = WaitGroup::new();
let highlight = self.config.borrow().highlight;
let stream_handler = if highlight {
let (tx, rx) = unbounded();
let abort = self.abort.clone();
let wg = wg.clone();
spawn(move || {
let _ = render_stream(rx, abort);
drop(wg);
});
ReplyStreamHandler::new(Some(tx), self.abort.clone())
} else {
ReplyStreamHandler::new(None, self.abort.clone())
};
let ret = self.handle_send_stream(&input, stream_handler);
let prompt = self.config.borrow().get_prompt();
let wg = WaitGroup::new();
let ret = render_stream(
&input,
prompt,
&self.client,
highlight,
true,
self.abort.clone(),
wg.clone(),
);
wg.wait();
self.state.borrow_mut().reply = ret?;
let buffer = ret?;
self.config.borrow().save_message(
self.state.borrow_mut().save_file.as_mut(),
&input,
&buffer,
)?;
self.state.borrow_mut().reply = buffer;
}
ReplCmd::SetRole(name) => {
let output = self.config.borrow_mut().change_role(&name);
@ -100,23 +102,6 @@ impl ReplCmdHandler {
}
Ok(())
}
fn handle_send_stream(
&self,
input: &str,
mut stream_handler: ReplyStreamHandler,
) -> Result<String> {
let prompt = self.config.borrow().get_prompt();
self.client
.send_message_streaming(input, prompt, &mut stream_handler)?;
let buffer = stream_handler.get_buffer();
self.config.borrow().save_message(
self.state.borrow_mut().save_file.as_mut(),
input,
buffer,
)?;
Ok(buffer.to_string())
}
}
pub struct ReplyStreamHandler {

Loading…
Cancel
Save