From 1cc89eff514669d6459f62cc4eed8e0b34d8ef0c Mon Sep 17 00:00:00 2001 From: sigoden Date: Tue, 23 Apr 2024 14:32:06 +0800 Subject: [PATCH] refactor: more async code (#427) --- Cargo.lock | 80 +++++++----------------- Cargo.toml | 4 +- src/client/claude.rs | 11 ++-- src/client/cohere.rs | 4 +- src/client/common.rs | 120 ++++++++++++++++++++++-------------- src/client/ernie.rs | 7 ++- src/client/gemini.rs | 4 +- src/client/mod.rs | 2 + src/client/ollama.rs | 6 +- src/client/openai.rs | 4 +- src/client/qianwen.rs | 7 +-- src/client/reply_handler.rs | 65 +++++++++++++++++++ src/client/vertexai.rs | 4 +- src/config/mod.rs | 40 +++++++----- src/main.rs | 47 +++++++------- src/render/mod.rs | 119 ++++------------------------------- src/render/stream.rs | 72 +++++++++++----------- src/repl/mod.rs | 33 +++++----- src/utils/mod.rs | 10 +-- src/utils/spinner.rs | 20 +++--- 20 files changed, 315 insertions(+), 344 deletions(-) create mode 100644 src/client/reply_handler.rs diff --git a/Cargo.lock b/Cargo.lock index fd25eff..f0448ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -33,6 +33,7 @@ dependencies = [ "ansi_colours", "anyhow", "arboard", + "async-recursion", "async-trait", "base64 0.22.0", "bincode", @@ -41,7 +42,6 @@ dependencies = [ "bytes", "chrono", "clap", - "crossbeam", "crossterm 0.27.0", "dirs", "fancy-regex", @@ -164,6 +164,17 @@ dependencies = [ "x11rb", ] +[[package]] +name = "async-recursion" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30c5ef0ede93efbf733c1a727f3b6b5a1060bbedd5600183e66f6e4be4af0ec5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.79" @@ -419,62 +430,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "crossbeam" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" -dependencies = [ - "crossbeam-channel", - "crossbeam-deque", - "crossbeam-epoch", - "crossbeam-queue", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-channel" -version = "0.5.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-deque" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" -dependencies = [ - "crossbeam-epoch", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-epoch" -version = "0.9.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-queue" -version = "0.3.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-utils" -version = "0.8.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" - [[package]] name = "crossterm" version = "0.25.0" @@ -1234,6 +1189,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "num_threads" version = "0.1.7" @@ -2092,6 +2057,7 @@ dependencies = [ "bytes", "libc", "mio", + "num_cpus", "pin-project-lite", "signal-hook-registry", "socket2", diff --git a/Cargo.toml b/Cargo.toml index ebf977b..5e792e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,8 +22,7 @@ reedline = "0.31.0" serde = { version = "1.0.152", features = ["derive"] } serde_json = { version = "1.0.93", features = ["preserve_order"] } serde_yaml = "0.9.17" -tokio = { version = "1.34.0", features = ["rt", "time", "macros", "signal"] } -crossbeam = "0.8.2" +tokio = { version = "1.34.0", features = ["rt", "time", "macros", "signal", "rt-multi-thread"] } crossterm = "0.27.0" chrono = "0.4.23" bincode = "1.3.3" @@ -45,6 +44,7 @@ mime_guess = "2.0.4" sha2 = "0.10.8" bitflags = "2.4.1" unicode-width = "0.1.11" +async-recursion = "1.1.0" [dependencies.reqwest] version = "0.12.0" diff --git a/src/client/claude.rs b/src/client/claude.rs index 7a4dd36..4da5128 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -1,11 +1,10 @@ -use super::{patch_system_message, ClaudeClient, Client, ExtraConfig, Model, PromptType, SendData}; - -use crate::{ - client::{ImageUrl, MessageContent, MessageContentPart}, - render::ReplyHandler, - utils::PromptKind, +use super::{ + patch_system_message, ClaudeClient, Client, ExtraConfig, ImageUrl, MessageContent, + MessageContentPart, Model, PromptType, ReplyHandler, SendData, }; +use crate::utils::PromptKind; + use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; use futures_util::StreamExt; diff --git a/src/client/cohere.rs b/src/client/cohere.rs index 6f2e288..a92e238 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -1,9 +1,9 @@ use super::{ json_stream, message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model, - PromptType, SendData, + PromptType, ReplyHandler, SendData, }; -use crate::{render::ReplyHandler, utils::PromptKind}; +use crate::utils::PromptKind; use anyhow::{bail, Result}; use async_trait::async_trait; diff --git a/src/client/common.rs b/src/client/common.rs index 9171e3a..2206d21 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -1,12 +1,9 @@ -use super::{openai::OpenAIConfig, ClientConfig, Message, MessageContent, Model}; +use super::{openai::OpenAIConfig, ClientConfig, Message, MessageContent, Model, ReplyHandler}; use crate::{ config::{GlobalConfig, Input}, - render::ReplyHandler, - utils::{ - init_tokio_runtime, prompt_input_integer, prompt_input_string, tokenize, AbortSignal, - PromptKind, - }, + render::{render_error, render_stream}, + utils::{prompt_input_integer, prompt_input_string, tokenize, AbortSignal, PromptKind}, }; use anyhow::{Context, Result}; @@ -16,7 +13,7 @@ use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; use std::{env, future::Future, time::Duration}; -use tokio::time::sleep; +use tokio::{sync::mpsc::unbounded_channel, time::sleep}; #[macro_export] macro_rules! register_client { @@ -173,7 +170,7 @@ macro_rules! openai_compatible_client { async fn send_message_streaming_inner( &self, client: &reqwest::Client, - handler: &mut $crate::render::ReplyHandler, + handler: &mut $crate::client::ReplyHandler, data: $crate::client::SendData, ) -> Result<()> { let builder = self.request_builder(client, data)?; @@ -201,7 +198,7 @@ macro_rules! config_get_fn { } #[async_trait] -pub trait Client { +pub trait Client: Sync + Send { fn config(&self) -> (&GlobalConfig, &Option); fn models(&self) -> Vec; @@ -226,22 +223,24 @@ pub trait Client { Ok(client) } - fn send_message(&self, input: Input) -> Result { - init_tokio_runtime()?.block_on(async { - let global_config = self.config().0; - if global_config.read().dry_run { - let content = global_config.read().echo_messages(&input); - return Ok(content); - } - let client = self.build_client()?; - let data = global_config.read().prepare_send_data(&input, false)?; - self.send_message_inner(&client, data) - .await - .with_context(|| "Failed to get answer") - }) + async fn send_message(&self, input: Input) -> Result { + let global_config = self.config().0; + if global_config.read().dry_run { + let content = global_config.read().echo_messages(&input); + return Ok(content); + } + let client = self.build_client()?; + let data = global_config.read().prepare_send_data(&input, false)?; + self.send_message_inner(&client, data) + .await + .with_context(|| "Failed to get answer") } - fn send_message_streaming(&self, input: &Input, handler: &mut ReplyHandler) -> Result<()> { + async fn send_message_streaming( + &self, + input: &Input, + handler: &mut ReplyHandler, + ) -> Result<()> { async fn watch_abort(abort: AbortSignal) { loop { if abort.aborted() { @@ -252,32 +251,30 @@ pub trait Client { } let abort = handler.get_abort(); let input = input.clone(); - init_tokio_runtime()?.block_on(async move { - tokio::select! { - ret = async { - let global_config = self.config().0; - if global_config.read().dry_run { - let content = global_config.read().echo_messages(&input); - let tokens = tokenize(&content); - for token in tokens { - tokio::time::sleep(Duration::from_millis(10)).await; - handler.text(&token)?; - } - return Ok(()); + tokio::select! { + ret = async { + let global_config = self.config().0; + if global_config.read().dry_run { + let content = global_config.read().echo_messages(&input); + let tokens = tokenize(&content); + for token in tokens { + tokio::time::sleep(Duration::from_millis(10)).await; + handler.text(&token)?; } - let client = self.build_client()?; - let data = global_config.read().prepare_send_data(&input, true)?; - self.send_message_streaming_inner(&client, handler, data).await - } => { - handler.done()?; - ret.with_context(|| "Failed to get answer") + return Ok(()); } - _ = watch_abort(abort.clone()) => { - handler.done()?; - Ok(()) - }, + let client = self.build_client()?; + let data = global_config.read().prepare_send_data(&input, true)?; + self.send_message_streaming_inner(&client, handler, data).await + } => { + handler.done()?; + ret.with_context(|| "Failed to get answer") } - }) + _ = watch_abort(abort.clone()) => { + handler.done()?; + Ok(()) + }, + } } async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result; @@ -336,6 +333,37 @@ pub fn create_config(list: &[PromptType], client: &str) -> Result<(String, Value Ok((model, clients)) } +pub async fn send_stream( + input: &Input, + client: &dyn Client, + config: &GlobalConfig, + abort: AbortSignal, +) -> Result { + let (tx, rx) = unbounded_channel(); + let mut stream_handler = ReplyHandler::new(tx, abort.clone()); + + let (send_ret, rend_ret) = tokio::join!( + client.send_message_streaming(input, &mut stream_handler), + render_stream(rx, config, abort.clone()), + ); + if let Err(err) = rend_ret { + render_error(err, config.read().highlight); + } + let output = stream_handler.get_buffer().to_string(); + match send_ret { + Ok(_) => { + println!(); + Ok(output) + } + Err(err) => { + if !output.is_empty() { + println!(); + } + Err(err) + } + } +} + #[allow(unused)] pub async fn send_message_as_streaming( builder: RequestBuilder, diff --git a/src/client/ernie.rs b/src/client/ernie.rs index 060f89b..db6a969 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -1,6 +1,9 @@ -use super::{patch_system_message, Client, ErnieClient, ExtraConfig, Model, PromptType, SendData}; +use super::{ + patch_system_message, Client, ErnieClient, ExtraConfig, Model, PromptType, ReplyHandler, + SendData, +}; -use crate::{render::ReplyHandler, utils::PromptKind}; +use crate::utils::PromptKind; use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; diff --git a/src/client/gemini.rs b/src/client/gemini.rs index bcd10e3..0c60bee 100644 --- a/src/client/gemini.rs +++ b/src/client/gemini.rs @@ -1,7 +1,7 @@ use super::vertexai::{build_body, send_message, send_message_streaming}; -use super::{Client, ExtraConfig, GeminiClient, Model, PromptType, SendData}; +use super::{Client, ExtraConfig, GeminiClient, Model, PromptType, ReplyHandler, SendData}; -use crate::{render::ReplyHandler, utils::PromptKind}; +use crate::utils::PromptKind; use anyhow::Result; use async_trait::async_trait; diff --git a/src/client/mod.rs b/src/client/mod.rs index d49ed4e..bd85e74 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -2,10 +2,12 @@ mod common; mod message; mod model; +mod reply_handler; pub use common::*; pub use message::*; pub use model::*; +pub use reply_handler::*; register_client!( (openai, "openai", OpenAIConfig, OpenAIClient), diff --git a/src/client/ollama.rs b/src/client/ollama.rs index 2c51f44..e652634 100644 --- a/src/client/ollama.rs +++ b/src/client/ollama.rs @@ -1,9 +1,9 @@ use super::{ message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig, OllamaClient, - PromptType, SendData, + PromptType, ReplyHandler, SendData, }; -use crate::{render::ReplyHandler, utils::PromptKind}; +use crate::utils::PromptKind; use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; @@ -118,7 +118,7 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand while let Some(chunk) = stream.next().await { let chunk = chunk?; if chunk.is_empty() { - continue; + continue; } let data: Value = serde_json::from_slice(&chunk)?; if data["done"].is_boolean() { diff --git a/src/client/openai.rs b/src/client/openai.rs index 24c72cc..797ee98 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -1,6 +1,6 @@ -use super::{ExtraConfig, Model, OpenAIClient, PromptType, SendData}; +use super::{ExtraConfig, Model, OpenAIClient, PromptType, ReplyHandler, SendData}; -use crate::{render::ReplyHandler, utils::PromptKind}; +use crate::utils::PromptKind; use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index 031abe7..2034736 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -1,11 +1,8 @@ use super::{ - message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, SendData, + message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, ReplyHandler, SendData, }; -use crate::{ - render::ReplyHandler, - utils::{sha256sum, PromptKind}, -}; +use crate::utils::{sha256sum, PromptKind}; use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; diff --git a/src/client/reply_handler.rs b/src/client/reply_handler.rs new file mode 100644 index 0000000..e11ea1d --- /dev/null +++ b/src/client/reply_handler.rs @@ -0,0 +1,65 @@ +use crate::utils::AbortSignal; + +use anyhow::{Context, Result}; +use tokio::sync::mpsc::UnboundedSender; + +pub struct ReplyHandler { + sender: UnboundedSender, + buffer: String, + abort: AbortSignal, +} + +impl ReplyHandler { + pub fn new(sender: UnboundedSender, abort: AbortSignal) -> Self { + Self { + sender, + abort, + buffer: String::new(), + } + } + + pub fn text(&mut self, text: &str) -> Result<()> { + debug!("ReplyText: {}", text); + if text.is_empty() { + return Ok(()); + } + self.buffer.push_str(text); + let ret = self + .sender + .send(ReplyEvent::Text(text.to_string())) + .with_context(|| "Failed to send ReplyEvent:Text"); + self.safe_ret(ret)?; + Ok(()) + } + + pub fn done(&mut self) -> Result<()> { + debug!("ReplyDone"); + let ret = self + .sender + .send(ReplyEvent::Done) + .with_context(|| "Failed to send ReplyEvent::Done"); + self.safe_ret(ret)?; + Ok(()) + } + + pub fn get_buffer(&self) -> &str { + &self.buffer + } + + pub fn get_abort(&self) -> AbortSignal { + self.abort.clone() + } + + fn safe_ret(&self, ret: Result<()>) -> Result<()> { + if ret.is_err() && self.abort.aborted() { + return Ok(()); + } + ret + } +} + +#[derive(Debug)] +pub enum ReplyEvent { + Text(String), + Done, +} diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index ad39c16..88035ec 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -1,9 +1,9 @@ use super::{ json_stream, message::*, patch_system_message, Client, ExtraConfig, Model, PromptType, - SendData, VertexAIClient, + ReplyHandler, SendData, VertexAIClient, }; -use crate::{render::ReplyHandler, utils::PromptKind}; +use crate::utils::PromptKind; use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; diff --git a/src/config/mod.rs b/src/config/mod.rs index 9b771ec..f86cda1 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -41,6 +41,12 @@ const SESSIONS_DIR_NAME: &str = "sessions"; const CLIENTS_FIELD: &str = "clients"; +const SUMMARIZE_PROMPT: &str = + "Summarize the discussion briefly in 200 words or less to use as a prompt for future context."; +const SUMMARY_PROMPT: &str = "This is a summary of the chat history as a recap: "; +const LEFT_PROMPT: &str = "{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} "; +const RIGHT_PROMPT: &str = "{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}"; + #[derive(Debug, Clone, Deserialize)] #[serde(default)] pub struct Config { @@ -59,10 +65,10 @@ pub struct Config { pub prelude: Option, pub buffer_editor: Option, pub compress_threshold: usize, - pub summarize_prompt: String, - pub summary_prompt: String, - pub left_prompt: String, - pub right_prompt: String, + pub summarize_prompt: Option, + pub summary_prompt: Option, + pub left_prompt: Option, + pub right_prompt: Option, pub clients: Vec, #[serde(skip)] pub roles: Vec, @@ -95,12 +101,11 @@ impl Default for Config { prelude: None, buffer_editor: None, compress_threshold: 2000, - summarize_prompt: "Summarize the discussion briefly in 200 words or less to use as a prompt for future context.".to_string(), - summary_prompt: "This is a summary of the chat history as a recap: ".into(), - left_prompt: "{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} ".to_string(), - right_prompt: "{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}" - .to_string(), - clients: vec![ClientConfig::default()], + summarize_prompt: None, + summary_prompt: None, + left_prompt: None, + right_prompt: None, + clients: vec![], roles: vec![], role: None, session: None, @@ -677,10 +682,15 @@ impl Config { pub fn compress_session(&mut self, summary: &str) { if let Some(session) = self.session.as_mut() { - session.compress(format!("{}{}", self.summary_prompt, summary)); + let summary_prompt = self.summary_prompt.as_deref().unwrap_or(SUMMARY_PROMPT); + session.compress(format!("{}{}", summary_prompt, summary)); } } + pub fn summarize_prompt(&self) -> &str { + self.summarize_prompt.as_deref().unwrap_or(SUMMARIZE_PROMPT) + } + pub fn is_compressing_session(&self) -> bool { self.session .as_ref() @@ -728,12 +738,14 @@ impl Config { pub fn render_prompt_left(&self) -> String { let variables = self.generate_prompt_context(); - render_prompt(&self.left_prompt, &variables) + let left_prompt = self.left_prompt.as_deref().unwrap_or(LEFT_PROMPT); + render_prompt(left_prompt, &variables) } pub fn render_prompt_right(&self) -> String { let variables = self.generate_prompt_context(); - render_prompt(&self.right_prompt, &variables) + let right_prompt = self.right_prompt.as_deref().unwrap_or(RIGHT_PROMPT); + render_prompt(right_prompt, &variables) } pub fn prepare_send_data(&self, input: &Input, stream: bool) -> Result { @@ -940,7 +952,7 @@ impl Config { } } - if let Some(ClientConfig::OpenAIConfig(client_config)) = self.clients.get_mut(0) { + if let Some(ClientConfig::OpenAIConfig(client_config)) = self.clients.first_mut() { if let Some(api_key) = value.get("api_key").and_then(|v| v.as_str()) { client_config.api_key = Some(api_key.to_string()) } diff --git a/src/main.rs b/src/main.rs index 6d5fe9b..bc6b8f1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,12 +10,13 @@ extern crate log; mod utils; use crate::cli::Cli; -use crate::client::{ensure_model_capabilities, init_client, list_models}; +use crate::client::{ensure_model_capabilities, init_client, list_models, send_stream}; use crate::config::{Config, GlobalConfig, Input, CODE_ROLE, EXPLAIN_ROLE, SHELL_ROLE}; -use crate::render::{render_error, render_stream, MarkdownRender}; +use crate::render::{render_error, MarkdownRender}; use crate::repl::Repl; use crate::utils::{ - cl100k_base_singleton, create_abort_signal, extract_block, run_command, CODE_BLOCK_RE, + cl100k_base_singleton, create_abort_signal, extract_block, run_command, run_spinner, + CODE_BLOCK_RE, }; use anyhow::{bail, Result}; @@ -24,11 +25,12 @@ use inquire::{Select, Text}; use is_terminal::IsTerminal; use parking_lot::RwLock; use std::io::{stderr, stdin, stdout, Read}; -use std::sync::{mpsc, Arc}; -use std::{process, thread}; -use utils::run_spinner; +use std::process; +use std::sync::Arc; +use tokio::sync::oneshot; -fn main() -> Result<()> { +#[tokio::main] +async fn main() -> Result<()> { let cli = Cli::parse(); let text = cli.text(); let config = Arc::new(RwLock::new(Config::init(text.is_none())?)); @@ -91,7 +93,7 @@ fn main() -> Result<()> { if cli.execute { match input { Some(input) => { - execute(&config, input)?; + execute(&config, input).await?; return Ok(()); } None => bail!("No input text"), @@ -99,8 +101,8 @@ fn main() -> Result<()> { } config.write().apply_prelude()?; if let Err(err) = match input { - Some(input) => start_directive(&config, input, cli.no_stream, cli.code), - None => start_interactive(&config), + Some(input) => start_directive(&config, input, cli.no_stream, cli.code).await, + None => start_interactive(&config).await, } { let highlight = stderr().is_terminal() && config.read().highlight; render_error(err, highlight) @@ -108,7 +110,7 @@ fn main() -> Result<()> { Ok(()) } -fn start_directive( +async fn start_directive( config: &GlobalConfig, input: Input, no_stream: bool, @@ -120,7 +122,7 @@ fn start_directive( let is_terminal_stdout = stdout().is_terminal(); let extract_code = !is_terminal_stdout && code_mode; let output = if no_stream || extract_code { - let output = client.send_message(input.clone())?; + let output = client.send_message(input.clone()).await?; let output = if extract_code && output.trim_start().starts_with("```") { extract_block(&output) } else { @@ -136,7 +138,7 @@ fn start_directive( output } else { let abort = create_abort_signal(); - render_stream(&input, client.as_ref(), config, abort)? + send_stream(&input, client.as_ref(), config, abort).await? }; // Save the message/session config.write().save_message(input, &output)?; @@ -144,19 +146,20 @@ fn start_directive( Ok(()) } -fn start_interactive(config: &GlobalConfig) -> Result<()> { +async fn start_interactive(config: &GlobalConfig) -> Result<()> { cl100k_base_singleton(); let mut repl: Repl = Repl::init(config)?; - repl.run() + repl.run().await } -fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> { +#[async_recursion::async_recursion] +async fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> { let client = init_client(config)?; config.read().maybe_print_send_tokens(&input); - let (tx, rx) = mpsc::sync_channel::<()>(0); - thread::spawn(move || run_spinner(" Generating", rx)); - let ret = client.send_message(input.clone()); - tx.send(())?; + let (spinner_tx, spinner_rx) = oneshot::channel(); + tokio::spawn(run_spinner(" Generating", spinner_rx)); + let ret = client.send_message(input.clone()).await; + let _ = spinner_tx.send(()); let mut eval_str = ret?; if let Ok(true) = CODE_BLOCK_RE.is_match(&eval_str) { eval_str = extract_block(&eval_str); @@ -191,7 +194,7 @@ fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> { } let input = Input::from_str(&eval_str, config.read().input_context()); let abort = create_abort_signal(); - render_stream(&input, client.as_ref(), config, abort)?; + send_stream(&input, client.as_ref(), config, abort).await?; explain = true; continue; } @@ -202,7 +205,7 @@ fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> { input.text() ); input.set_text(text); - return execute(config, input); + return execute(config, input).await; } _ => {} } diff --git a/src/render/mod.rs b/src/render/mod.rs index dc4c081..146577b 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -4,61 +4,26 @@ mod stream; pub use self::markdown::{MarkdownRender, RenderOptions}; use self::stream::{markdown_stream, raw_stream}; -use crate::client::Client; -use crate::config::{GlobalConfig, Input}; use crate::utils::AbortSignal; +use crate::{client::ReplyEvent, config::GlobalConfig}; -use anyhow::{Context, Result}; -use crossbeam::channel::{unbounded, Sender}; -use crossbeam::sync::WaitGroup; +use anyhow::Result; use is_terminal::IsTerminal; use nu_ansi_term::{Color, Style}; use std::io::stdout; -use std::thread::spawn; +use tokio::sync::mpsc::UnboundedReceiver; -pub fn render_stream( - input: &Input, - client: &dyn Client, +pub async fn render_stream( + rx: UnboundedReceiver, config: &GlobalConfig, abort: AbortSignal, -) -> Result { - let wg = WaitGroup::new(); - let wg_cloned = wg.clone(); - let render_options = config.read().get_render_options()?; - let mut stream_handler = { - let (tx, rx) = unbounded(); - let abort_clone = abort.clone(); - let highlight = config.read().highlight; - spawn(move || { - let run = move || { - if stdout().is_terminal() { - let mut render = MarkdownRender::init(render_options)?; - markdown_stream(&rx, &mut render, &abort) - } else { - raw_stream(&rx, &abort) - } - }; - if let Err(err) = run() { - render_error(err, highlight); - } - drop(wg_cloned); - }); - ReplyHandler::new(tx, abort_clone) - }; - let ret = client.send_message_streaming(input, &mut stream_handler); - wg.wait(); - let output = stream_handler.get_buffer().to_string(); - match ret { - Ok(_) => { - println!(); - Ok(output) - } - Err(err) => { - if !output.is_empty() { - println!(); - } - Err(err) - } +) -> Result<()> { + if stdout().is_terminal() { + let render_options = config.read().get_render_options()?; + let mut render = MarkdownRender::init(render_options)?; + markdown_stream(rx, &mut render, &abort).await + } else { + raw_stream(rx, &abort).await } } @@ -71,63 +36,3 @@ pub fn render_error(err: anyhow::Error, highlight: bool) { eprintln!("{err}"); } } - -pub struct ReplyHandler { - sender: Sender, - buffer: String, - abort: AbortSignal, -} - -impl ReplyHandler { - pub fn new(sender: Sender, abort: AbortSignal) -> Self { - Self { - sender, - abort, - buffer: String::new(), - } - } - - pub fn text(&mut self, text: &str) -> Result<()> { - debug!("ReplyText: {}", text); - if text.is_empty() { - return Ok(()); - } - self.buffer.push_str(text); - let ret = self - .sender - .send(ReplyEvent::Text(text.to_string())) - .with_context(|| "Failed to send ReplyEvent:Text"); - self.safe_ret(ret)?; - Ok(()) - } - - pub fn done(&mut self) -> Result<()> { - debug!("ReplyDone"); - let ret = self - .sender - .send(ReplyEvent::Done) - .with_context(|| "Failed to send ReplyEvent::Done"); - self.safe_ret(ret)?; - Ok(()) - } - - pub fn get_buffer(&self) -> &str { - &self.buffer - } - - pub fn get_abort(&self) -> AbortSignal { - self.abort.clone() - } - - fn safe_ret(&self, ret: Result<()>) -> Result<()> { - if ret.is_err() && self.abort.aborted() { - return Ok(()); - } - ret - } -} - -pub enum ReplyEvent { - Text(String), - Done, -} diff --git a/src/render/stream.rs b/src/render/stream.rs index 9fdad5a..6007690 100644 --- a/src/render/stream.rs +++ b/src/render/stream.rs @@ -1,9 +1,8 @@ use super::{MarkdownRender, ReplyEvent}; -use crate::utils::{AbortSignal, Spinner}; +use crate::utils::{run_spinner, AbortSignal}; use anyhow::Result; -use crossbeam::channel::Receiver; use crossterm::{ cursor, event::{self, Event, KeyCode, KeyModifiers}, @@ -12,32 +11,32 @@ use crossterm::{ }; use std::{ io::{self, stdout, Stdout, Write}, - ops::Div, - time::{Duration, Instant}, + time::Duration, }; use textwrap::core::display_width; +use tokio::sync::{mpsc::UnboundedReceiver, oneshot}; -pub fn markdown_stream( - rx: &Receiver, +pub async fn markdown_stream( + rx: UnboundedReceiver, render: &mut MarkdownRender, abort: &AbortSignal, ) -> Result<()> { enable_raw_mode()?; let mut stdout = io::stdout(); - let ret = markdown_stream_inner(rx, render, abort, &mut stdout); + let ret = markdown_stream_inner(rx, render, abort, &mut stdout).await; disable_raw_mode()?; ret } -pub fn raw_stream(rx: &Receiver, abort: &AbortSignal) -> Result<()> { +pub async fn raw_stream(mut rx: UnboundedReceiver, abort: &AbortSignal) -> Result<()> { loop { if abort.aborted() { return Ok(()); } - if let Ok(evt) = rx.try_recv() { + if let Some(evt) = rx.recv().await { match evt { ReplyEvent::Text(text) => { print!("{}", text); @@ -52,30 +51,29 @@ pub fn raw_stream(rx: &Receiver, abort: &AbortSignal) -> Result<()> Ok(()) } -fn markdown_stream_inner( - rx: &Receiver, +async fn markdown_stream_inner( + mut rx: UnboundedReceiver, render: &mut MarkdownRender, abort: &AbortSignal, writer: &mut Stdout, ) -> Result<()> { - let mut last_tick = Instant::now(); - let tick_rate = Duration::from_millis(50); - let mut buffer = String::new(); let mut buffer_rows = 1; let columns = terminal::size()?.0; - let mut spinner = Spinner::new(" Generating"); + let (spinner_tx, spinner_rx) = oneshot::channel(); + let mut spinner_tx = Some(spinner_tx); + tokio::spawn(run_spinner(" Generating", spinner_rx)); 'outer: loop { if abort.aborted() { return Ok(()); } - spinner.step(writer)?; - - for reply_event in gather_events(rx) { - spinner.stop(writer)?; + for reply_event in gather_events(&mut rx).await { + if let Some(spinner_tx) = spinner_tx.take() { + let _ = spinner_tx.send(()); + } match reply_event { ReplyEvent::Text(mut text) => { @@ -135,10 +133,7 @@ fn markdown_stream_inner( } } - let timeout = tick_rate - .checked_sub(last_tick.elapsed()) - .unwrap_or_else(|| tick_rate.div(2)); - if crossterm::event::poll(timeout)? { + if crossterm::event::poll(Duration::from_millis(25))? { if let Event::Key(key) = event::read()? { match key.code { KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => { @@ -153,28 +148,31 @@ fn markdown_stream_inner( } } } - - if last_tick.elapsed() >= tick_rate { - last_tick = Instant::now(); - } } - spinner.stop(writer)?; - + if let Some(spinner_tx) = spinner_tx.take() { + let _ = spinner_tx.send(()); + } Ok(()) } -fn gather_events(rx: &Receiver) -> Vec { +async fn gather_events(rx: &mut UnboundedReceiver) -> Vec { let mut texts = vec![]; let mut done = false; - for reply_event in rx.try_iter() { - match reply_event { - ReplyEvent::Text(v) => texts.push(v), - ReplyEvent::Done => { - done = true; + tokio::select! { + _ = async { + while let Some(reply_event) = rx.recv().await { + match reply_event { + ReplyEvent::Text(v) => texts.push(v), + ReplyEvent::Done => { + done = true; + break; + } + } } - } - } + } => {} + _ = tokio::time::sleep(Duration::from_millis(50)) => {} + }; let mut events = vec![]; if !texts.is_empty() { events.push(ReplyEvent::Text(texts.join(""))) diff --git a/src/repl/mod.rs b/src/repl/mod.rs index 412d0be..3c0b98b 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -6,9 +6,9 @@ use self::completer::ReplCompleter; use self::highlighter::ReplHighlighter; use self::prompt::ReplPrompt; -use crate::client::{ensure_model_capabilities, init_client}; +use crate::client::{ensure_model_capabilities, init_client, send_stream}; use crate::config::{GlobalConfig, Input, InputContext, State}; -use crate::render::{render_error, render_stream}; +use crate::render::render_error; use crate::utils::{create_abort_signal, set_text, AbortSignal}; use anyhow::{bail, Context, Result}; @@ -93,7 +93,7 @@ impl Repl { }) } - pub fn run(&mut self) -> Result<()> { + pub async fn run(&mut self) -> Result<()> { self.banner(); loop { @@ -104,7 +104,7 @@ impl Repl { match sig { Ok(Signal::Success(line)) => { self.abort.reset(); - match self.handle(&line) { + match self.handle(&line).await { Ok(exit) => { if exit { break; @@ -127,11 +127,11 @@ impl Repl { _ => {} } } - self.handle(".exit session")?; + self.handle(".exit session").await?; Ok(()) } - fn handle(&self, mut line: &str) -> Result { + async fn handle(&self, mut line: &str) -> Result { if let Ok(Some(captures)) = MULTILINE_RE.captures(line) { if let Some(text_match) = captures.get(1) { line = text_match.as_str(); @@ -175,7 +175,7 @@ impl Repl { let role = self.config.read().retrieve_role(name.trim())?; let input = Input::from_str(text.trim(), InputContext::new(Some(role), false)); - self.ask(input)?; + self.ask(input).await?; } None => { self.config.write().set_role(args)?; @@ -220,7 +220,7 @@ impl Repl { }; let files = shell_words::split(files).with_context(|| "Invalid args")?; let input = Input::new(text, files, self.config.read().input_context())?; - self.ask(input)?; + self.ask(input).await?; } None => println!("Usage: .file ... [-- ...]"), }, @@ -246,7 +246,7 @@ impl Repl { }, None => { let input = Input::from_str(line, self.config.read().input_context()); - self.ask(input)?; + self.ask(input).await?; } } @@ -255,7 +255,7 @@ impl Repl { Ok(false) } - fn ask(&self, input: Input) -> Result<()> { + async fn ask(&self, input: Input) -> Result<()> { if input.is_empty() { return Ok(()); } @@ -265,7 +265,7 @@ impl Repl { self.config.read().maybe_print_send_tokens(&input); let mut client = init_client(&self.config)?; ensure_model_capabilities(client.as_mut(), input.required_capabilities())?; - let output = render_stream(&input, client.as_ref(), &self.config, self.abort.clone())?; + let output = send_stream(&input, client.as_ref(), &self.config, self.abort.clone()).await?; self.config.write().save_message(input, &output)?; self.config.read().maybe_copy(&output); if self.config.write().should_compress_session() { @@ -283,10 +283,9 @@ impl Repl { color.italic().paint("compress_threshold"), color.normal().paint("`."), ); - std::thread::spawn(move || -> anyhow::Result<()> { - let _ = compress_session(&config); + tokio::spawn(async move { + let _ = compress_session(&config).await; config.write().end_compressing_session(); - Ok(()) }); } Ok(()) @@ -443,14 +442,14 @@ fn parse_command(line: &str) -> Option<(&str, Option<&str>)> { } } -fn compress_session(config: &GlobalConfig) -> Result<()> { +async fn compress_session(config: &GlobalConfig) -> Result<()> { let input = Input::from_str( - &config.read().summarize_prompt, + config.read().summarize_prompt(), config.read().input_context(), ); let mut client = init_client(config)?; ensure_model_capabilities(client.as_mut(), input.required_capabilities())?; - let summary = client.send_message(input)?; + let summary = client.send_message(input).await?; config.write().compress_session(&summary); Ok(()) } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index ba438a7..38cf766 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -9,7 +9,7 @@ pub use self::abort_signal::{create_abort_signal, AbortSignal}; pub use self::clipboard::set_text; pub use self::prompt_input::*; pub use self::render_prompt::render_prompt; -pub use self::spinner::{run_spinner, Spinner}; +pub use self::spinner::run_spinner; pub use self::tiktoken::cl100k_base_singleton; use fancy_regex::Regex; @@ -82,14 +82,6 @@ pub fn light_theme_from_colorfgbg(colorfgbg: &str) -> Option { Some(light) } -pub fn init_tokio_runtime() -> anyhow::Result { - use anyhow::Context; - tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .with_context(|| "Failed to init tokio") -} - pub fn sha256sum(input: &str) -> String { let mut hasher = Sha256::new(); hasher.update(input); diff --git a/src/utils/spinner.rs b/src/utils/spinner.rs index 3715eff..0dd4d01 100644 --- a/src/utils/spinner.rs +++ b/src/utils/spinner.rs @@ -2,10 +2,9 @@ use anyhow::Result; use crossterm::{cursor, queue, style, terminal}; use std::{ io::{stdout, Stdout, Write}, - sync::mpsc, - thread, time::Duration, }; +use tokio::{sync::oneshot, time::interval}; pub struct Spinner { index: usize, @@ -56,17 +55,20 @@ impl Spinner { } } -pub fn run_spinner(message: &str, rx: mpsc::Receiver<()>) -> Result<()> { +pub async fn run_spinner(message: &str, rx: oneshot::Receiver<()>) -> Result<()> { let mut writer = stdout(); let mut spinner = Spinner::new(message); - loop { - spinner.step(&mut writer)?; - if let Ok(()) = rx.try_recv() { + let mut interval = interval(Duration::from_millis(50)); + tokio::select! { + _ = async { + loop { + interval.tick().await; + let _ = spinner.step(&mut writer); + } + } => {} + _ = rx => { spinner.stop(&mut writer)?; - break; } - thread::sleep(Duration::from_millis(50)) } - spinner.stop(&mut writer)?; Ok(()) }