feat: rag splitter supports languages (#593)

pull/594/head
sigoden 4 weeks ago committed by GitHub
parent 492b006db7
commit 64982b4510
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -186,8 +186,12 @@ impl Rag {
.extension() .extension()
.map(|v| v.to_string_lossy().to_lowercase()) .map(|v| v.to_string_lossy().to_lowercase())
.unwrap_or_default(); .unwrap_or_default();
let separator = autodetect_separator(&extension); let separator = detect_separators(&extension);
let splitter = Splitter::new(self.data.chunk_size, self.data.chunk_overlap, separator); let splitter = RecursiveCharacterTextSplitter::new(
self.data.chunk_size,
self.data.chunk_overlap,
&separator,
);
let documents = load(&path, &extension) let documents = load(&path, &extension)
.with_context(|| format!("Failed to load file at '{path}'"))?; .with_context(|| format!("Failed to load file at '{path}'"))?;
let documents = let documents =
@ -207,9 +211,9 @@ impl Rag {
let mut vector_ids = vec![]; let mut vector_ids = vec![];
let mut texts = vec![]; let mut texts = vec![];
for (file_index, file) in rag_files.iter().enumerate() { for (file_index, file) in rag_files.iter().enumerate() {
for (document_index, doc) in file.documents.iter().enumerate() { for (document_index, document) in file.documents.iter().enumerate() {
vector_ids.push(combine_vector_id(file_index, document_index)); vector_ids.push(combine_vector_id(file_index, document_index));
texts.push(doc.page_content.clone()) texts.push(document_text(&file.path, document))
} }
} }
@ -226,7 +230,7 @@ impl Rag {
} }
async fn search_impl(&self, text: &str, top_k: usize) -> Result<Vec<String>> { async fn search_impl(&self, text: &str, top_k: usize) -> Result<Vec<String>> {
let splitter = Splitter::new( let splitter = RecursiveCharacterTextSplitter::new(
self.data.chunk_size, self.data.chunk_size,
self.data.chunk_overlap, self.data.chunk_overlap,
&DEFAULT_SEPARATES, &DEFAULT_SEPARATES,
@ -245,10 +249,9 @@ impl Rag {
return None; return None;
} }
let (file_index, document_index) = split_vector_id(v.d_id); let (file_index, document_index) = split_vector_id(v.d_id);
let text = self.data.files[file_index].documents[document_index] let file = self.data.files.get(file_index)?;
.page_content let document = file.documents.get(document_index)?;
.clone(); Some(document_text(&file.path, document))
Some(text)
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
}) })
@ -378,6 +381,14 @@ pub fn split_vector_id(value: VectorID) -> (usize, usize) {
(high, low) (high, low)
} }
fn document_text(file_path: &str, document: &RagDocument) -> String {
format!(
"file_path: {}\n\n{}",
shell_words::quote(file_path),
document.page_content
)
}
fn retrieve_embedding_model(config: &Config, model_id: &str) -> Result<Model> { fn retrieve_embedding_model(config: &Config, model_id: &str) -> Result<Model> {
let model = Model::find(&list_embedding_models(config), model_id) let model = Model::find(&list_embedding_models(config), model_id)
.ok_or_else(|| anyhow!("No embedding model '{model_id}'"))?; .ok_or_else(|| anyhow!("No embedding model '{model_id}'"))?;

@ -0,0 +1,235 @@
#[derive(PartialEq, Eq, Hash)]
pub enum Language {
Cpp,
Go,
Java,
Js,
Php,
Proto,
Python,
Rst,
Ruby,
Rust,
Scala,
Swift,
Markdown,
Latex,
Html,
Sol,
}
impl Language {
pub fn separators(&self) -> Vec<&str> {
match self {
Language::Cpp => vec![
"\nclass ",
"\nvoid ",
"\nint ",
"\nfloat ",
"\ndouble ",
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Go => vec![
"\nfunc ",
"\nvar ",
"\nconst ",
"\ntype ",
"\nif ",
"\nfor ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Java => vec![
"\nclass ",
"\npublic ",
"\nprotected ",
"\nprivate ",
"\nstatic ",
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Js => vec![
"\nfunction ",
"\nconst ",
"\nlet ",
"\nvar ",
"\nclass ",
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\ndefault ",
"\n\n",
"\n",
" ",
"",
],
Language::Php => vec![
"\nfunction ",
"\nclass ",
"\nif ",
"\nforeach ",
"\nwhile ",
"\ndo ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Proto => vec![
"\nmessage ",
"\nservice ",
"\nenum ",
"\noption ",
"\nimport ",
"\nsyntax ",
"\n\n",
"\n",
" ",
"",
],
Language::Python => vec!["\nclass ", "\ndef ", "\n\tdef ", "\n\n", "\n", " ", ""],
Language::Rst => vec![
"\n===\n", "\n---\n", "\n***\n", "\n.. ", "\n\n", "\n", " ", "",
],
Language::Ruby => vec![
"\ndef ",
"\nclass ",
"\nif ",
"\nunless ",
"\nwhile ",
"\nfor ",
"\ndo ",
"\nbegin ",
"\nrescue ",
"\n\n",
"\n",
" ",
"",
],
Language::Rust => vec![
"\nfn ", "\nconst ", "\nlet ", "\nif ", "\nwhile ", "\nfor ", "\nloop ",
"\nmatch ", "\nconst ", "\n\n", "\n", " ", "",
],
Language::Scala => vec![
"\nclass ",
"\nobject ",
"\ndef ",
"\nval ",
"\nvar ",
"\nif ",
"\nfor ",
"\nwhile ",
"\nmatch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Swift => vec![
"\nfunc ",
"\nclass ",
"\nstruct ",
"\nenum ",
"\nif ",
"\nfor ",
"\nwhile ",
"\ndo ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Markdown => vec![
"\n## ",
"\n### ",
"\n#### ",
"\n##### ",
"\n###### ",
"```\n\n",
"\n\n***\n\n",
"\n\n---\n\n",
"\n\n___\n\n",
"\n\n",
"\n",
" ",
"",
],
Language::Latex => vec![
"\n\\chapter{",
"\n\\section{",
"\n\\subsection{",
"\n\\subsubsection{",
"\n\\begin{enumerate}",
"\n\\begin{itemize}",
"\n\\begin{description}",
"\n\\begin{list}",
"\n\\begin{quote}",
"\n\\begin{quotation}",
"\n\\begin{verse}",
"\n\\begin{verbatim}",
"\n\\begin{align}",
"$$",
"$",
"\n\n",
"\n",
" ",
"",
],
Language::Html => vec![
"<body>", "<div>", "<p>", "<br>", "<li>", "<h1>", "<h2>", "<h3>", "<h4>", "<h5>",
"<h6>", "<span>", "<table>", "<tr>", "<td>", "<th>", "<ul>", "<ol>", "<header>",
"<footer>", "<nav>", "<head>", "<style>", "<script>", "<meta>", "<title>", " ", "",
],
Language::Sol => vec![
"\npragma ",
"\nusing ",
"\ncontract ",
"\ninterface ",
"\nlibrary ",
"\nconstructor ",
"\ntype ",
"\nfunction ",
"\nevent ",
"\nmodifier ",
"\nerror ",
"\nstruct ",
"\nenum ",
"\nif ",
"\nfor ",
"\nwhile ",
"\ndo while ",
"\nassembly ",
"\n\n",
"\n",
" ",
"",
],
}
}
}

@ -1,82 +1,43 @@
mod language;
pub use self::language::*;
use super::{RagDocument, RagMetadata}; use super::{RagDocument, RagMetadata};
use std::cmp::Ordering; use std::cmp::Ordering;
pub const DEFAULT_SEPARATES: [&str; 4] = ["\n\n", "\n", " ", ""]; pub const DEFAULT_SEPARATES: [&str; 4] = ["\n\n", "\n", " ", ""];
pub const HTML_SEPARATES: [&str; 28] = [
// First, try to split along HTML tags pub fn detect_separators(extension: &str) -> Vec<&'static str> {
"<body>", "<div>", "<p>", "<br>", "<li>", "<h1>", "<h2>", "<h3>", "<h4>", "<h5>", "<h6>",
"<span>", "<table>", "<tr>", "<td>", "<th>", "<ul>", "<ol>", "<header>", "<footer>", "<nav>",
// Head
"<head>", "<style>", "<script>", "<meta>", "<title>", // Normal type of lines
" ", "",
];
pub const MARKDOWN_SEPARATES: [&str; 13] = [
// First, try to split along Markdown headings (starting with level 2)
"\n## ",
"\n### ",
"\n#### ",
"\n##### ",
"\n###### ",
// Note the alternative syntax for headings (below) is not handled here
// Heading level 2
// ---------------
// End of code block
"```\n\n",
// Horizontal lines
"\n\n***\n\n",
"\n\n---\n\n",
"\n\n___\n\n",
// Note that this splitter doesn't handle horizontal lines defined
// by *three or more* of ***, ---, or ___, but this is not handled
"\n\n",
"\n",
" ",
"",
];
pub const LATEX_SEPARATES: [&str; 19] = [
// First, try to split along Latex sections
"\n\\chapter{",
"\n\\section{",
"\n\\subsection{",
"\n\\subsubsection{",
// Now split by environments
"\n\\begin{enumerate}",
"\n\\begin{itemize}",
"\n\\begin{description}",
"\n\\begin{list}",
"\n\\begin{quote}",
"\n\\begin{quotation}",
"\n\\begin{verse}",
"\n\\begin{verbatim}",
// Now split by math environments
"\n\\begin{align}",
"$$",
"$",
// Now split by the normal type of lines
"\n\n",
"\n",
" ",
"",
];
pub fn autodetect_separator(extension: &str) -> &[&'static str] {
match extension { match extension {
"md" | "mkd" => &MARKDOWN_SEPARATES, "c" | "cc" | "cpp" => Language::Cpp.separators(),
"htm" | "html" => &HTML_SEPARATES, "go" => Language::Go.separators(),
"tex" => &LATEX_SEPARATES, "java" => Language::Java.separators(),
_ => &DEFAULT_SEPARATES, "js" | "mjs" | "cjs" => Language::Js.separators(),
"php" => Language::Php.separators(),
"proto" => Language::Proto.separators(),
"py" => Language::Python.separators(),
"rst" => Language::Rst.separators(),
"rb" => Language::Ruby.separators(),
"rs" => Language::Rust.separators(),
"scala" => Language::Scala.separators(),
"swift" => Language::Swift.separators(),
"md" | "mkd" => Language::Markdown.separators(),
"tex" => Language::Latex.separators(),
"htm" | "html" => Language::Html.separators(),
"sol" => Language::Sol.separators(),
_ => DEFAULT_SEPARATES.to_vec(),
} }
} }
pub struct Splitter { pub struct RecursiveCharacterTextSplitter {
pub chunk_size: usize, pub chunk_size: usize,
pub chunk_overlap: usize, pub chunk_overlap: usize,
pub separators: Vec<String>, pub separators: Vec<String>,
pub length_function: Box<dyn Fn(&str) -> usize + Send + Sync>, pub length_function: Box<dyn Fn(&str) -> usize + Send + Sync>,
} }
impl Default for Splitter { impl Default for RecursiveCharacterTextSplitter {
fn default() -> Self { fn default() -> Self {
Self { Self {
chunk_size: 1000, chunk_size: 1000,
@ -87,8 +48,7 @@ impl Default for Splitter {
} }
} }
// Builder pattern for Options struct impl RecursiveCharacterTextSplitter {
impl Splitter {
pub fn new(chunk_size: usize, chunk_overlap: usize, separators: &[&str]) -> Self { pub fn new(chunk_size: usize, chunk_overlap: usize, separators: &[&str]) -> Self {
Self::default() Self::default()
.with_chunk_size(chunk_size) .with_chunk_size(chunk_size)
@ -406,7 +366,7 @@ mod tests {
} }
#[test] #[test]
fn test_split_text() { fn test_split_text() {
let splitter = Splitter { let splitter = RecursiveCharacterTextSplitter {
chunk_size: 7, chunk_size: 7,
chunk_overlap: 3, chunk_overlap: 3,
separators: vec![" ".into()], separators: vec![" ".into()],
@ -418,7 +378,7 @@ mod tests {
#[test] #[test]
fn test_create_document() { fn test_create_document() {
let splitter = Splitter::new(3, 0, &[" "]); let splitter = RecursiveCharacterTextSplitter::new(3, 0, &[" "]);
let chunk_header_options = SplitterChunkHeaderOptions::default(); let chunk_header_options = SplitterChunkHeaderOptions::default();
let mut metadata1 = IndexMap::new(); let mut metadata1 = IndexMap::new();
metadata1.insert("source".into(), "1".into()); metadata1.insert("source".into(), "1".into());
@ -451,7 +411,7 @@ mod tests {
#[test] #[test]
fn test_chunk_header() { fn test_chunk_header() {
let splitter = Splitter::new(3, 0, &[" "]); let splitter = RecursiveCharacterTextSplitter::new(3, 0, &[" "]);
let chunk_header_options = SplitterChunkHeaderOptions::default() let chunk_header_options = SplitterChunkHeaderOptions::default()
.with_chunk_header("SOURCE NAME: testing\n-----\n") .with_chunk_header("SOURCE NAME: testing\n-----\n")
.with_append_chunk_overlap_header(true); .with_append_chunk_overlap_header(true);
@ -498,7 +458,8 @@ pip install langchain
``` ```
As an open source project in a rapidly developing field, we are extremely open to contributions."#; As an open source project in a rapidly developing field, we are extremely open to contributions."#;
let splitter = Splitter::new(100, 0, &MARKDOWN_SEPARATES); let splitter =
RecursiveCharacterTextSplitter::new(100, 0, &Language::Markdown.separators());
let output = splitter.split_text(text); let output = splitter.split_text(text);
let expected_output = vec![ let expected_output = vec![
"# 🦜️🔗 LangChain\n\n⚡ Building applications with LLMs through composability ⚡", "# 🦜️🔗 LangChain\n\n⚡ Building applications with LLMs through composability ⚡",
@ -534,7 +495,7 @@ As an open source project in a rapidly developing field, we are extremely open t
</div> </div>
</body> </body>
</html>"#; </html>"#;
let splitter = Splitter::new(175, 20, &HTML_SEPARATES); let splitter = RecursiveCharacterTextSplitter::new(175, 20, &Language::Html.separators());
let output = splitter.split_text(text); let output = splitter.split_text(text);
let expected_output = vec![ let expected_output = vec![
"<!DOCTYPE html>\n<html>", "<!DOCTYPE html>\n<html>",
Loading…
Cancel
Save