refactor: improve RAG (#675)

pull/676/head
sigoden 4 months ago committed by GitHub
parent 7c6dac061b
commit 2f2b13c891
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -34,7 +34,7 @@ document_loaders:
docx: 'pandoc --to plain $1' # Load .docx file docx: 'pandoc --to plain $1' # Load .docx file
# xlsx: 'ssconvert $1 $2' # Load .xlsx file # xlsx: 'ssconvert $1 $2' # Load .xlsx file
# html: 'pandoc --to plain $1' # Load .html file # html: 'pandoc --to plain $1' # Load .html file
# recursive_url: 'rag-crawler $1 $2' # Load websites, see https://github.com/sigoden/rag-crawler recursive_url: 'rag-crawler $1 $2' # Load websites, see https://github.com/sigoden/rag-crawler
# ---- function-calling & agent ---- # ---- function-calling & agent ----
# Controls the function calling feature. For setup instructions, visit https://github.com/sigoden/llm-functions # Controls the function calling feature. For setup instructions, visit https://github.com/sigoden/llm-functions

@ -1446,12 +1446,16 @@ impl Config {
} }
fn setup_document_loaders(&mut self) { fn setup_document_loaders(&mut self) {
[("pdf", "pdftotext $1 -"), ("docx", "pandoc --to plain $1")] [
.into_iter() ("pdf", "pdftotext $1 -"),
.for_each(|(k, v)| { ("docx", "pandoc --to plain $1"),
let (k, v) = (k.to_string(), v.to_string()); (RECURSIVE_URL_LOADER, "rag-crawler $1 $2"),
self.document_loaders.entry(k).or_insert(v); ]
}); .into_iter()
.for_each(|(k, v)| {
let (k, v) = (k.to_string(), v.to_string());
self.document_loaders.entry(k).or_insert(v);
});
} }
} }

@ -14,7 +14,7 @@ pub async fn load_recrusive_url(
let extension = RECURSIVE_URL_LOADER; let extension = RECURSIVE_URL_LOADER;
let loader_command = loaders let loader_command = loaders
.get(extension) .get(extension)
.with_context(|| format!("RAG document loader '{extension}' not configured"))?; .with_context(|| format!("Document loader '{extension}' not configured"))?;
let contents = run_loader_command(path, extension, loader_command)?; let contents = run_loader_command(path, extension, loader_command)?;
let pages: Vec<WebPage> = serde_json::from_str(&contents).context(r#"The crawler response is invalid. It should follow the JSON format: `[{"path":"...", "text":"..."}]`."#)?; let pages: Vec<WebPage> = serde_json::from_str(&contents).context(r#"The crawler response is invalid. It should follow the JSON format: `[{"path":"...", "text":"..."}]`."#)?;
let output = pages let output = pages

@ -4,8 +4,6 @@ pub use self::language::*;
use super::{RagDocument, RagMetadata}; use super::{RagDocument, RagMetadata};
use std::cmp::Ordering;
pub const DEFAULT_SEPARATES: [&str; 4] = ["\n\n", "\n", " ", ""]; pub const DEFAULT_SEPARATES: [&str; 4] = ["\n\n", "\n", " ", ""];
pub fn get_separators(extension: &str) -> Vec<&'static str> { pub fn get_separators(extension: &str) -> Vec<&'static str> {
@ -71,15 +69,6 @@ impl RecursiveCharacterTextSplitter {
self self
} }
#[allow(unused)]
pub fn with_length_function<F>(mut self, length_function: F) -> Self
where
F: Fn(&str) -> usize + Send + Sync + 'static,
{
self.length_function = Box::new(length_function);
self
}
pub fn split_documents( pub fn split_documents(
&self, &self,
documents: &[RagDocument], documents: &[RagDocument],
@ -110,45 +99,33 @@ impl RecursiveCharacterTextSplitter {
let mut documents = Vec::new(); let mut documents = Vec::new();
for (i, text) in texts.iter().enumerate() { for (i, text) in texts.iter().enumerate() {
let mut line_counter_index = 1; let mut prev_chunk: Option<String> = None;
let mut prev_chunk = None; let mut index_prev_chunk = -1;
let mut index_prev_chunk = None;
for chunk in self.split_text(text) { for chunk in self.split_text(text) {
let mut page_content = chunk_header.clone(); let mut page_content = chunk_header.clone();
let index_chunk = { let index_chunk = if index_prev_chunk < 0 {
let idx = match index_prev_chunk { text.find(&chunk).map(|i| i as i32).unwrap_or(-1)
Some(v) => v + 1,
None => 0,
};
text[idx..].find(&chunk).map(|i| i + idx).unwrap_or(0)
};
if prev_chunk.is_none() {
line_counter_index += self.number_of_newlines(text, 0, index_chunk);
} else { } else {
let index_end_prev_chunk: usize = index_prev_chunk.unwrap_or_default() match text[(index_prev_chunk as usize)..].chars().next() {
+ (self.length_function)(prev_chunk.as_deref().unwrap_or_default()); Some(c) => {
let offset = (index_prev_chunk as usize) + c.len_utf8();
match index_end_prev_chunk.cmp(&index_chunk) { text[offset..]
Ordering::Less => { .find(&chunk)
line_counter_index += .map(|i| (i + offset) as i32)
self.number_of_newlines(text, index_end_prev_chunk, index_chunk); .unwrap_or(-1)
}
Ordering::Greater => {
let number =
self.number_of_newlines(text, index_chunk, index_end_prev_chunk);
line_counter_index = line_counter_index.saturating_sub(number);
} }
Ordering::Equal => {} None => -1,
} }
};
if prev_chunk.is_some() {
if let Some(chunk_overlap_header) = chunk_overlap_header { if let Some(chunk_overlap_header) = chunk_overlap_header {
page_content += chunk_overlap_header; page_content += chunk_overlap_header;
} }
} }
let newlines_count = self.number_of_newlines(&chunk, 0, chunk.len());
let metadata = metadatas[i].clone(); let metadata = metadatas[i].clone();
page_content += &chunk; page_content += &chunk;
documents.push(RagDocument { documents.push(RagDocument {
@ -156,19 +133,14 @@ impl RecursiveCharacterTextSplitter {
metadata, metadata,
}); });
line_counter_index += newlines_count;
prev_chunk = Some(chunk); prev_chunk = Some(chunk);
index_prev_chunk = Some(index_chunk); index_prev_chunk = index_chunk;
} }
} }
documents documents
} }
fn number_of_newlines(&self, text: &str, start: usize, end: usize) -> usize {
text[start..end].matches('\n').count()
}
pub fn split_text(&self, text: &str) -> Vec<String> { pub fn split_text(&self, text: &str) -> Vec<String> {
let keep_separator = self let keep_separator = self
.separators .separators

@ -29,23 +29,31 @@ pub async fn fetch(loaders: &HashMap<String, String>, path: &str) -> Result<(Str
Err(ref err) => bail!("{err}"), Err(ref err) => bail!("{err}"),
}; };
let mut res = client.get(path).send().await?; let mut res = client.get(path).send().await?;
let extension = path
.rsplit_once('/')
.and_then(|(_, pair)| pair.rsplit_once('.').map(|(_, ext)| ext))
.unwrap_or(DEFAULT_EXTENSION);
let mut extension = extension.to_lowercase();
let content_type = res let content_type = res
.headers() .headers()
.get(CONTENT_TYPE) .get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.map(|v| match v.split_once(';') { .map(|v| match v.split_once(';') {
Some((mime, _)) => mime, Some((mime, _)) => mime.trim(),
None => v, None => v,
}); })
if let Some(true) = content_type.map(|v| v.contains("text/html")) { .unwrap_or_default();
extension = "html".into() let extension = match content_type {
} "application/pdf" => "pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document" => "docx",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" => "xlsx",
"application/vnd.openxmlformats-officedocument.presentationml.presentation" => "pptx",
"application/vnd.oasis.opendocument.text" => "odt",
"application/vnd.oasis.opendocument.spreadsheet" => "ods",
"application/vnd.oasis.opendocument.presentation" => "odp",
"application/rtf" => "rtf",
"text/html" => "html",
_ => path
.rsplit_once('/')
.and_then(|(_, pair)| pair.rsplit_once('.').map(|(_, ext)| ext))
.unwrap_or(DEFAULT_EXTENSION),
};
let extension = extension.to_lowercase();
let result = match loaders.get(&extension) { let result = match loaders.get(&extension) {
Some(loader_command) => { Some(loader_command) => {
let save_path = temp_file("-download-", &format!(".{extension}")) let save_path = temp_file("-download-", &format!(".{extension}"))

Loading…
Cancel
Save