feat: rag load websites (#655)

pull/656/head
sigoden 4 months ago committed by GitHub
parent 03b40036a1
commit 5985551aba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -53,8 +53,10 @@ rag_min_score_rerank: 0 # Specifies the minimum relevance sc
rag_document_loaders: rag_document_loaders:
# You can add more loaders, here is the syntax: # You can add more loaders, here is the syntax:
# <file-extension>: <command-to-load-the-file> # <file-extension>: <command-to-load-the-file>
pdf: 'pdftotext $1 -' # Load .pdf file pdf: 'pdftotext $1 -' # Load .pdf file, see https://poppler.freedesktop.org
docx: 'pandoc --to plain $1' # Load .docx file docx: 'pandoc --to plain $1' # Load .docx file
url: 'curl -fsSL $1' # Load url
# recursive_url: 'crawler $1 $2' # Load websites
# Defines the query structure using variables like __CONTEXT__ and __INPUT__ to tailor searches to specific needs # Defines the query structure using variables like __CONTEXT__ and __INPUT__ to tailor searches to specific needs
rag_template: | rag_template: |

@ -224,9 +224,9 @@ async fn shell_execute(config: &GlobalConfig, shell: &Shell, mut input: Input) -
let client = input.create_client()?; let client = input.create_client()?;
config.write().before_chat_completion(&input)?; config.write().before_chat_completion(&input)?;
let ret = if *IS_STDOUT_TERMINAL { let ret = if *IS_STDOUT_TERMINAL {
let (stop_spinner_tx, _) = run_spinner("Generating").await; let spinner = create_spinner("Generating").await;
let ret = client.chat_completions(input.clone()).await; let ret = client.chat_completions(input.clone()).await;
let _ = stop_spinner_tx.send(()); spinner.stop();
ret ret
} else { } else {
client.chat_completions(input.clone()).await client.chat_completions(input.clone()).await

@ -2,22 +2,43 @@ use super::*;
use anyhow::{bail, Context, Result}; use anyhow::{bail, Context, Result};
use async_recursion::async_recursion; use async_recursion::async_recursion;
use std::{collections::HashMap, fs::read_to_string, path::Path}; use serde_json::Value;
use std::{collections::HashMap, env, fs::read_to_string, path::Path};
pub fn load_file( pub const RECURSIVE_URL_LOADER: &str = "recursive_url";
pub fn load(
loaders: &HashMap<String, String>, loaders: &HashMap<String, String>,
path: &str, path: &str,
loader_name: &str, loader_name: &str,
) -> Result<Vec<RagDocument>> { ) -> Result<Vec<RagDocument>> {
match loaders.get(loader_name) { if loader_name == RECURSIVE_URL_LOADER {
Some(loader_command) => load_with_command(path, loader_name, loader_command), let loader_command = loaders
None => load_plain(path), .get(loader_name)
.with_context(|| format!("RAG document loader '{loader_name}' not configured"))?;
let contents = run_loader_command(path, loader_name, loader_command)?;
let output = match parse_json_documents(&contents) {
Some(v) => v,
None => vec![RagDocument::new(contents)],
};
Ok(output)
} else {
match loaders.get(loader_name) {
Some(loader_command) => load_with_command(path, loader_name, loader_command),
None => load_plain(path, loader_name),
}
} }
} }
fn load_plain(path: &str) -> Result<Vec<RagDocument>> { fn load_plain(path: &str, loader_name: &str) -> Result<Vec<RagDocument>> {
let contents = read_to_string(path)?; let contents = read_to_string(path)?;
let document = RagDocument::new(contents); if loader_name == "json" {
if let Some(documents) = parse_json_documents(&contents) {
return Ok(documents);
}
}
let mut document = RagDocument::new(contents);
document.metadata.insert("path".into(), path.to_string());
Ok(vec![document]) Ok(vec![document])
} }
@ -26,29 +47,135 @@ fn load_with_command(
loader_name: &str, loader_name: &str,
loader_command: &str, loader_command: &str,
) -> Result<Vec<RagDocument>> { ) -> Result<Vec<RagDocument>> {
let cmd_args = shell_words::split(loader_command) let contents = run_loader_command(path, loader_name, loader_command)?;
.with_context(|| anyhow!("Invalid rag loader '{loader_name}': `{loader_command}`"))?; let mut document = RagDocument::new(contents);
document.metadata.insert("path".into(), path.to_string());
Ok(vec![document])
}
fn run_loader_command(path: &str, loader_name: &str, loader_command: &str) -> Result<String> {
let cmd_args = shell_words::split(loader_command).with_context(|| {
anyhow!("Invalid rag document loader '{loader_name}': `{loader_command}`")
})?;
let mut use_stdout = true;
let outpath = env::temp_dir()
.join(format!("aichat-{}", sha256(path)))
.display()
.to_string();
let cmd_args: Vec<_> = cmd_args let cmd_args: Vec<_> = cmd_args
.into_iter() .into_iter()
.map(|v| if v == "$1" { path.to_string() } else { v }) .map(|mut v| {
if v.contains("$1") {
v = v.replace("$1", path);
}
if v.contains("$2") {
use_stdout = false;
v = v.replace("$2", &outpath);
}
v
})
.collect(); .collect();
let cmd_eval = shell_words::join(&cmd_args); let cmd_eval = shell_words::join(&cmd_args);
debug!("run `{cmd_eval}`");
let (cmd, args) = cmd_args.split_at(1); let (cmd, args) = cmd_args.split_at(1);
let cmd = &cmd[0]; let cmd = &cmd[0];
let (success, stdout, stderr) = if use_stdout {
run_command_with_output(cmd, args, None).with_context(|| { let (success, stdout, stderr) =
run_command_with_output(cmd, args, None).with_context(|| {
format!("Unable to run `{cmd_eval}`, Perhaps '{cmd}' is not installed?")
})?;
if !success {
let err = if !stderr.is_empty() {
stderr
} else {
format!("The command `{cmd_eval}` exited with non-zero.")
};
bail!("{err}")
}
Ok(stdout)
} else {
let status = run_command(cmd, args, None).with_context(|| {
format!("Unable to run `{cmd_eval}`, Perhaps '{cmd}' is not installed?") format!("Unable to run `{cmd_eval}`, Perhaps '{cmd}' is not installed?")
})?; })?;
if !success { if status != 0 {
let err = if !stderr.is_empty() { bail!("The command `{cmd_eval}` exited with non-zero.")
stderr }
} else { let contents =
format!("The command `{cmd_eval}` exited with non-zero.") read_to_string(&outpath).context("Failed to read file generated by the loader")?;
}; Ok(contents)
bail!("{err}") }
}
fn parse_json_documents(data: &str) -> Option<Vec<RagDocument>> {
let value: Value = serde_json::from_str(data).ok()?;
let items = match value {
Value::Array(v) => v,
_ => return None,
};
if items.is_empty() {
return None;
}
match &items[0] {
Value::String(_) => {
let documents: Vec<_> = items
.into_iter()
.flat_map(|item| {
if let Value::String(content) = item {
Some(RagDocument::new(content))
} else {
None
}
})
.collect();
Some(documents)
}
Value::Object(obj) => {
let key = [
"page_content",
"pageContent",
"content",
"html",
"markdown",
"text",
"data",
]
.into_iter()
.map(|v| v.to_string())
.find(|key| obj.get(key).and_then(|v| v.as_str()).is_some())?;
let documents: Vec<_> = items
.into_iter()
.flat_map(|item| {
if let Value::Object(mut obj) = item {
if let Some(page_content) = obj.get(&key).and_then(|v| v.as_str()) {
let page_content = page_content.to_string();
obj.remove(&key);
let metadata: IndexMap<_, _> = obj
.into_iter()
.map(|(k, v)| {
if let Value::String(v) = v {
(k, v)
} else {
(k, v.to_string())
}
})
.collect();
return Some(RagDocument {
page_content,
metadata,
});
}
}
None
})
.collect();
if documents.is_empty() {
None
} else {
Some(documents)
}
}
_ => None,
} }
let document = RagDocument::new(stdout);
Ok(vec![document])
} }
pub fn parse_glob(path_str: &str) -> Result<(String, Vec<String>)> { pub fn parse_glob(path_str: &str) -> Result<(String, Vec<String>)> {
@ -146,4 +273,36 @@ mod tests {
("C:\\dir".into(), vec!["md".into(), "txt".into()]) ("C:\\dir".into(), vec!["md".into(), "txt".into()])
); );
} }
#[test]
fn test_parse_json_documents() {
let data = r#"["foo", "bar"]"#;
assert_eq!(
parse_json_documents(data).unwrap(),
vec![RagDocument::new("foo"), RagDocument::new("bar")]
);
let data = r#"[{"content": "foo"}, {"content": "bar"}]"#;
assert_eq!(
parse_json_documents(data).unwrap(),
vec![RagDocument::new("foo"), RagDocument::new("bar")]
);
let mut metadata = IndexMap::new();
metadata.insert("k1".into(), "1".into());
let data = r#"[{"k1": 1, "data": "foo" }]"#;
assert_eq!(
parse_json_documents(data).unwrap(),
vec![RagDocument::new("foo").with_metadata(metadata.clone())]
);
let data = r#""hello""#;
assert!(parse_json_documents(data).is_none());
let data = r#"{"key":"value"}"#;
assert!(parse_json_documents(data).is_none());
let data = r#"[{"key":"value"}]"#;
assert!(parse_json_documents(data).is_none());
}
} }

@ -20,7 +20,6 @@ use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use std::collections::HashMap; use std::collections::HashMap;
use std::{fmt::Debug, io::BufReader, path::Path}; use std::{fmt::Debug, io::BufReader, path::Path};
use tokio::sync::mpsc;
pub struct Rag { pub struct Rag {
name: String, name: String,
@ -61,14 +60,14 @@ impl Rag {
}; };
debug!("doc paths: {paths:?}"); debug!("doc paths: {paths:?}");
let loaders = config.read().rag_document_loaders.clone(); let loaders = config.read().rag_document_loaders.clone();
let (stop_spinner_tx, set_spinner_message_tx) = run_spinner("Starting").await; let spinner = create_spinner("Starting").await;
tokio::select! { tokio::select! {
ret = rag.add_paths(loaders, &paths, Some(set_spinner_message_tx)) => { ret = rag.add_paths(loaders, &paths, Some(spinner.clone())) => {
let _ = stop_spinner_tx.send(()); spinner.stop();
ret?; ret?;
} }
_ = watch_abort_signal(abort_signal) => { _ = watch_abort_signal(abort_signal) => {
let _ = stop_spinner_tx.send(()); spinner.stop();
bail!("Aborted!") bail!("Aborted!")
}, },
}; };
@ -207,7 +206,7 @@ impl Rag {
rerank: Option<(Box<dyn Client>, f32)>, rerank: Option<(Box<dyn Client>, f32)>,
abort_signal: AbortSignal, abort_signal: AbortSignal,
) -> Result<String> { ) -> Result<String> {
let (stop_spinner_tx, _) = run_spinner("Searching").await; let spinner = create_spinner("Searching").await;
let ret = tokio::select! { let ret = tokio::select! {
ret = self.hybird_search(text, top_k, min_score_vector_search, min_score_keyword_search, rerank) => { ret = self.hybird_search(text, top_k, min_score_vector_search, min_score_keyword_search, rerank) => {
ret ret
@ -216,66 +215,99 @@ impl Rag {
bail!("Aborted!") bail!("Aborted!")
}, },
}; };
let _ = stop_spinner_tx.send(()); spinner.stop();
let output = ret?.join("\n\n"); let output = ret?.join("\n\n");
Ok(output) Ok(output)
} }
pub async fn add_paths<T: AsRef<Path>>( pub async fn add_paths<T: AsRef<str>>(
&mut self, &mut self,
loaders: HashMap<String, String>, loaders: HashMap<String, String>,
paths: &[T], paths: &[T],
progress_tx: Option<mpsc::UnboundedSender<String>>, spinner: Option<Spinner>,
) -> Result<()> { ) -> Result<()> {
let mut rag_files = vec![];
// List files // List files
let mut file_paths = vec![]; let mut new_paths = vec![];
progress(&progress_tx, "Listing paths".into()); progress(&spinner, "Gathering paths".into());
for path in paths { for path in paths {
let path = path let path = path.as_ref();
.as_ref() if path.starts_with("http://") || path.starts_with("https://") {
.absolutize() if let Some(path) = path.strip_suffix("**") {
.with_context(|| anyhow!("Invalid path '{}'", path.as_ref().display()))?; new_paths.push((path.to_string(), RECURSIVE_URL_LOADER.into()));
let path_str = path.display().to_string(); } else {
if self.data.files.iter().any(|v| v.path == path_str) { new_paths.push((path.to_string(), "url".into()))
continue; }
}
let (path_str, suffixes) = parse_glob(&path_str)?;
let suffixes = if suffixes.is_empty() {
None
} else { } else {
Some(&suffixes) let path = Path::new(path);
}; let path = path
list_files(&mut file_paths, Path::new(&path_str), suffixes).await?; .absolutize()
.with_context(|| anyhow!("Invalid path '{}'", path.display()))?;
let path_str = path.display().to_string();
if self.data.files.iter().any(|v| v.path == path_str) {
continue;
}
let (path_str, suffixes) = parse_glob(&path_str)?;
let suffixes = if suffixes.is_empty() {
None
} else {
Some(&suffixes)
};
let mut file_paths = vec![];
list_files(&mut file_paths, Path::new(&path_str), suffixes).await?;
for file_path in file_paths {
let loader_name = Path::new(&file_path)
.extension()
.map(|v| v.to_string_lossy().to_lowercase())
.unwrap_or_default();
new_paths.push((file_path, loader_name))
}
}
} }
// Load files // Load files
let mut rag_files = vec![]; let new_paths_len = new_paths.len();
let file_paths_len = file_paths.len(); if new_paths_len > 0 {
progress(&progress_tx, format!("Loading files [1/{file_paths_len}]")); if let Some(spinner) = &spinner {
for path in file_paths { let _ = spinner.set_message(String::new());
let extension = Path::new(&path) }
.extension() for (index, (path, loader_name)) in new_paths.into_iter().enumerate() {
.map(|v| v.to_string_lossy().to_lowercase()) println!("Loading {path} [{}/{new_paths_len}]", index + 1);
.unwrap_or_default(); let documents = load(&loaders, &path, &loader_name)
let separator = detect_separators(&extension); .with_context(|| format!("Failed to load '{path}'"))?;
let splitter = RecursiveCharacterTextSplitter::new( let separator = get_separators(&loader_name);
self.data.chunk_size, let splitter = RecursiveCharacterTextSplitter::new(
self.data.chunk_overlap, self.data.chunk_size,
&separator, self.data.chunk_overlap,
); &separator,
let documents = load_file(&loaders, &path, &extension) );
.with_context(|| format!("Failed to load file at '{path}'"))?; let splitted_documents: Vec<_> = documents
let split_options = SplitterChunkHeaderOptions::default().with_chunk_header(&format!( .into_iter()
"<document_metadata>\npath: {path}\n</document_metadata>\n\n" .flat_map(|document| {
)); let metadata = document
if !documents.is_empty() { .metadata
let documents = splitter.split_documents(&documents, &split_options); .iter()
rag_files.push(RagFile { path, documents }); .map(|(k, v)| format!("{k}: {v}\n"))
.collect::<Vec<String>>()
.join("");
let split_options = SplitterChunkHeaderOptions::default()
.with_chunk_header(&format!(
"<document_metadata>\n{metadata}</document_metadata>\n\n"
));
splitter.split_documents(&[document], &split_options)
})
.collect();
let display_path = if loader_name == RECURSIVE_URL_LOADER {
format!("{path}**")
} else {
path
};
rag_files.push(RagFile {
path: display_path,
documents: splitted_documents,
})
} }
progress(
&progress_tx,
format!("Loading files [{}/{file_paths_len}]", rag_files.len()),
);
} }
if rag_files.is_empty() { if rag_files.is_empty() {
@ -294,11 +326,11 @@ impl Rag {
let embeddings_data = EmbeddingsData::new(texts, false); let embeddings_data = EmbeddingsData::new(texts, false);
let embeddings = self let embeddings = self
.create_embeddings(embeddings_data, progress_tx.clone()) .create_embeddings(embeddings_data, spinner.clone())
.await?; .await?;
self.data.add(rag_files, vector_ids, embeddings); self.data.add(rag_files, vector_ids, embeddings);
progress(&progress_tx, "Building vector store".into()); progress(&spinner, "Building vector store".into());
self.hnsw = self.data.build_hnsw(); self.hnsw = self.data.build_hnsw();
self.bm25 = self.data.build_bm25(); self.bm25 = self.data.build_bm25();
@ -418,17 +450,17 @@ impl Rag {
async fn create_embeddings( async fn create_embeddings(
&self, &self,
data: EmbeddingsData, data: EmbeddingsData,
progress_tx: Option<mpsc::UnboundedSender<String>>, spinner: Option<Spinner>,
) -> Result<EmbeddingsOutput> { ) -> Result<EmbeddingsOutput> {
let EmbeddingsData { texts, query } = data; let EmbeddingsData { texts, query } = data;
let mut output = vec![]; let mut output = vec![];
let batch_chunks = texts.chunks(self.embedding_model.max_batch_size()); let batch_chunks = texts.chunks(self.embedding_model.max_batch_size());
let batch_chunks_len = batch_chunks.len(); let batch_chunks_len = batch_chunks.len();
progress(
&progress_tx,
format!("Creating embeddings [1/{batch_chunks_len}]"),
);
for (index, texts) in batch_chunks.enumerate() { for (index, texts) in batch_chunks.enumerate() {
progress(
&spinner,
format!("Creating embeddings [{}/{batch_chunks_len}]", index + 1),
);
let chunk_data = EmbeddingsData { let chunk_data = EmbeddingsData {
texts: texts.to_vec(), texts: texts.to_vec(),
query, query,
@ -439,10 +471,6 @@ impl Rag {
.await .await
.context("Failed to create embedding")?; .context("Failed to create embedding")?;
output.extend(chunk_output); output.extend(chunk_output);
progress(
&progress_tx,
format!("Creating embeddings [{}/{batch_chunks_len}]", index + 1),
);
} }
Ok(output) Ok(output)
} }
@ -510,7 +538,7 @@ pub struct RagFile {
documents: Vec<RagDocument>, documents: Vec<RagDocument>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RagDocument { pub struct RagDocument {
pub page_content: String, pub page_content: String,
pub metadata: RagMetadata, pub metadata: RagMetadata,
@ -603,15 +631,15 @@ fn set_chunk_overlay(default_value: usize) -> Result<usize> {
fn add_doc_paths() -> Result<Vec<String>> { fn add_doc_paths() -> Result<Vec<String>> {
let text = Text::new("Add document paths:") let text = Text::new("Add document paths:")
.with_validator(required!("This field is required")) .with_validator(required!("This field is required"))
.with_help_message("e.g. file1;dir2/;dir3/**/*.md") .with_help_message("e.g. file;dir/;dir/**/*.md;url;sites/**")
.prompt()?; .prompt()?;
let paths = text.split(';').map(|v| v.to_string()).collect(); let paths = text.split(';').map(|v| v.to_string()).collect();
Ok(paths) Ok(paths)
} }
fn progress(spinner_message_tx: &Option<mpsc::UnboundedSender<String>>, message: String) { fn progress(spinner: &Option<Spinner>, message: String) {
if let Some(tx) = spinner_message_tx { if let Some(spinner) = spinner {
let _ = tx.send(message); let _ = spinner.set_message(message);
} }
} }

@ -8,7 +8,7 @@ use std::cmp::Ordering;
pub const DEFAULT_SEPARATES: [&str; 4] = ["\n\n", "\n", " ", ""]; pub const DEFAULT_SEPARATES: [&str; 4] = ["\n\n", "\n", " ", ""];
pub fn detect_separators(extension: &str) -> Vec<&'static str> { pub fn get_separators(extension: &str) -> Vec<&'static str> {
match extension { match extension {
"c" | "cc" | "cpp" => Language::Cpp.separators(), "c" | "cc" | "cpp" => Language::Cpp.separators(),
"go" => Language::Go.separators(), "go" => Language::Go.separators(),
@ -149,16 +149,7 @@ impl RecursiveCharacterTextSplitter {
} }
let newlines_count = self.number_of_newlines(&chunk, 0, chunk.len()); let newlines_count = self.number_of_newlines(&chunk, 0, chunk.len());
let metadata = metadatas[i].clone();
let mut metadata = metadatas[i].clone();
metadata.insert(
"loc".into(),
format!(
"{}:{}",
line_counter_index,
line_counter_index + newlines_count
),
);
page_content += &chunk; page_content += &chunk;
documents.push(RagDocument { documents.push(RagDocument {
page_content, page_content,
@ -348,11 +339,8 @@ mod tests {
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use serde_json::{json, Value}; use serde_json::{json, Value};
fn build_metadata(source: &str, loc_from_line: usize, loc_to_line: usize) -> Value { fn build_metadata(source: &str) -> Value {
json!({ json!({ "source": source })
"source": source,
"loc": format!("{loc_from_line}:{loc_to_line}"),
})
} }
#[test] #[test]
fn test_split_text() { fn test_split_text() {
@ -385,15 +373,15 @@ mod tests {
json!([ json!([
{ {
"page_content": "foo", "page_content": "foo",
"metadata": build_metadata("1", 1, 1), "metadata": build_metadata("1"),
}, },
{ {
"page_content": "bar", "page_content": "bar",
"metadata": build_metadata("1", 1, 1), "metadata": build_metadata("1"),
}, },
{ {
"page_content": "baz", "page_content": "baz",
"metadata": build_metadata("2", 1, 1), "metadata": build_metadata("2"),
}, },
]) ])
); );
@ -420,15 +408,15 @@ mod tests {
json!([ json!([
{ {
"page_content": "SOURCE NAME: testing\n-----\nfoo", "page_content": "SOURCE NAME: testing\n-----\nfoo",
"metadata": build_metadata("1", 1, 1), "metadata": build_metadata("1"),
}, },
{ {
"page_content": "SOURCE NAME: testing\n-----\n(cont'd) bar", "page_content": "SOURCE NAME: testing\n-----\n(cont'd) bar",
"metadata": build_metadata("1", 1, 1), "metadata": build_metadata("1"),
}, },
{ {
"page_content": "SOURCE NAME: testing\n-----\nbaz", "page_content": "SOURCE NAME: testing\n-----\nbaz",
"metadata": build_metadata("2", 1, 1), "metadata": build_metadata("2"),
}, },
]) ])
); );

@ -1,6 +1,6 @@
use super::{MarkdownRender, SseEvent}; use super::{MarkdownRender, SseEvent};
use crate::utils::{run_spinner, AbortSignal}; use crate::utils::{create_spinner, AbortSignal};
use anyhow::Result; use anyhow::Result;
use crossterm::{ use crossterm::{
@ -62,16 +62,15 @@ async fn markdown_stream_inner(
let columns = terminal::size()?.0; let columns = terminal::size()?.0;
let (stop_spinner_tx, _) = run_spinner("Generating").await; let mut spinner = Some(create_spinner("Generating").await);
let mut stop_spinner_tx = Some(stop_spinner_tx);
'outer: loop { 'outer: loop {
if abort.aborted() { if abort.aborted() {
return Ok(()); return Ok(());
} }
for reply_event in gather_events(&mut rx).await { for reply_event in gather_events(&mut rx).await {
if let Some(stop_spinner_tx) = stop_spinner_tx.take() { if let Some(spinner) = spinner.take() {
let _ = stop_spinner_tx.send(()); spinner.stop();
} }
match reply_event { match reply_event {
@ -149,8 +148,8 @@ async fn markdown_stream_inner(
} }
} }
if let Some(stop_spinner_tx) = stop_spinner_tx.take() { if let Some(spinner) = spinner.take() {
let _ = stop_spinner_tx.send(()); spinner.stop();
} }
Ok(()) Ok(())
} }

@ -12,7 +12,7 @@ pub use self::command::*;
pub use self::crypto::*; pub use self::crypto::*;
pub use self::prompt_input::*; pub use self::prompt_input::*;
pub use self::render_prompt::render_prompt; pub use self::render_prompt::render_prompt;
pub use self::spinner::run_spinner; pub use self::spinner::{create_spinner, Spinner};
use fancy_regex::Regex; use fancy_regex::Regex;
use is_terminal::IsTerminal; use is_terminal::IsTerminal;

@ -5,46 +5,34 @@ use std::{
io::{stdout, Write}, io::{stdout, Write},
time::Duration, time::Duration,
}; };
use tokio::{ use tokio::{sync::mpsc, time::interval};
sync::{mpsc, oneshot},
time::interval,
};
pub struct Spinner { pub struct SpinnerInner {
index: usize, index: usize,
message: String, message: String,
stopped: bool, is_not_terminal: bool,
} }
impl Spinner { impl SpinnerInner {
const DATA: [&'static str; 10] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]; const DATA: [&'static str; 10] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"];
pub fn new(message: &str) -> Self { fn new(message: &str) -> Self {
Spinner { SpinnerInner {
index: 0, index: 0,
message: message.to_string(), message: message.to_string(),
stopped: false, is_not_terminal: !stdout().is_terminal(),
} }
} }
pub fn set_message(&mut self, message: &str) { fn step(&mut self) -> Result<()> {
self.message = format!(" {message}"); if self.is_not_terminal || self.message.is_empty() {
}
pub fn step(&mut self) -> Result<()> {
if self.stopped {
return Ok(()); return Ok(());
} }
let mut writer = stdout(); let mut writer = stdout();
let frame = Self::DATA[self.index % Self::DATA.len()]; let frame = Self::DATA[self.index % Self::DATA.len()];
let dots = ".".repeat((self.index / 5) % 4); let dots = ".".repeat((self.index / 5) % 4);
let line = format!("{frame}{}{:<3}", self.message, dots); let line = format!("{frame}{}{:<3}", self.message, dots);
queue!( queue!(writer, cursor::MoveToColumn(0), style::Print(line),)?;
writer,
cursor::MoveToColumn(0),
terminal::Clear(terminal::ClearType::FromCursorDown),
style::Print(line),
)?;
if self.index == 0 { if self.index == 0 {
queue!(writer, cursor::Hide)?; queue!(writer, cursor::Hide)?;
} }
@ -53,12 +41,20 @@ impl Spinner {
Ok(()) Ok(())
} }
pub fn stop(&mut self) -> Result<()> { fn set_message(&mut self, message: String) -> Result<()> {
if self.stopped { self.clear_message()?;
if !message.is_empty() {
self.message = format!(" {message}");
}
Ok(())
}
fn clear_message(&mut self) -> Result<()> {
if self.is_not_terminal || self.message.is_empty() {
return Ok(()); return Ok(());
} }
self.message.clear();
let mut writer = stdout(); let mut writer = stdout();
self.stopped = true;
queue!( queue!(
writer, writer,
cursor::MoveToColumn(0), cursor::MoveToColumn(0),
@ -70,43 +66,60 @@ impl Spinner {
} }
} }
pub async fn run_spinner(message: &str) -> (oneshot::Sender<()>, mpsc::UnboundedSender<String>) { #[derive(Clone)]
pub struct Spinner(mpsc::UnboundedSender<SpinnerEvent>);
impl Drop for Spinner {
fn drop(&mut self) {
self.stop();
}
}
impl Spinner {
pub fn set_message(&self, message: String) -> Result<()> {
self.0.send(SpinnerEvent::SetMessage(message))?;
Ok(())
}
pub fn stop(&self) {
let _ = self.0.send(SpinnerEvent::Stop);
}
}
enum SpinnerEvent {
SetMessage(String),
Stop,
}
pub async fn create_spinner(message: &str) -> Spinner {
let message = format!(" {message}"); let message = format!(" {message}");
let (stop_tx, stop_rx) = oneshot::channel(); let (tx, rx) = mpsc::unbounded_channel();
let (message_tx, message_rx) = mpsc::unbounded_channel(); tokio::spawn(run_spinner(message, rx));
tokio::spawn(run_spinner_inner(message, stop_rx, message_rx)); Spinner(tx)
(stop_tx, message_tx)
} }
async fn run_spinner_inner( async fn run_spinner(message: String, mut rx: mpsc::UnboundedReceiver<SpinnerEvent>) -> Result<()> {
message: String, let mut spinner = SpinnerInner::new(&message);
stop_rx: oneshot::Receiver<()>,
mut message_rx: mpsc::UnboundedReceiver<String>,
) -> Result<()> {
let is_stdout_terminal = stdout().is_terminal();
let mut spinner = Spinner::new(&message);
let mut interval = interval(Duration::from_millis(50)); let mut interval = interval(Duration::from_millis(50));
tokio::select! { loop {
_ = async { tokio::select! {
loop { _ = interval.tick() => {
tokio::select! { let _ = spinner.step();
_ = interval.tick() => { }
if is_stdout_terminal { evt = rx.recv() => {
let _ = spinner.step(); if let Some(evt) = evt {
match evt {
SpinnerEvent::SetMessage(message) => {
spinner.set_message(message)?;
} }
} SpinnerEvent::Stop => {
message = message_rx.recv() => { spinner.clear_message()?;
if let Some(message) = message { break;
spinner.set_message(&message);
} }
} }
} }
} }
} => {}
_ = stop_rx => {
if is_stdout_terminal {
spinner.stop()?;
}
} }
} }
Ok(()) Ok(())

Loading…
Cancel
Save