From d96950c23b8c8d63e11d576a0c9e3527695e35d3 Mon Sep 17 00:00:00 2001 From: sigoden Date: Sun, 8 Sep 2024 08:15:57 +0800 Subject: [PATCH] feat: tolerate failure to load some rag files (#846) --- src/rag/loader.rs | 45 ++++++++++++++++++++++++++++++++++++++++++--- src/rag/mod.rs | 33 +++------------------------------ 2 files changed, 45 insertions(+), 33 deletions(-) diff --git a/src/rag/loader.rs b/src/rag/loader.rs index f6c3b8e..2ea75a9 100644 --- a/src/rag/loader.rs +++ b/src/rag/loader.rs @@ -6,6 +6,38 @@ use std::collections::HashMap; pub const EXTENSION_METADATA: &str = "__extension__"; pub const PATH_METADATA: &str = "__path__"; +pub async fn load_document( + loaders: &HashMap, + path: &str, + has_error: &mut bool, +) -> (String, Vec<(String, RagMetadata)>) { + let mut maybe_error = None; + let mut files = vec![]; + if is_url(path) { + if let Some(path) = path.strip_suffix("**") { + match load_recursive_url(loaders, path).await { + Ok(v) => files.extend(v), + Err(err) => maybe_error = Some(err), + } + } else { + match load_url(loaders, path).await { + Ok(v) => files.push(v), + Err(err) => maybe_error = Some(err), + } + } + } else { + match load_path(loaders, path, has_error).await { + Ok(v) => files.extend(v), + Err(err) => maybe_error = Some(err), + } + } + if let Some(err) = maybe_error { + *has_error = true; + println!("{}", warning_text(&format!("⚠️ {err:?}"))); + } + (path.to_string(), files) +} + pub async fn load_recursive_url( loaders: &HashMap, path: &str, @@ -37,7 +69,9 @@ pub async fn load_recursive_url( pub async fn load_path( loaders: &HashMap, path: &str, + has_error: &mut bool, ) -> Result> { + let path = Path::new(path).absolutize()?.display().to_string(); let file_paths = expand_glob_paths(&[path]).await?; let mut output = vec![]; let file_paths_len = file_paths.len(); @@ -46,10 +80,15 @@ pub async fn load_path( 1 => output.push(load_file(loaders, &file_paths[0]).await?), _ => { for path in file_paths { - println!("🚀 Loading file {path}"); - output.push(load_file(loaders, &path).await?) + println!("Load {path}"); + match load_file(loaders, &path).await { + Ok(v) => output.push(v), + Err(err) => { + *has_error = true; + println!("{}", warning_text(&format!("Error: {err:?}"))); + } + } } - println!("✨ Load directory completed"); } } Ok(output) diff --git a/src/rag/mod.rs b/src/rag/mod.rs index 1b72867..aa1911e 100644 --- a/src/rag/mod.rs +++ b/src/rag/mod.rs @@ -275,16 +275,9 @@ impl Rag { for (index, path) in paths.iter().enumerate() { let path = path.as_ref(); println!("Load {path} [{}/{paths_len}]", index + 1); - match load_document(&loaders, path).await { - Ok((path, document_files)) => { - files.extend(document_files); - document_paths.push(path); - } - Err(err) => { - has_error = true; - println!("{}", warning_text(&format!("Error: {err:?}"))); - } - } + let (path, document_files) = load_document(&loaders, path, &mut has_error).await; + files.extend(document_files); + document_paths.push(path); } if has_error { @@ -729,26 +722,6 @@ fn add_documents() -> Result> { Ok(paths) } -async fn load_document( - loaders: &HashMap, - path: &str, -) -> Result<(String, Vec<(String, RagMetadata)>)> { - let mut files = vec![]; - if is_url(path) { - if let Some(path) = path.strip_suffix("**") { - files.extend(load_recursive_url(loaders, path).await?); - } else { - files.push(load_url(loaders, path).await?); - } - Ok((path.to_string(), files)) - } else { - let path = Path::new(path); - let path = path.absolutize()?.display().to_string(); - files.extend(load_path(loaders, &path).await?); - Ok((path.to_string(), files)) - } -} - fn progress(spinner: &Option, message: String) { if let Some(spinner) = spinner { let _ = spinner.set_message(message);