refactor: more async code (#427)

pull/428/head
sigoden 2 months ago committed by GitHub
parent c5506fe393
commit 1cc89eff51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

80
Cargo.lock generated

@ -33,6 +33,7 @@ dependencies = [
"ansi_colours",
"anyhow",
"arboard",
"async-recursion",
"async-trait",
"base64 0.22.0",
"bincode",
@ -41,7 +42,6 @@ dependencies = [
"bytes",
"chrono",
"clap",
"crossbeam",
"crossterm 0.27.0",
"dirs",
"fancy-regex",
@ -164,6 +164,17 @@ dependencies = [
"x11rb",
]
[[package]]
name = "async-recursion"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30c5ef0ede93efbf733c1a727f3b6b5a1060bbedd5600183e66f6e4be4af0ec5"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "async-trait"
version = "0.1.79"
@ -419,62 +430,6 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "crossbeam"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345"
[[package]]
name = "crossterm"
version = "0.25.0"
@ -1234,6 +1189,16 @@ dependencies = [
"autocfg",
]
[[package]]
name = "num_cpus"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
dependencies = [
"hermit-abi",
"libc",
]
[[package]]
name = "num_threads"
version = "0.1.7"
@ -2092,6 +2057,7 @@ dependencies = [
"bytes",
"libc",
"mio",
"num_cpus",
"pin-project-lite",
"signal-hook-registry",
"socket2",

@ -22,8 +22,7 @@ reedline = "0.31.0"
serde = { version = "1.0.152", features = ["derive"] }
serde_json = { version = "1.0.93", features = ["preserve_order"] }
serde_yaml = "0.9.17"
tokio = { version = "1.34.0", features = ["rt", "time", "macros", "signal"] }
crossbeam = "0.8.2"
tokio = { version = "1.34.0", features = ["rt", "time", "macros", "signal", "rt-multi-thread"] }
crossterm = "0.27.0"
chrono = "0.4.23"
bincode = "1.3.3"
@ -45,6 +44,7 @@ mime_guess = "2.0.4"
sha2 = "0.10.8"
bitflags = "2.4.1"
unicode-width = "0.1.11"
async-recursion = "1.1.0"
[dependencies.reqwest]
version = "0.12.0"

@ -1,11 +1,10 @@
use super::{patch_system_message, ClaudeClient, Client, ExtraConfig, Model, PromptType, SendData};
use crate::{
client::{ImageUrl, MessageContent, MessageContentPart},
render::ReplyHandler,
utils::PromptKind,
use super::{
patch_system_message, ClaudeClient, Client, ExtraConfig, ImageUrl, MessageContent,
MessageContentPart, Model, PromptType, ReplyHandler, SendData,
};
use crate::utils::PromptKind;
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
use futures_util::StreamExt;

@ -1,9 +1,9 @@
use super::{
json_stream, message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model,
PromptType, SendData,
PromptType, ReplyHandler, SendData,
};
use crate::{render::ReplyHandler, utils::PromptKind};
use crate::utils::PromptKind;
use anyhow::{bail, Result};
use async_trait::async_trait;

@ -1,12 +1,9 @@
use super::{openai::OpenAIConfig, ClientConfig, Message, MessageContent, Model};
use super::{openai::OpenAIConfig, ClientConfig, Message, MessageContent, Model, ReplyHandler};
use crate::{
config::{GlobalConfig, Input},
render::ReplyHandler,
utils::{
init_tokio_runtime, prompt_input_integer, prompt_input_string, tokenize, AbortSignal,
PromptKind,
},
render::{render_error, render_stream},
utils::{prompt_input_integer, prompt_input_string, tokenize, AbortSignal, PromptKind},
};
use anyhow::{Context, Result};
@ -16,7 +13,7 @@ use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use std::{env, future::Future, time::Duration};
use tokio::time::sleep;
use tokio::{sync::mpsc::unbounded_channel, time::sleep};
#[macro_export]
macro_rules! register_client {
@ -173,7 +170,7 @@ macro_rules! openai_compatible_client {
async fn send_message_streaming_inner(
&self,
client: &reqwest::Client,
handler: &mut $crate::render::ReplyHandler,
handler: &mut $crate::client::ReplyHandler,
data: $crate::client::SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
@ -201,7 +198,7 @@ macro_rules! config_get_fn {
}
#[async_trait]
pub trait Client {
pub trait Client: Sync + Send {
fn config(&self) -> (&GlobalConfig, &Option<ExtraConfig>);
fn models(&self) -> Vec<Model>;
@ -226,22 +223,24 @@ pub trait Client {
Ok(client)
}
fn send_message(&self, input: Input) -> Result<String> {
init_tokio_runtime()?.block_on(async {
let global_config = self.config().0;
if global_config.read().dry_run {
let content = global_config.read().echo_messages(&input);
return Ok(content);
}
let client = self.build_client()?;
let data = global_config.read().prepare_send_data(&input, false)?;
self.send_message_inner(&client, data)
.await
.with_context(|| "Failed to get answer")
})
async fn send_message(&self, input: Input) -> Result<String> {
let global_config = self.config().0;
if global_config.read().dry_run {
let content = global_config.read().echo_messages(&input);
return Ok(content);
}
let client = self.build_client()?;
let data = global_config.read().prepare_send_data(&input, false)?;
self.send_message_inner(&client, data)
.await
.with_context(|| "Failed to get answer")
}
fn send_message_streaming(&self, input: &Input, handler: &mut ReplyHandler) -> Result<()> {
async fn send_message_streaming(
&self,
input: &Input,
handler: &mut ReplyHandler,
) -> Result<()> {
async fn watch_abort(abort: AbortSignal) {
loop {
if abort.aborted() {
@ -252,32 +251,30 @@ pub trait Client {
}
let abort = handler.get_abort();
let input = input.clone();
init_tokio_runtime()?.block_on(async move {
tokio::select! {
ret = async {
let global_config = self.config().0;
if global_config.read().dry_run {
let content = global_config.read().echo_messages(&input);
let tokens = tokenize(&content);
for token in tokens {
tokio::time::sleep(Duration::from_millis(10)).await;
handler.text(&token)?;
}
return Ok(());
tokio::select! {
ret = async {
let global_config = self.config().0;
if global_config.read().dry_run {
let content = global_config.read().echo_messages(&input);
let tokens = tokenize(&content);
for token in tokens {
tokio::time::sleep(Duration::from_millis(10)).await;
handler.text(&token)?;
}
let client = self.build_client()?;
let data = global_config.read().prepare_send_data(&input, true)?;
self.send_message_streaming_inner(&client, handler, data).await
} => {
handler.done()?;
ret.with_context(|| "Failed to get answer")
return Ok(());
}
_ = watch_abort(abort.clone()) => {
handler.done()?;
Ok(())
},
let client = self.build_client()?;
let data = global_config.read().prepare_send_data(&input, true)?;
self.send_message_streaming_inner(&client, handler, data).await
} => {
handler.done()?;
ret.with_context(|| "Failed to get answer")
}
})
_ = watch_abort(abort.clone()) => {
handler.done()?;
Ok(())
},
}
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String>;
@ -336,6 +333,37 @@ pub fn create_config(list: &[PromptType], client: &str) -> Result<(String, Value
Ok((model, clients))
}
pub async fn send_stream(
input: &Input,
client: &dyn Client,
config: &GlobalConfig,
abort: AbortSignal,
) -> Result<String> {
let (tx, rx) = unbounded_channel();
let mut stream_handler = ReplyHandler::new(tx, abort.clone());
let (send_ret, rend_ret) = tokio::join!(
client.send_message_streaming(input, &mut stream_handler),
render_stream(rx, config, abort.clone()),
);
if let Err(err) = rend_ret {
render_error(err, config.read().highlight);
}
let output = stream_handler.get_buffer().to_string();
match send_ret {
Ok(_) => {
println!();
Ok(output)
}
Err(err) => {
if !output.is_empty() {
println!();
}
Err(err)
}
}
}
#[allow(unused)]
pub async fn send_message_as_streaming<F, Fut>(
builder: RequestBuilder,

@ -1,6 +1,9 @@
use super::{patch_system_message, Client, ErnieClient, ExtraConfig, Model, PromptType, SendData};
use super::{
patch_system_message, Client, ErnieClient, ExtraConfig, Model, PromptType, ReplyHandler,
SendData,
};
use crate::{render::ReplyHandler, utils::PromptKind};
use crate::utils::PromptKind;
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;

@ -1,7 +1,7 @@
use super::vertexai::{build_body, send_message, send_message_streaming};
use super::{Client, ExtraConfig, GeminiClient, Model, PromptType, SendData};
use super::{Client, ExtraConfig, GeminiClient, Model, PromptType, ReplyHandler, SendData};
use crate::{render::ReplyHandler, utils::PromptKind};
use crate::utils::PromptKind;
use anyhow::Result;
use async_trait::async_trait;

@ -2,10 +2,12 @@
mod common;
mod message;
mod model;
mod reply_handler;
pub use common::*;
pub use message::*;
pub use model::*;
pub use reply_handler::*;
register_client!(
(openai, "openai", OpenAIConfig, OpenAIClient),

@ -1,9 +1,9 @@
use super::{
message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig, OllamaClient,
PromptType, SendData,
PromptType, ReplyHandler, SendData,
};
use crate::{render::ReplyHandler, utils::PromptKind};
use crate::utils::PromptKind;
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
@ -118,7 +118,7 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
if chunk.is_empty() {
continue;
continue;
}
let data: Value = serde_json::from_slice(&chunk)?;
if data["done"].is_boolean() {

@ -1,6 +1,6 @@
use super::{ExtraConfig, Model, OpenAIClient, PromptType, SendData};
use super::{ExtraConfig, Model, OpenAIClient, PromptType, ReplyHandler, SendData};
use crate::{render::ReplyHandler, utils::PromptKind};
use crate::utils::PromptKind;
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;

@ -1,11 +1,8 @@
use super::{
message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, SendData,
message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, ReplyHandler, SendData,
};
use crate::{
render::ReplyHandler,
utils::{sha256sum, PromptKind},
};
use crate::utils::{sha256sum, PromptKind};
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;

@ -0,0 +1,65 @@
use crate::utils::AbortSignal;
use anyhow::{Context, Result};
use tokio::sync::mpsc::UnboundedSender;
pub struct ReplyHandler {
sender: UnboundedSender<ReplyEvent>,
buffer: String,
abort: AbortSignal,
}
impl ReplyHandler {
pub fn new(sender: UnboundedSender<ReplyEvent>, abort: AbortSignal) -> Self {
Self {
sender,
abort,
buffer: String::new(),
}
}
pub fn text(&mut self, text: &str) -> Result<()> {
debug!("ReplyText: {}", text);
if text.is_empty() {
return Ok(());
}
self.buffer.push_str(text);
let ret = self
.sender
.send(ReplyEvent::Text(text.to_string()))
.with_context(|| "Failed to send ReplyEvent:Text");
self.safe_ret(ret)?;
Ok(())
}
pub fn done(&mut self) -> Result<()> {
debug!("ReplyDone");
let ret = self
.sender
.send(ReplyEvent::Done)
.with_context(|| "Failed to send ReplyEvent::Done");
self.safe_ret(ret)?;
Ok(())
}
pub fn get_buffer(&self) -> &str {
&self.buffer
}
pub fn get_abort(&self) -> AbortSignal {
self.abort.clone()
}
fn safe_ret(&self, ret: Result<()>) -> Result<()> {
if ret.is_err() && self.abort.aborted() {
return Ok(());
}
ret
}
}
#[derive(Debug)]
pub enum ReplyEvent {
Text(String),
Done,
}

@ -1,9 +1,9 @@
use super::{
json_stream, message::*, patch_system_message, Client, ExtraConfig, Model, PromptType,
SendData, VertexAIClient,
ReplyHandler, SendData, VertexAIClient,
};
use crate::{render::ReplyHandler, utils::PromptKind};
use crate::utils::PromptKind;
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;

@ -41,6 +41,12 @@ const SESSIONS_DIR_NAME: &str = "sessions";
const CLIENTS_FIELD: &str = "clients";
const SUMMARIZE_PROMPT: &str =
"Summarize the discussion briefly in 200 words or less to use as a prompt for future context.";
const SUMMARY_PROMPT: &str = "This is a summary of the chat history as a recap: ";
const LEFT_PROMPT: &str = "{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} ";
const RIGHT_PROMPT: &str = "{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}";
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct Config {
@ -59,10 +65,10 @@ pub struct Config {
pub prelude: Option<String>,
pub buffer_editor: Option<String>,
pub compress_threshold: usize,
pub summarize_prompt: String,
pub summary_prompt: String,
pub left_prompt: String,
pub right_prompt: String,
pub summarize_prompt: Option<String>,
pub summary_prompt: Option<String>,
pub left_prompt: Option<String>,
pub right_prompt: Option<String>,
pub clients: Vec<ClientConfig>,
#[serde(skip)]
pub roles: Vec<Role>,
@ -95,12 +101,11 @@ impl Default for Config {
prelude: None,
buffer_editor: None,
compress_threshold: 2000,
summarize_prompt: "Summarize the discussion briefly in 200 words or less to use as a prompt for future context.".to_string(),
summary_prompt: "This is a summary of the chat history as a recap: ".into(),
left_prompt: "{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} ".to_string(),
right_prompt: "{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}"
.to_string(),
clients: vec![ClientConfig::default()],
summarize_prompt: None,
summary_prompt: None,
left_prompt: None,
right_prompt: None,
clients: vec![],
roles: vec![],
role: None,
session: None,
@ -677,10 +682,15 @@ impl Config {
pub fn compress_session(&mut self, summary: &str) {
if let Some(session) = self.session.as_mut() {
session.compress(format!("{}{}", self.summary_prompt, summary));
let summary_prompt = self.summary_prompt.as_deref().unwrap_or(SUMMARY_PROMPT);
session.compress(format!("{}{}", summary_prompt, summary));
}
}
pub fn summarize_prompt(&self) -> &str {
self.summarize_prompt.as_deref().unwrap_or(SUMMARIZE_PROMPT)
}
pub fn is_compressing_session(&self) -> bool {
self.session
.as_ref()
@ -728,12 +738,14 @@ impl Config {
pub fn render_prompt_left(&self) -> String {
let variables = self.generate_prompt_context();
render_prompt(&self.left_prompt, &variables)
let left_prompt = self.left_prompt.as_deref().unwrap_or(LEFT_PROMPT);
render_prompt(left_prompt, &variables)
}
pub fn render_prompt_right(&self) -> String {
let variables = self.generate_prompt_context();
render_prompt(&self.right_prompt, &variables)
let right_prompt = self.right_prompt.as_deref().unwrap_or(RIGHT_PROMPT);
render_prompt(right_prompt, &variables)
}
pub fn prepare_send_data(&self, input: &Input, stream: bool) -> Result<SendData> {
@ -940,7 +952,7 @@ impl Config {
}
}
if let Some(ClientConfig::OpenAIConfig(client_config)) = self.clients.get_mut(0) {
if let Some(ClientConfig::OpenAIConfig(client_config)) = self.clients.first_mut() {
if let Some(api_key) = value.get("api_key").and_then(|v| v.as_str()) {
client_config.api_key = Some(api_key.to_string())
}

@ -10,12 +10,13 @@ extern crate log;
mod utils;
use crate::cli::Cli;
use crate::client::{ensure_model_capabilities, init_client, list_models};
use crate::client::{ensure_model_capabilities, init_client, list_models, send_stream};
use crate::config::{Config, GlobalConfig, Input, CODE_ROLE, EXPLAIN_ROLE, SHELL_ROLE};
use crate::render::{render_error, render_stream, MarkdownRender};
use crate::render::{render_error, MarkdownRender};
use crate::repl::Repl;
use crate::utils::{
cl100k_base_singleton, create_abort_signal, extract_block, run_command, CODE_BLOCK_RE,
cl100k_base_singleton, create_abort_signal, extract_block, run_command, run_spinner,
CODE_BLOCK_RE,
};
use anyhow::{bail, Result};
@ -24,11 +25,12 @@ use inquire::{Select, Text};
use is_terminal::IsTerminal;
use parking_lot::RwLock;
use std::io::{stderr, stdin, stdout, Read};
use std::sync::{mpsc, Arc};
use std::{process, thread};
use utils::run_spinner;
use std::process;
use std::sync::Arc;
use tokio::sync::oneshot;
fn main() -> Result<()> {
#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
let text = cli.text();
let config = Arc::new(RwLock::new(Config::init(text.is_none())?));
@ -91,7 +93,7 @@ fn main() -> Result<()> {
if cli.execute {
match input {
Some(input) => {
execute(&config, input)?;
execute(&config, input).await?;
return Ok(());
}
None => bail!("No input text"),
@ -99,8 +101,8 @@ fn main() -> Result<()> {
}
config.write().apply_prelude()?;
if let Err(err) = match input {
Some(input) => start_directive(&config, input, cli.no_stream, cli.code),
None => start_interactive(&config),
Some(input) => start_directive(&config, input, cli.no_stream, cli.code).await,
None => start_interactive(&config).await,
} {
let highlight = stderr().is_terminal() && config.read().highlight;
render_error(err, highlight)
@ -108,7 +110,7 @@ fn main() -> Result<()> {
Ok(())
}
fn start_directive(
async fn start_directive(
config: &GlobalConfig,
input: Input,
no_stream: bool,
@ -120,7 +122,7 @@ fn start_directive(
let is_terminal_stdout = stdout().is_terminal();
let extract_code = !is_terminal_stdout && code_mode;
let output = if no_stream || extract_code {
let output = client.send_message(input.clone())?;
let output = client.send_message(input.clone()).await?;
let output = if extract_code && output.trim_start().starts_with("```") {
extract_block(&output)
} else {
@ -136,7 +138,7 @@ fn start_directive(
output
} else {
let abort = create_abort_signal();
render_stream(&input, client.as_ref(), config, abort)?
send_stream(&input, client.as_ref(), config, abort).await?
};
// Save the message/session
config.write().save_message(input, &output)?;
@ -144,19 +146,20 @@ fn start_directive(
Ok(())
}
fn start_interactive(config: &GlobalConfig) -> Result<()> {
async fn start_interactive(config: &GlobalConfig) -> Result<()> {
cl100k_base_singleton();
let mut repl: Repl = Repl::init(config)?;
repl.run()
repl.run().await
}
fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> {
#[async_recursion::async_recursion]
async fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> {
let client = init_client(config)?;
config.read().maybe_print_send_tokens(&input);
let (tx, rx) = mpsc::sync_channel::<()>(0);
thread::spawn(move || run_spinner(" Generating", rx));
let ret = client.send_message(input.clone());
tx.send(())?;
let (spinner_tx, spinner_rx) = oneshot::channel();
tokio::spawn(run_spinner(" Generating", spinner_rx));
let ret = client.send_message(input.clone()).await;
let _ = spinner_tx.send(());
let mut eval_str = ret?;
if let Ok(true) = CODE_BLOCK_RE.is_match(&eval_str) {
eval_str = extract_block(&eval_str);
@ -191,7 +194,7 @@ fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> {
}
let input = Input::from_str(&eval_str, config.read().input_context());
let abort = create_abort_signal();
render_stream(&input, client.as_ref(), config, abort)?;
send_stream(&input, client.as_ref(), config, abort).await?;
explain = true;
continue;
}
@ -202,7 +205,7 @@ fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> {
input.text()
);
input.set_text(text);
return execute(config, input);
return execute(config, input).await;
}
_ => {}
}

@ -4,61 +4,26 @@ mod stream;
pub use self::markdown::{MarkdownRender, RenderOptions};
use self::stream::{markdown_stream, raw_stream};
use crate::client::Client;
use crate::config::{GlobalConfig, Input};
use crate::utils::AbortSignal;
use crate::{client::ReplyEvent, config::GlobalConfig};
use anyhow::{Context, Result};
use crossbeam::channel::{unbounded, Sender};
use crossbeam::sync::WaitGroup;
use anyhow::Result;
use is_terminal::IsTerminal;
use nu_ansi_term::{Color, Style};
use std::io::stdout;
use std::thread::spawn;
use tokio::sync::mpsc::UnboundedReceiver;
pub fn render_stream(
input: &Input,
client: &dyn Client,
pub async fn render_stream(
rx: UnboundedReceiver<ReplyEvent>,
config: &GlobalConfig,
abort: AbortSignal,
) -> Result<String> {
let wg = WaitGroup::new();
let wg_cloned = wg.clone();
let render_options = config.read().get_render_options()?;
let mut stream_handler = {
let (tx, rx) = unbounded();
let abort_clone = abort.clone();
let highlight = config.read().highlight;
spawn(move || {
let run = move || {
if stdout().is_terminal() {
let mut render = MarkdownRender::init(render_options)?;
markdown_stream(&rx, &mut render, &abort)
} else {
raw_stream(&rx, &abort)
}
};
if let Err(err) = run() {
render_error(err, highlight);
}
drop(wg_cloned);
});
ReplyHandler::new(tx, abort_clone)
};
let ret = client.send_message_streaming(input, &mut stream_handler);
wg.wait();
let output = stream_handler.get_buffer().to_string();
match ret {
Ok(_) => {
println!();
Ok(output)
}
Err(err) => {
if !output.is_empty() {
println!();
}
Err(err)
}
) -> Result<()> {
if stdout().is_terminal() {
let render_options = config.read().get_render_options()?;
let mut render = MarkdownRender::init(render_options)?;
markdown_stream(rx, &mut render, &abort).await
} else {
raw_stream(rx, &abort).await
}
}
@ -71,63 +36,3 @@ pub fn render_error(err: anyhow::Error, highlight: bool) {
eprintln!("{err}");
}
}
pub struct ReplyHandler {
sender: Sender<ReplyEvent>,
buffer: String,
abort: AbortSignal,
}
impl ReplyHandler {
pub fn new(sender: Sender<ReplyEvent>, abort: AbortSignal) -> Self {
Self {
sender,
abort,
buffer: String::new(),
}
}
pub fn text(&mut self, text: &str) -> Result<()> {
debug!("ReplyText: {}", text);
if text.is_empty() {
return Ok(());
}
self.buffer.push_str(text);
let ret = self
.sender
.send(ReplyEvent::Text(text.to_string()))
.with_context(|| "Failed to send ReplyEvent:Text");
self.safe_ret(ret)?;
Ok(())
}
pub fn done(&mut self) -> Result<()> {
debug!("ReplyDone");
let ret = self
.sender
.send(ReplyEvent::Done)
.with_context(|| "Failed to send ReplyEvent::Done");
self.safe_ret(ret)?;
Ok(())
}
pub fn get_buffer(&self) -> &str {
&self.buffer
}
pub fn get_abort(&self) -> AbortSignal {
self.abort.clone()
}
fn safe_ret(&self, ret: Result<()>) -> Result<()> {
if ret.is_err() && self.abort.aborted() {
return Ok(());
}
ret
}
}
pub enum ReplyEvent {
Text(String),
Done,
}

@ -1,9 +1,8 @@
use super::{MarkdownRender, ReplyEvent};
use crate::utils::{AbortSignal, Spinner};
use crate::utils::{run_spinner, AbortSignal};
use anyhow::Result;
use crossbeam::channel::Receiver;
use crossterm::{
cursor,
event::{self, Event, KeyCode, KeyModifiers},
@ -12,32 +11,32 @@ use crossterm::{
};
use std::{
io::{self, stdout, Stdout, Write},
ops::Div,
time::{Duration, Instant},
time::Duration,
};
use textwrap::core::display_width;
use tokio::sync::{mpsc::UnboundedReceiver, oneshot};
pub fn markdown_stream(
rx: &Receiver<ReplyEvent>,
pub async fn markdown_stream(
rx: UnboundedReceiver<ReplyEvent>,
render: &mut MarkdownRender,
abort: &AbortSignal,
) -> Result<()> {
enable_raw_mode()?;
let mut stdout = io::stdout();
let ret = markdown_stream_inner(rx, render, abort, &mut stdout);
let ret = markdown_stream_inner(rx, render, abort, &mut stdout).await;
disable_raw_mode()?;
ret
}
pub fn raw_stream(rx: &Receiver<ReplyEvent>, abort: &AbortSignal) -> Result<()> {
pub async fn raw_stream(mut rx: UnboundedReceiver<ReplyEvent>, abort: &AbortSignal) -> Result<()> {
loop {
if abort.aborted() {
return Ok(());
}
if let Ok(evt) = rx.try_recv() {
if let Some(evt) = rx.recv().await {
match evt {
ReplyEvent::Text(text) => {
print!("{}", text);
@ -52,30 +51,29 @@ pub fn raw_stream(rx: &Receiver<ReplyEvent>, abort: &AbortSignal) -> Result<()>
Ok(())
}
fn markdown_stream_inner(
rx: &Receiver<ReplyEvent>,
async fn markdown_stream_inner(
mut rx: UnboundedReceiver<ReplyEvent>,
render: &mut MarkdownRender,
abort: &AbortSignal,
writer: &mut Stdout,
) -> Result<()> {
let mut last_tick = Instant::now();
let tick_rate = Duration::from_millis(50);
let mut buffer = String::new();
let mut buffer_rows = 1;
let columns = terminal::size()?.0;
let mut spinner = Spinner::new(" Generating");
let (spinner_tx, spinner_rx) = oneshot::channel();
let mut spinner_tx = Some(spinner_tx);
tokio::spawn(run_spinner(" Generating", spinner_rx));
'outer: loop {
if abort.aborted() {
return Ok(());
}
spinner.step(writer)?;
for reply_event in gather_events(rx) {
spinner.stop(writer)?;
for reply_event in gather_events(&mut rx).await {
if let Some(spinner_tx) = spinner_tx.take() {
let _ = spinner_tx.send(());
}
match reply_event {
ReplyEvent::Text(mut text) => {
@ -135,10 +133,7 @@ fn markdown_stream_inner(
}
}
let timeout = tick_rate
.checked_sub(last_tick.elapsed())
.unwrap_or_else(|| tick_rate.div(2));
if crossterm::event::poll(timeout)? {
if crossterm::event::poll(Duration::from_millis(25))? {
if let Event::Key(key) = event::read()? {
match key.code {
KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => {
@ -153,28 +148,31 @@ fn markdown_stream_inner(
}
}
}
if last_tick.elapsed() >= tick_rate {
last_tick = Instant::now();
}
}
spinner.stop(writer)?;
if let Some(spinner_tx) = spinner_tx.take() {
let _ = spinner_tx.send(());
}
Ok(())
}
fn gather_events(rx: &Receiver<ReplyEvent>) -> Vec<ReplyEvent> {
async fn gather_events(rx: &mut UnboundedReceiver<ReplyEvent>) -> Vec<ReplyEvent> {
let mut texts = vec![];
let mut done = false;
for reply_event in rx.try_iter() {
match reply_event {
ReplyEvent::Text(v) => texts.push(v),
ReplyEvent::Done => {
done = true;
tokio::select! {
_ = async {
while let Some(reply_event) = rx.recv().await {
match reply_event {
ReplyEvent::Text(v) => texts.push(v),
ReplyEvent::Done => {
done = true;
break;
}
}
}
}
}
} => {}
_ = tokio::time::sleep(Duration::from_millis(50)) => {}
};
let mut events = vec![];
if !texts.is_empty() {
events.push(ReplyEvent::Text(texts.join("")))

@ -6,9 +6,9 @@ use self::completer::ReplCompleter;
use self::highlighter::ReplHighlighter;
use self::prompt::ReplPrompt;
use crate::client::{ensure_model_capabilities, init_client};
use crate::client::{ensure_model_capabilities, init_client, send_stream};
use crate::config::{GlobalConfig, Input, InputContext, State};
use crate::render::{render_error, render_stream};
use crate::render::render_error;
use crate::utils::{create_abort_signal, set_text, AbortSignal};
use anyhow::{bail, Context, Result};
@ -93,7 +93,7 @@ impl Repl {
})
}
pub fn run(&mut self) -> Result<()> {
pub async fn run(&mut self) -> Result<()> {
self.banner();
loop {
@ -104,7 +104,7 @@ impl Repl {
match sig {
Ok(Signal::Success(line)) => {
self.abort.reset();
match self.handle(&line) {
match self.handle(&line).await {
Ok(exit) => {
if exit {
break;
@ -127,11 +127,11 @@ impl Repl {
_ => {}
}
}
self.handle(".exit session")?;
self.handle(".exit session").await?;
Ok(())
}
fn handle(&self, mut line: &str) -> Result<bool> {
async fn handle(&self, mut line: &str) -> Result<bool> {
if let Ok(Some(captures)) = MULTILINE_RE.captures(line) {
if let Some(text_match) = captures.get(1) {
line = text_match.as_str();
@ -175,7 +175,7 @@ impl Repl {
let role = self.config.read().retrieve_role(name.trim())?;
let input =
Input::from_str(text.trim(), InputContext::new(Some(role), false));
self.ask(input)?;
self.ask(input).await?;
}
None => {
self.config.write().set_role(args)?;
@ -220,7 +220,7 @@ impl Repl {
};
let files = shell_words::split(files).with_context(|| "Invalid args")?;
let input = Input::new(text, files, self.config.read().input_context())?;
self.ask(input)?;
self.ask(input).await?;
}
None => println!("Usage: .file <files>... [-- <text>...]"),
},
@ -246,7 +246,7 @@ impl Repl {
},
None => {
let input = Input::from_str(line, self.config.read().input_context());
self.ask(input)?;
self.ask(input).await?;
}
}
@ -255,7 +255,7 @@ impl Repl {
Ok(false)
}
fn ask(&self, input: Input) -> Result<()> {
async fn ask(&self, input: Input) -> Result<()> {
if input.is_empty() {
return Ok(());
}
@ -265,7 +265,7 @@ impl Repl {
self.config.read().maybe_print_send_tokens(&input);
let mut client = init_client(&self.config)?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
let output = render_stream(&input, client.as_ref(), &self.config, self.abort.clone())?;
let output = send_stream(&input, client.as_ref(), &self.config, self.abort.clone()).await?;
self.config.write().save_message(input, &output)?;
self.config.read().maybe_copy(&output);
if self.config.write().should_compress_session() {
@ -283,10 +283,9 @@ impl Repl {
color.italic().paint("compress_threshold"),
color.normal().paint("`."),
);
std::thread::spawn(move || -> anyhow::Result<()> {
let _ = compress_session(&config);
tokio::spawn(async move {
let _ = compress_session(&config).await;
config.write().end_compressing_session();
Ok(())
});
}
Ok(())
@ -443,14 +442,14 @@ fn parse_command(line: &str) -> Option<(&str, Option<&str>)> {
}
}
fn compress_session(config: &GlobalConfig) -> Result<()> {
async fn compress_session(config: &GlobalConfig) -> Result<()> {
let input = Input::from_str(
&config.read().summarize_prompt,
config.read().summarize_prompt(),
config.read().input_context(),
);
let mut client = init_client(config)?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
let summary = client.send_message(input)?;
let summary = client.send_message(input).await?;
config.write().compress_session(&summary);
Ok(())
}

@ -9,7 +9,7 @@ pub use self::abort_signal::{create_abort_signal, AbortSignal};
pub use self::clipboard::set_text;
pub use self::prompt_input::*;
pub use self::render_prompt::render_prompt;
pub use self::spinner::{run_spinner, Spinner};
pub use self::spinner::run_spinner;
pub use self::tiktoken::cl100k_base_singleton;
use fancy_regex::Regex;
@ -82,14 +82,6 @@ pub fn light_theme_from_colorfgbg(colorfgbg: &str) -> Option<bool> {
Some(light)
}
pub fn init_tokio_runtime() -> anyhow::Result<tokio::runtime::Runtime> {
use anyhow::Context;
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.with_context(|| "Failed to init tokio")
}
pub fn sha256sum(input: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(input);

@ -2,10 +2,9 @@ use anyhow::Result;
use crossterm::{cursor, queue, style, terminal};
use std::{
io::{stdout, Stdout, Write},
sync::mpsc,
thread,
time::Duration,
};
use tokio::{sync::oneshot, time::interval};
pub struct Spinner {
index: usize,
@ -56,17 +55,20 @@ impl Spinner {
}
}
pub fn run_spinner(message: &str, rx: mpsc::Receiver<()>) -> Result<()> {
pub async fn run_spinner(message: &str, rx: oneshot::Receiver<()>) -> Result<()> {
let mut writer = stdout();
let mut spinner = Spinner::new(message);
loop {
spinner.step(&mut writer)?;
if let Ok(()) = rx.try_recv() {
let mut interval = interval(Duration::from_millis(50));
tokio::select! {
_ = async {
loop {
interval.tick().await;
let _ = spinner.step(&mut writer);
}
} => {}
_ = rx => {
spinner.stop(&mut writer)?;
break;
}
thread::sleep(Duration::from_millis(50))
}
spinner.stop(&mut writer)?;
Ok(())
}

Loading…
Cancel
Save