feat: rag splitter supports languages (#593)

pull/594/head
sigoden 3 weeks ago committed by GitHub
parent 492b006db7
commit 64982b4510
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -186,8 +186,12 @@ impl Rag {
.extension()
.map(|v| v.to_string_lossy().to_lowercase())
.unwrap_or_default();
let separator = autodetect_separator(&extension);
let splitter = Splitter::new(self.data.chunk_size, self.data.chunk_overlap, separator);
let separator = detect_separators(&extension);
let splitter = RecursiveCharacterTextSplitter::new(
self.data.chunk_size,
self.data.chunk_overlap,
&separator,
);
let documents = load(&path, &extension)
.with_context(|| format!("Failed to load file at '{path}'"))?;
let documents =
@ -207,9 +211,9 @@ impl Rag {
let mut vector_ids = vec![];
let mut texts = vec![];
for (file_index, file) in rag_files.iter().enumerate() {
for (document_index, doc) in file.documents.iter().enumerate() {
for (document_index, document) in file.documents.iter().enumerate() {
vector_ids.push(combine_vector_id(file_index, document_index));
texts.push(doc.page_content.clone())
texts.push(document_text(&file.path, document))
}
}
@ -226,7 +230,7 @@ impl Rag {
}
async fn search_impl(&self, text: &str, top_k: usize) -> Result<Vec<String>> {
let splitter = Splitter::new(
let splitter = RecursiveCharacterTextSplitter::new(
self.data.chunk_size,
self.data.chunk_overlap,
&DEFAULT_SEPARATES,
@ -245,10 +249,9 @@ impl Rag {
return None;
}
let (file_index, document_index) = split_vector_id(v.d_id);
let text = self.data.files[file_index].documents[document_index]
.page_content
.clone();
Some(text)
let file = self.data.files.get(file_index)?;
let document = file.documents.get(document_index)?;
Some(document_text(&file.path, document))
})
.collect::<Vec<_>>()
})
@ -378,6 +381,14 @@ pub fn split_vector_id(value: VectorID) -> (usize, usize) {
(high, low)
}
fn document_text(file_path: &str, document: &RagDocument) -> String {
format!(
"file_path: {}\n\n{}",
shell_words::quote(file_path),
document.page_content
)
}
fn retrieve_embedding_model(config: &Config, model_id: &str) -> Result<Model> {
let model = Model::find(&list_embedding_models(config), model_id)
.ok_or_else(|| anyhow!("No embedding model '{model_id}'"))?;

@ -0,0 +1,235 @@
#[derive(PartialEq, Eq, Hash)]
pub enum Language {
Cpp,
Go,
Java,
Js,
Php,
Proto,
Python,
Rst,
Ruby,
Rust,
Scala,
Swift,
Markdown,
Latex,
Html,
Sol,
}
impl Language {
pub fn separators(&self) -> Vec<&str> {
match self {
Language::Cpp => vec![
"\nclass ",
"\nvoid ",
"\nint ",
"\nfloat ",
"\ndouble ",
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Go => vec![
"\nfunc ",
"\nvar ",
"\nconst ",
"\ntype ",
"\nif ",
"\nfor ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Java => vec![
"\nclass ",
"\npublic ",
"\nprotected ",
"\nprivate ",
"\nstatic ",
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Js => vec![
"\nfunction ",
"\nconst ",
"\nlet ",
"\nvar ",
"\nclass ",
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\ndefault ",
"\n\n",
"\n",
" ",
"",
],
Language::Php => vec![
"\nfunction ",
"\nclass ",
"\nif ",
"\nforeach ",
"\nwhile ",
"\ndo ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Proto => vec![
"\nmessage ",
"\nservice ",
"\nenum ",
"\noption ",
"\nimport ",
"\nsyntax ",
"\n\n",
"\n",
" ",
"",
],
Language::Python => vec!["\nclass ", "\ndef ", "\n\tdef ", "\n\n", "\n", " ", ""],
Language::Rst => vec![
"\n===\n", "\n---\n", "\n***\n", "\n.. ", "\n\n", "\n", " ", "",
],
Language::Ruby => vec![
"\ndef ",
"\nclass ",
"\nif ",
"\nunless ",
"\nwhile ",
"\nfor ",
"\ndo ",
"\nbegin ",
"\nrescue ",
"\n\n",
"\n",
" ",
"",
],
Language::Rust => vec![
"\nfn ", "\nconst ", "\nlet ", "\nif ", "\nwhile ", "\nfor ", "\nloop ",
"\nmatch ", "\nconst ", "\n\n", "\n", " ", "",
],
Language::Scala => vec![
"\nclass ",
"\nobject ",
"\ndef ",
"\nval ",
"\nvar ",
"\nif ",
"\nfor ",
"\nwhile ",
"\nmatch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Swift => vec![
"\nfunc ",
"\nclass ",
"\nstruct ",
"\nenum ",
"\nif ",
"\nfor ",
"\nwhile ",
"\ndo ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Markdown => vec![
"\n## ",
"\n### ",
"\n#### ",
"\n##### ",
"\n###### ",
"```\n\n",
"\n\n***\n\n",
"\n\n---\n\n",
"\n\n___\n\n",
"\n\n",
"\n",
" ",
"",
],
Language::Latex => vec![
"\n\\chapter{",
"\n\\section{",
"\n\\subsection{",
"\n\\subsubsection{",
"\n\\begin{enumerate}",
"\n\\begin{itemize}",
"\n\\begin{description}",
"\n\\begin{list}",
"\n\\begin{quote}",
"\n\\begin{quotation}",
"\n\\begin{verse}",
"\n\\begin{verbatim}",
"\n\\begin{align}",
"$$",
"$",
"\n\n",
"\n",
" ",
"",
],
Language::Html => vec![
"<body>", "<div>", "<p>", "<br>", "<li>", "<h1>", "<h2>", "<h3>", "<h4>", "<h5>",
"<h6>", "<span>", "<table>", "<tr>", "<td>", "<th>", "<ul>", "<ol>", "<header>",
"<footer>", "<nav>", "<head>", "<style>", "<script>", "<meta>", "<title>", " ", "",
],
Language::Sol => vec![
"\npragma ",
"\nusing ",
"\ncontract ",
"\ninterface ",
"\nlibrary ",
"\nconstructor ",
"\ntype ",
"\nfunction ",
"\nevent ",
"\nmodifier ",
"\nerror ",
"\nstruct ",
"\nenum ",
"\nif ",
"\nfor ",
"\nwhile ",
"\ndo while ",
"\nassembly ",
"\n\n",
"\n",
" ",
"",
],
}
}
}

@ -1,82 +1,43 @@
mod language;
pub use self::language::*;
use super::{RagDocument, RagMetadata};
use std::cmp::Ordering;
pub const DEFAULT_SEPARATES: [&str; 4] = ["\n\n", "\n", " ", ""];
pub const HTML_SEPARATES: [&str; 28] = [
// First, try to split along HTML tags
"<body>", "<div>", "<p>", "<br>", "<li>", "<h1>", "<h2>", "<h3>", "<h4>", "<h5>", "<h6>",
"<span>", "<table>", "<tr>", "<td>", "<th>", "<ul>", "<ol>", "<header>", "<footer>", "<nav>",
// Head
"<head>", "<style>", "<script>", "<meta>", "<title>", // Normal type of lines
" ", "",
];
pub const MARKDOWN_SEPARATES: [&str; 13] = [
// First, try to split along Markdown headings (starting with level 2)
"\n## ",
"\n### ",
"\n#### ",
"\n##### ",
"\n###### ",
// Note the alternative syntax for headings (below) is not handled here
// Heading level 2
// ---------------
// End of code block
"```\n\n",
// Horizontal lines
"\n\n***\n\n",
"\n\n---\n\n",
"\n\n___\n\n",
// Note that this splitter doesn't handle horizontal lines defined
// by *three or more* of ***, ---, or ___, but this is not handled
"\n\n",
"\n",
" ",
"",
];
pub const LATEX_SEPARATES: [&str; 19] = [
// First, try to split along Latex sections
"\n\\chapter{",
"\n\\section{",
"\n\\subsection{",
"\n\\subsubsection{",
// Now split by environments
"\n\\begin{enumerate}",
"\n\\begin{itemize}",
"\n\\begin{description}",
"\n\\begin{list}",
"\n\\begin{quote}",
"\n\\begin{quotation}",
"\n\\begin{verse}",
"\n\\begin{verbatim}",
// Now split by math environments
"\n\\begin{align}",
"$$",
"$",
// Now split by the normal type of lines
"\n\n",
"\n",
" ",
"",
];
pub fn autodetect_separator(extension: &str) -> &[&'static str] {
pub fn detect_separators(extension: &str) -> Vec<&'static str> {
match extension {
"md" | "mkd" => &MARKDOWN_SEPARATES,
"htm" | "html" => &HTML_SEPARATES,
"tex" => &LATEX_SEPARATES,
_ => &DEFAULT_SEPARATES,
"c" | "cc" | "cpp" => Language::Cpp.separators(),
"go" => Language::Go.separators(),
"java" => Language::Java.separators(),
"js" | "mjs" | "cjs" => Language::Js.separators(),
"php" => Language::Php.separators(),
"proto" => Language::Proto.separators(),
"py" => Language::Python.separators(),
"rst" => Language::Rst.separators(),
"rb" => Language::Ruby.separators(),
"rs" => Language::Rust.separators(),
"scala" => Language::Scala.separators(),
"swift" => Language::Swift.separators(),
"md" | "mkd" => Language::Markdown.separators(),
"tex" => Language::Latex.separators(),
"htm" | "html" => Language::Html.separators(),
"sol" => Language::Sol.separators(),
_ => DEFAULT_SEPARATES.to_vec(),
}
}
pub struct Splitter {
pub struct RecursiveCharacterTextSplitter {
pub chunk_size: usize,
pub chunk_overlap: usize,
pub separators: Vec<String>,
pub length_function: Box<dyn Fn(&str) -> usize + Send + Sync>,
}
impl Default for Splitter {
impl Default for RecursiveCharacterTextSplitter {
fn default() -> Self {
Self {
chunk_size: 1000,
@ -87,8 +48,7 @@ impl Default for Splitter {
}
}
// Builder pattern for Options struct
impl Splitter {
impl RecursiveCharacterTextSplitter {
pub fn new(chunk_size: usize, chunk_overlap: usize, separators: &[&str]) -> Self {
Self::default()
.with_chunk_size(chunk_size)
@ -406,7 +366,7 @@ mod tests {
}
#[test]
fn test_split_text() {
let splitter = Splitter {
let splitter = RecursiveCharacterTextSplitter {
chunk_size: 7,
chunk_overlap: 3,
separators: vec![" ".into()],
@ -418,7 +378,7 @@ mod tests {
#[test]
fn test_create_document() {
let splitter = Splitter::new(3, 0, &[" "]);
let splitter = RecursiveCharacterTextSplitter::new(3, 0, &[" "]);
let chunk_header_options = SplitterChunkHeaderOptions::default();
let mut metadata1 = IndexMap::new();
metadata1.insert("source".into(), "1".into());
@ -451,7 +411,7 @@ mod tests {
#[test]
fn test_chunk_header() {
let splitter = Splitter::new(3, 0, &[" "]);
let splitter = RecursiveCharacterTextSplitter::new(3, 0, &[" "]);
let chunk_header_options = SplitterChunkHeaderOptions::default()
.with_chunk_header("SOURCE NAME: testing\n-----\n")
.with_append_chunk_overlap_header(true);
@ -498,7 +458,8 @@ pip install langchain
```
As an open source project in a rapidly developing field, we are extremely open to contributions."#;
let splitter = Splitter::new(100, 0, &MARKDOWN_SEPARATES);
let splitter =
RecursiveCharacterTextSplitter::new(100, 0, &Language::Markdown.separators());
let output = splitter.split_text(text);
let expected_output = vec![
"# 🦜️🔗 LangChain\n\n⚡ Building applications with LLMs through composability ⚡",
@ -534,7 +495,7 @@ As an open source project in a rapidly developing field, we are extremely open t
</div>
</body>
</html>"#;
let splitter = Splitter::new(175, 20, &HTML_SEPARATES);
let splitter = RecursiveCharacterTextSplitter::new(175, 20, &Language::Html.separators());
let output = splitter.split_text(text);
let expected_output = vec![
"<!DOCTYPE html>\n<html>",
Loading…
Cancel
Save