From 64982b4510e38153885bfd0a78c250110b3e03c5 Mon Sep 17 00:00:00 2001 From: sigoden Date: Wed, 12 Jun 2024 19:17:40 +0800 Subject: [PATCH] feat: rag splitter supports languages (#593) --- src/rag/mod.rs | 29 ++- src/rag/splitter/language.rs | 235 +++++++++++++++++++++++ src/rag/{splitter.rs => splitter/mod.rs} | 103 +++------- 3 files changed, 287 insertions(+), 80 deletions(-) create mode 100644 src/rag/splitter/language.rs rename src/rag/{splitter.rs => splitter/mod.rs} (88%) diff --git a/src/rag/mod.rs b/src/rag/mod.rs index 011b574..f88e80e 100644 --- a/src/rag/mod.rs +++ b/src/rag/mod.rs @@ -186,8 +186,12 @@ impl Rag { .extension() .map(|v| v.to_string_lossy().to_lowercase()) .unwrap_or_default(); - let separator = autodetect_separator(&extension); - let splitter = Splitter::new(self.data.chunk_size, self.data.chunk_overlap, separator); + let separator = detect_separators(&extension); + let splitter = RecursiveCharacterTextSplitter::new( + self.data.chunk_size, + self.data.chunk_overlap, + &separator, + ); let documents = load(&path, &extension) .with_context(|| format!("Failed to load file at '{path}'"))?; let documents = @@ -207,9 +211,9 @@ impl Rag { let mut vector_ids = vec![]; let mut texts = vec![]; for (file_index, file) in rag_files.iter().enumerate() { - for (document_index, doc) in file.documents.iter().enumerate() { + for (document_index, document) in file.documents.iter().enumerate() { vector_ids.push(combine_vector_id(file_index, document_index)); - texts.push(doc.page_content.clone()) + texts.push(document_text(&file.path, document)) } } @@ -226,7 +230,7 @@ impl Rag { } async fn search_impl(&self, text: &str, top_k: usize) -> Result> { - let splitter = Splitter::new( + let splitter = RecursiveCharacterTextSplitter::new( self.data.chunk_size, self.data.chunk_overlap, &DEFAULT_SEPARATES, @@ -245,10 +249,9 @@ impl Rag { return None; } let (file_index, document_index) = split_vector_id(v.d_id); - let text = self.data.files[file_index].documents[document_index] - .page_content - .clone(); - Some(text) + let file = self.data.files.get(file_index)?; + let document = file.documents.get(document_index)?; + Some(document_text(&file.path, document)) }) .collect::>() }) @@ -378,6 +381,14 @@ pub fn split_vector_id(value: VectorID) -> (usize, usize) { (high, low) } +fn document_text(file_path: &str, document: &RagDocument) -> String { + format!( + "file_path: {}\n\n{}", + shell_words::quote(file_path), + document.page_content + ) +} + fn retrieve_embedding_model(config: &Config, model_id: &str) -> Result { let model = Model::find(&list_embedding_models(config), model_id) .ok_or_else(|| anyhow!("No embedding model '{model_id}'"))?; diff --git a/src/rag/splitter/language.rs b/src/rag/splitter/language.rs new file mode 100644 index 0000000..20722cf --- /dev/null +++ b/src/rag/splitter/language.rs @@ -0,0 +1,235 @@ +#[derive(PartialEq, Eq, Hash)] +pub enum Language { + Cpp, + Go, + Java, + Js, + Php, + Proto, + Python, + Rst, + Ruby, + Rust, + Scala, + Swift, + Markdown, + Latex, + Html, + Sol, +} + +impl Language { + pub fn separators(&self) -> Vec<&str> { + match self { + Language::Cpp => vec![ + "\nclass ", + "\nvoid ", + "\nint ", + "\nfloat ", + "\ndouble ", + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\n\n", + "\n", + " ", + "", + ], + Language::Go => vec![ + "\nfunc ", + "\nvar ", + "\nconst ", + "\ntype ", + "\nif ", + "\nfor ", + "\nswitch ", + "\ncase ", + "\n\n", + "\n", + " ", + "", + ], + Language::Java => vec![ + "\nclass ", + "\npublic ", + "\nprotected ", + "\nprivate ", + "\nstatic ", + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\n\n", + "\n", + " ", + "", + ], + Language::Js => vec![ + "\nfunction ", + "\nconst ", + "\nlet ", + "\nvar ", + "\nclass ", + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\ndefault ", + "\n\n", + "\n", + " ", + "", + ], + Language::Php => vec![ + "\nfunction ", + "\nclass ", + "\nif ", + "\nforeach ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + "\n\n", + "\n", + " ", + "", + ], + Language::Proto => vec![ + "\nmessage ", + "\nservice ", + "\nenum ", + "\noption ", + "\nimport ", + "\nsyntax ", + "\n\n", + "\n", + " ", + "", + ], + Language::Python => vec!["\nclass ", "\ndef ", "\n\tdef ", "\n\n", "\n", " ", ""], + Language::Rst => vec![ + "\n===\n", "\n---\n", "\n***\n", "\n.. ", "\n\n", "\n", " ", "", + ], + Language::Ruby => vec![ + "\ndef ", + "\nclass ", + "\nif ", + "\nunless ", + "\nwhile ", + "\nfor ", + "\ndo ", + "\nbegin ", + "\nrescue ", + "\n\n", + "\n", + " ", + "", + ], + Language::Rust => vec![ + "\nfn ", "\nconst ", "\nlet ", "\nif ", "\nwhile ", "\nfor ", "\nloop ", + "\nmatch ", "\nconst ", "\n\n", "\n", " ", "", + ], + Language::Scala => vec![ + "\nclass ", + "\nobject ", + "\ndef ", + "\nval ", + "\nvar ", + "\nif ", + "\nfor ", + "\nwhile ", + "\nmatch ", + "\ncase ", + "\n\n", + "\n", + " ", + "", + ], + Language::Swift => vec![ + "\nfunc ", + "\nclass ", + "\nstruct ", + "\nenum ", + "\nif ", + "\nfor ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + "\n\n", + "\n", + " ", + "", + ], + Language::Markdown => vec![ + "\n## ", + "\n### ", + "\n#### ", + "\n##### ", + "\n###### ", + "```\n\n", + "\n\n***\n\n", + "\n\n---\n\n", + "\n\n___\n\n", + "\n\n", + "\n", + " ", + "", + ], + Language::Latex => vec![ + "\n\\chapter{", + "\n\\section{", + "\n\\subsection{", + "\n\\subsubsection{", + "\n\\begin{enumerate}", + "\n\\begin{itemize}", + "\n\\begin{description}", + "\n\\begin{list}", + "\n\\begin{quote}", + "\n\\begin{quotation}", + "\n\\begin{verse}", + "\n\\begin{verbatim}", + "\n\\begin{align}", + "$$", + "$", + "\n\n", + "\n", + " ", + "", + ], + Language::Html => vec![ + "", "
", "

", "
", "

  • ", "

    ", "

    ", "

    ", "

    ", "

    ", + "
    ", "", "", "", "
    ", "", "
      ", "
        ", "
        ", + "