feat: ask for confirmation when some rag documents fail to load (#760)

pull/761/head
sigoden 3 months ago committed by GitHub
parent 49b61129c9
commit c5b2be641a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -68,7 +68,7 @@ impl Agent {
println!("The agent has the documents, initializing RAG...");
let mut document_paths = vec![];
for path in &definition.documents {
if Rag::is_url_path(path) {
if is_url(path) {
document_paths.push(path.to_string());
} else {
let new_path = safe_join_path(&functions_dir, path)

@ -14,7 +14,7 @@ use anyhow::bail;
use anyhow::{anyhow, Context, Result};
use hnsw_rs::prelude::*;
use indexmap::{IndexMap, IndexSet};
use inquire::{required, validator::Validation, Select, Text};
use inquire::{required, validator::Validation, Confirm, Select, Text};
use path_absolutize::Absolutize;
use serde::{Deserialize, Serialize};
use serde_json::json;
@ -62,7 +62,7 @@ impl Rag {
let loaders = config.read().document_loaders.clone();
let spinner = create_spinner("Starting").await;
tokio::select! {
ret = rag.load_paths(loaders, &paths, Some(spinner.clone())) => {
ret = rag.sync_documents(loaders, &paths, Some(spinner.clone())) => {
spinner.stop();
ret?;
}
@ -114,7 +114,7 @@ impl Rag {
let spinner = create_spinner("Starting").await;
let paths = self.data.document_paths.clone();
tokio::select! {
ret = self.load_paths(loaders, &paths, Some(spinner.clone())) => {
ret = self.sync_documents(loaders, &paths, Some(spinner.clone())) => {
spinner.stop();
ret?;
}
@ -258,7 +258,7 @@ impl Rag {
Ok(output)
}
pub async fn load_paths<T: AsRef<str>>(
pub async fn sync_documents<T: AsRef<str>>(
&mut self,
loaders: HashMap<String, String>,
paths: &[T],
@ -271,22 +271,33 @@ impl Rag {
let mut document_paths = vec![];
let mut files = vec![];
let paths_len = paths.len();
let mut has_error = false;
for (index, path) in paths.iter().enumerate() {
let path = path.as_ref();
println!("Load {path} [{}/{paths_len}]", index + 1);
if Self::is_url_path(path) {
if let Some(path) = path.strip_suffix("**") {
files.extend(load_recursive_url(&loaders, path).await?);
} else {
files.push(load_url(&loaders, path).await?);
}
document_paths.push(path.to_string());
} else {
let path = Path::new(path);
let path = path.absolutize()?.display().to_string();
files.extend(load_path(&loaders, &path).await?);
match load_document(&loaders, path).await {
Ok((path, document_files)) => {
files.extend(document_files);
document_paths.push(path);
}
Err(err) => {
has_error = true;
println!("{}", warning_text(&format!("Error: {err:?}")));
}
}
}
if has_error {
let mut aborted = true;
if *IS_STDOUT_TERMINAL && !document_paths.is_empty() {
let ans = Confirm::new("Some documents failed to load. Continue?")
.with_default(false)
.prompt()?;
aborted = !ans;
}
if aborted {
bail!("Aborted");
}
}
let mut to_deleted: IndexMap<String, FileId> = Default::default();
@ -367,10 +378,6 @@ impl Rag {
Ok(())
}
pub fn is_url_path(path: &str) -> bool {
path.starts_with("http://") || path.starts_with("https://")
}
async fn hybird_search(
&self,
query: &str,
@ -708,6 +715,26 @@ fn add_documents() -> Result<Vec<String>> {
Ok(paths)
}
async fn load_document(
loaders: &HashMap<String, String>,
path: &str,
) -> Result<(String, Vec<(String, RagMetadata)>)> {
let mut files = vec![];
if is_url(path) {
if let Some(path) = path.strip_suffix("**") {
files.extend(load_recursive_url(loaders, path).await?);
} else {
files.push(load_url(loaders, path).await?);
}
Ok((path.to_string(), files))
} else {
let path = Path::new(path);
let path = path.absolutize()?.display().to_string();
files.extend(load_path(loaders, &path).await?);
Ok((path.to_string(), files))
}
}
fn progress(spinner: &Option<Spinner>, message: String) {
if let Some(spinner) = spinner {
let _ = spinner.set_message(message);

@ -145,6 +145,10 @@ pub fn temp_file(prefix: &str, suffix: &str) -> PathBuf {
))
}
pub fn is_url(path: &str) -> bool {
path.starts_with("http://") || path.starts_with("https://")
}
pub fn set_proxy(
builder: reqwest::ClientBuilder,
proxy: Option<&String>,

Loading…
Cancel
Save