|
|
|
@ -20,7 +20,6 @@ use serde::{Deserialize, Serialize};
|
|
|
|
|
use serde_json::json;
|
|
|
|
|
use std::collections::HashMap;
|
|
|
|
|
use std::{fmt::Debug, io::BufReader, path::Path};
|
|
|
|
|
use tokio::sync::mpsc;
|
|
|
|
|
|
|
|
|
|
pub struct Rag {
|
|
|
|
|
name: String,
|
|
|
|
@ -61,14 +60,14 @@ impl Rag {
|
|
|
|
|
};
|
|
|
|
|
debug!("doc paths: {paths:?}");
|
|
|
|
|
let loaders = config.read().rag_document_loaders.clone();
|
|
|
|
|
let (stop_spinner_tx, set_spinner_message_tx) = run_spinner("Starting").await;
|
|
|
|
|
let spinner = create_spinner("Starting").await;
|
|
|
|
|
tokio::select! {
|
|
|
|
|
ret = rag.add_paths(loaders, &paths, Some(set_spinner_message_tx)) => {
|
|
|
|
|
let _ = stop_spinner_tx.send(());
|
|
|
|
|
ret = rag.add_paths(loaders, &paths, Some(spinner.clone())) => {
|
|
|
|
|
spinner.stop();
|
|
|
|
|
ret?;
|
|
|
|
|
}
|
|
|
|
|
_ = watch_abort_signal(abort_signal) => {
|
|
|
|
|
let _ = stop_spinner_tx.send(());
|
|
|
|
|
spinner.stop();
|
|
|
|
|
bail!("Aborted!")
|
|
|
|
|
},
|
|
|
|
|
};
|
|
|
|
@ -207,7 +206,7 @@ impl Rag {
|
|
|
|
|
rerank: Option<(Box<dyn Client>, f32)>,
|
|
|
|
|
abort_signal: AbortSignal,
|
|
|
|
|
) -> Result<String> {
|
|
|
|
|
let (stop_spinner_tx, _) = run_spinner("Searching").await;
|
|
|
|
|
let spinner = create_spinner("Searching").await;
|
|
|
|
|
let ret = tokio::select! {
|
|
|
|
|
ret = self.hybird_search(text, top_k, min_score_vector_search, min_score_keyword_search, rerank) => {
|
|
|
|
|
ret
|
|
|
|
@ -216,66 +215,99 @@ impl Rag {
|
|
|
|
|
bail!("Aborted!")
|
|
|
|
|
},
|
|
|
|
|
};
|
|
|
|
|
let _ = stop_spinner_tx.send(());
|
|
|
|
|
spinner.stop();
|
|
|
|
|
let output = ret?.join("\n\n");
|
|
|
|
|
Ok(output)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub async fn add_paths<T: AsRef<Path>>(
|
|
|
|
|
pub async fn add_paths<T: AsRef<str>>(
|
|
|
|
|
&mut self,
|
|
|
|
|
loaders: HashMap<String, String>,
|
|
|
|
|
paths: &[T],
|
|
|
|
|
progress_tx: Option<mpsc::UnboundedSender<String>>,
|
|
|
|
|
spinner: Option<Spinner>,
|
|
|
|
|
) -> Result<()> {
|
|
|
|
|
let mut rag_files = vec![];
|
|
|
|
|
|
|
|
|
|
// List files
|
|
|
|
|
let mut file_paths = vec![];
|
|
|
|
|
progress(&progress_tx, "Listing paths".into());
|
|
|
|
|
let mut new_paths = vec![];
|
|
|
|
|
progress(&spinner, "Gathering paths".into());
|
|
|
|
|
for path in paths {
|
|
|
|
|
let path = path
|
|
|
|
|
.as_ref()
|
|
|
|
|
.absolutize()
|
|
|
|
|
.with_context(|| anyhow!("Invalid path '{}'", path.as_ref().display()))?;
|
|
|
|
|
let path_str = path.display().to_string();
|
|
|
|
|
if self.data.files.iter().any(|v| v.path == path_str) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
let (path_str, suffixes) = parse_glob(&path_str)?;
|
|
|
|
|
let suffixes = if suffixes.is_empty() {
|
|
|
|
|
None
|
|
|
|
|
let path = path.as_ref();
|
|
|
|
|
if path.starts_with("http://") || path.starts_with("https://") {
|
|
|
|
|
if let Some(path) = path.strip_suffix("**") {
|
|
|
|
|
new_paths.push((path.to_string(), RECURSIVE_URL_LOADER.into()));
|
|
|
|
|
} else {
|
|
|
|
|
new_paths.push((path.to_string(), "url".into()))
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
Some(&suffixes)
|
|
|
|
|
};
|
|
|
|
|
list_files(&mut file_paths, Path::new(&path_str), suffixes).await?;
|
|
|
|
|
let path = Path::new(path);
|
|
|
|
|
let path = path
|
|
|
|
|
.absolutize()
|
|
|
|
|
.with_context(|| anyhow!("Invalid path '{}'", path.display()))?;
|
|
|
|
|
let path_str = path.display().to_string();
|
|
|
|
|
if self.data.files.iter().any(|v| v.path == path_str) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
let (path_str, suffixes) = parse_glob(&path_str)?;
|
|
|
|
|
let suffixes = if suffixes.is_empty() {
|
|
|
|
|
None
|
|
|
|
|
} else {
|
|
|
|
|
Some(&suffixes)
|
|
|
|
|
};
|
|
|
|
|
let mut file_paths = vec![];
|
|
|
|
|
list_files(&mut file_paths, Path::new(&path_str), suffixes).await?;
|
|
|
|
|
for file_path in file_paths {
|
|
|
|
|
let loader_name = Path::new(&file_path)
|
|
|
|
|
.extension()
|
|
|
|
|
.map(|v| v.to_string_lossy().to_lowercase())
|
|
|
|
|
.unwrap_or_default();
|
|
|
|
|
new_paths.push((file_path, loader_name))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Load files
|
|
|
|
|
let mut rag_files = vec![];
|
|
|
|
|
let file_paths_len = file_paths.len();
|
|
|
|
|
progress(&progress_tx, format!("Loading files [1/{file_paths_len}]"));
|
|
|
|
|
for path in file_paths {
|
|
|
|
|
let extension = Path::new(&path)
|
|
|
|
|
.extension()
|
|
|
|
|
.map(|v| v.to_string_lossy().to_lowercase())
|
|
|
|
|
.unwrap_or_default();
|
|
|
|
|
let separator = detect_separators(&extension);
|
|
|
|
|
let splitter = RecursiveCharacterTextSplitter::new(
|
|
|
|
|
self.data.chunk_size,
|
|
|
|
|
self.data.chunk_overlap,
|
|
|
|
|
&separator,
|
|
|
|
|
);
|
|
|
|
|
let documents = load_file(&loaders, &path, &extension)
|
|
|
|
|
.with_context(|| format!("Failed to load file at '{path}'"))?;
|
|
|
|
|
let split_options = SplitterChunkHeaderOptions::default().with_chunk_header(&format!(
|
|
|
|
|
"<document_metadata>\npath: {path}\n</document_metadata>\n\n"
|
|
|
|
|
));
|
|
|
|
|
if !documents.is_empty() {
|
|
|
|
|
let documents = splitter.split_documents(&documents, &split_options);
|
|
|
|
|
rag_files.push(RagFile { path, documents });
|
|
|
|
|
let new_paths_len = new_paths.len();
|
|
|
|
|
if new_paths_len > 0 {
|
|
|
|
|
if let Some(spinner) = &spinner {
|
|
|
|
|
let _ = spinner.set_message(String::new());
|
|
|
|
|
}
|
|
|
|
|
for (index, (path, loader_name)) in new_paths.into_iter().enumerate() {
|
|
|
|
|
println!("Loading {path} [{}/{new_paths_len}]", index + 1);
|
|
|
|
|
let documents = load(&loaders, &path, &loader_name)
|
|
|
|
|
.with_context(|| format!("Failed to load '{path}'"))?;
|
|
|
|
|
let separator = get_separators(&loader_name);
|
|
|
|
|
let splitter = RecursiveCharacterTextSplitter::new(
|
|
|
|
|
self.data.chunk_size,
|
|
|
|
|
self.data.chunk_overlap,
|
|
|
|
|
&separator,
|
|
|
|
|
);
|
|
|
|
|
let splitted_documents: Vec<_> = documents
|
|
|
|
|
.into_iter()
|
|
|
|
|
.flat_map(|document| {
|
|
|
|
|
let metadata = document
|
|
|
|
|
.metadata
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|(k, v)| format!("{k}: {v}\n"))
|
|
|
|
|
.collect::<Vec<String>>()
|
|
|
|
|
.join("");
|
|
|
|
|
let split_options = SplitterChunkHeaderOptions::default()
|
|
|
|
|
.with_chunk_header(&format!(
|
|
|
|
|
"<document_metadata>\n{metadata}</document_metadata>\n\n"
|
|
|
|
|
));
|
|
|
|
|
splitter.split_documents(&[document], &split_options)
|
|
|
|
|
})
|
|
|
|
|
.collect();
|
|
|
|
|
let display_path = if loader_name == RECURSIVE_URL_LOADER {
|
|
|
|
|
format!("{path}**")
|
|
|
|
|
} else {
|
|
|
|
|
path
|
|
|
|
|
};
|
|
|
|
|
rag_files.push(RagFile {
|
|
|
|
|
path: display_path,
|
|
|
|
|
documents: splitted_documents,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
progress(
|
|
|
|
|
&progress_tx,
|
|
|
|
|
format!("Loading files [{}/{file_paths_len}]", rag_files.len()),
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if rag_files.is_empty() {
|
|
|
|
@ -294,11 +326,11 @@ impl Rag {
|
|
|
|
|
|
|
|
|
|
let embeddings_data = EmbeddingsData::new(texts, false);
|
|
|
|
|
let embeddings = self
|
|
|
|
|
.create_embeddings(embeddings_data, progress_tx.clone())
|
|
|
|
|
.create_embeddings(embeddings_data, spinner.clone())
|
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
|
|
self.data.add(rag_files, vector_ids, embeddings);
|
|
|
|
|
progress(&progress_tx, "Building vector store".into());
|
|
|
|
|
progress(&spinner, "Building vector store".into());
|
|
|
|
|
self.hnsw = self.data.build_hnsw();
|
|
|
|
|
self.bm25 = self.data.build_bm25();
|
|
|
|
|
|
|
|
|
@ -418,17 +450,17 @@ impl Rag {
|
|
|
|
|
async fn create_embeddings(
|
|
|
|
|
&self,
|
|
|
|
|
data: EmbeddingsData,
|
|
|
|
|
progress_tx: Option<mpsc::UnboundedSender<String>>,
|
|
|
|
|
spinner: Option<Spinner>,
|
|
|
|
|
) -> Result<EmbeddingsOutput> {
|
|
|
|
|
let EmbeddingsData { texts, query } = data;
|
|
|
|
|
let mut output = vec![];
|
|
|
|
|
let batch_chunks = texts.chunks(self.embedding_model.max_batch_size());
|
|
|
|
|
let batch_chunks_len = batch_chunks.len();
|
|
|
|
|
progress(
|
|
|
|
|
&progress_tx,
|
|
|
|
|
format!("Creating embeddings [1/{batch_chunks_len}]"),
|
|
|
|
|
);
|
|
|
|
|
for (index, texts) in batch_chunks.enumerate() {
|
|
|
|
|
progress(
|
|
|
|
|
&spinner,
|
|
|
|
|
format!("Creating embeddings [{}/{batch_chunks_len}]", index + 1),
|
|
|
|
|
);
|
|
|
|
|
let chunk_data = EmbeddingsData {
|
|
|
|
|
texts: texts.to_vec(),
|
|
|
|
|
query,
|
|
|
|
@ -439,10 +471,6 @@ impl Rag {
|
|
|
|
|
.await
|
|
|
|
|
.context("Failed to create embedding")?;
|
|
|
|
|
output.extend(chunk_output);
|
|
|
|
|
progress(
|
|
|
|
|
&progress_tx,
|
|
|
|
|
format!("Creating embeddings [{}/{batch_chunks_len}]", index + 1),
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
Ok(output)
|
|
|
|
|
}
|
|
|
|
@ -510,7 +538,7 @@ pub struct RagFile {
|
|
|
|
|
documents: Vec<RagDocument>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
|
|
|
pub struct RagDocument {
|
|
|
|
|
pub page_content: String,
|
|
|
|
|
pub metadata: RagMetadata,
|
|
|
|
@ -603,15 +631,15 @@ fn set_chunk_overlay(default_value: usize) -> Result<usize> {
|
|
|
|
|
fn add_doc_paths() -> Result<Vec<String>> {
|
|
|
|
|
let text = Text::new("Add document paths:")
|
|
|
|
|
.with_validator(required!("This field is required"))
|
|
|
|
|
.with_help_message("e.g. file1;dir2/;dir3/**/*.md")
|
|
|
|
|
.with_help_message("e.g. file;dir/;dir/**/*.md;url;sites/**")
|
|
|
|
|
.prompt()?;
|
|
|
|
|
let paths = text.split(';').map(|v| v.to_string()).collect();
|
|
|
|
|
Ok(paths)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn progress(spinner_message_tx: &Option<mpsc::UnboundedSender<String>>, message: String) {
|
|
|
|
|
if let Some(tx) = spinner_message_tx {
|
|
|
|
|
let _ = tx.send(message);
|
|
|
|
|
fn progress(spinner: &Option<Spinner>, message: String) {
|
|
|
|
|
if let Some(spinner) = spinner {
|
|
|
|
|
let _ = spinner.set_message(message);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|