refactor: improve RAG (#675)

pull/676/head
sigoden 4 months ago committed by GitHub
parent 7c6dac061b
commit 2f2b13c891
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -34,7 +34,7 @@ document_loaders:
docx: 'pandoc --to plain $1' # Load .docx file
# xlsx: 'ssconvert $1 $2' # Load .xlsx file
# html: 'pandoc --to plain $1' # Load .html file
# recursive_url: 'rag-crawler $1 $2' # Load websites, see https://github.com/sigoden/rag-crawler
recursive_url: 'rag-crawler $1 $2' # Load websites, see https://github.com/sigoden/rag-crawler
# ---- function-calling & agent ----
# Controls the function calling feature. For setup instructions, visit https://github.com/sigoden/llm-functions

@ -1446,7 +1446,11 @@ impl Config {
}
fn setup_document_loaders(&mut self) {
[("pdf", "pdftotext $1 -"), ("docx", "pandoc --to plain $1")]
[
("pdf", "pdftotext $1 -"),
("docx", "pandoc --to plain $1"),
(RECURSIVE_URL_LOADER, "rag-crawler $1 $2"),
]
.into_iter()
.for_each(|(k, v)| {
let (k, v) = (k.to_string(), v.to_string());

@ -14,7 +14,7 @@ pub async fn load_recrusive_url(
let extension = RECURSIVE_URL_LOADER;
let loader_command = loaders
.get(extension)
.with_context(|| format!("RAG document loader '{extension}' not configured"))?;
.with_context(|| format!("Document loader '{extension}' not configured"))?;
let contents = run_loader_command(path, extension, loader_command)?;
let pages: Vec<WebPage> = serde_json::from_str(&contents).context(r#"The crawler response is invalid. It should follow the JSON format: `[{"path":"...", "text":"..."}]`."#)?;
let output = pages

@ -4,8 +4,6 @@ pub use self::language::*;
use super::{RagDocument, RagMetadata};
use std::cmp::Ordering;
pub const DEFAULT_SEPARATES: [&str; 4] = ["\n\n", "\n", " ", ""];
pub fn get_separators(extension: &str) -> Vec<&'static str> {
@ -71,15 +69,6 @@ impl RecursiveCharacterTextSplitter {
self
}
#[allow(unused)]
pub fn with_length_function<F>(mut self, length_function: F) -> Self
where
F: Fn(&str) -> usize + Send + Sync + 'static,
{
self.length_function = Box::new(length_function);
self
}
pub fn split_documents(
&self,
documents: &[RagDocument],
@ -110,45 +99,33 @@ impl RecursiveCharacterTextSplitter {
let mut documents = Vec::new();
for (i, text) in texts.iter().enumerate() {
let mut line_counter_index = 1;
let mut prev_chunk = None;
let mut index_prev_chunk = None;
let mut prev_chunk: Option<String> = None;
let mut index_prev_chunk = -1;
for chunk in self.split_text(text) {
let mut page_content = chunk_header.clone();
let index_chunk = {
let idx = match index_prev_chunk {
Some(v) => v + 1,
None => 0,
};
text[idx..].find(&chunk).map(|i| i + idx).unwrap_or(0)
};
if prev_chunk.is_none() {
line_counter_index += self.number_of_newlines(text, 0, index_chunk);
let index_chunk = if index_prev_chunk < 0 {
text.find(&chunk).map(|i| i as i32).unwrap_or(-1)
} else {
let index_end_prev_chunk: usize = index_prev_chunk.unwrap_or_default()
+ (self.length_function)(prev_chunk.as_deref().unwrap_or_default());
match index_end_prev_chunk.cmp(&index_chunk) {
Ordering::Less => {
line_counter_index +=
self.number_of_newlines(text, index_end_prev_chunk, index_chunk);
}
Ordering::Greater => {
let number =
self.number_of_newlines(text, index_chunk, index_end_prev_chunk);
line_counter_index = line_counter_index.saturating_sub(number);
match text[(index_prev_chunk as usize)..].chars().next() {
Some(c) => {
let offset = (index_prev_chunk as usize) + c.len_utf8();
text[offset..]
.find(&chunk)
.map(|i| (i + offset) as i32)
.unwrap_or(-1)
}
Ordering::Equal => {}
None => -1,
}
};
if prev_chunk.is_some() {
if let Some(chunk_overlap_header) = chunk_overlap_header {
page_content += chunk_overlap_header;
}
}
let newlines_count = self.number_of_newlines(&chunk, 0, chunk.len());
let metadata = metadatas[i].clone();
page_content += &chunk;
documents.push(RagDocument {
@ -156,19 +133,14 @@ impl RecursiveCharacterTextSplitter {
metadata,
});
line_counter_index += newlines_count;
prev_chunk = Some(chunk);
index_prev_chunk = Some(index_chunk);
index_prev_chunk = index_chunk;
}
}
documents
}
fn number_of_newlines(&self, text: &str, start: usize, end: usize) -> usize {
text[start..end].matches('\n').count()
}
pub fn split_text(&self, text: &str) -> Vec<String> {
let keep_separator = self
.separators

@ -29,23 +29,31 @@ pub async fn fetch(loaders: &HashMap<String, String>, path: &str) -> Result<(Str
Err(ref err) => bail!("{err}"),
};
let mut res = client.get(path).send().await?;
let extension = path
.rsplit_once('/')
.and_then(|(_, pair)| pair.rsplit_once('.').map(|(_, ext)| ext))
.unwrap_or(DEFAULT_EXTENSION);
let mut extension = extension.to_lowercase();
let content_type = res
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|v| match v.split_once(';') {
Some((mime, _)) => mime,
Some((mime, _)) => mime.trim(),
None => v,
});
if let Some(true) = content_type.map(|v| v.contains("text/html")) {
extension = "html".into()
}
})
.unwrap_or_default();
let extension = match content_type {
"application/pdf" => "pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document" => "docx",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" => "xlsx",
"application/vnd.openxmlformats-officedocument.presentationml.presentation" => "pptx",
"application/vnd.oasis.opendocument.text" => "odt",
"application/vnd.oasis.opendocument.spreadsheet" => "ods",
"application/vnd.oasis.opendocument.presentation" => "odp",
"application/rtf" => "rtf",
"text/html" => "html",
_ => path
.rsplit_once('/')
.and_then(|(_, pair)| pair.rsplit_once('.').map(|(_, ext)| ext))
.unwrap_or(DEFAULT_EXTENSION),
};
let extension = extension.to_lowercase();
let result = match loaders.get(&extension) {
Some(loader_command) => {
let save_path = temp_file("-download-", &format!(".{extension}"))

Loading…
Cancel
Save