feat: rag hybrid search (#618)

pull/620/head
sigoden 3 months ago committed by GitHub
parent 3b3d39cef0
commit 2eab71a641
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

1
Cargo.lock generated

@ -77,6 +77,7 @@ dependencies = [
"path-absolutize",
"pretty_assertions",
"rand",
"rayon",
"reedline",
"reqwest",
"reqwest-eventsource",

@ -61,6 +61,7 @@ bitflags = "2.5.0"
path-absolutize = "3.1.1"
hnsw_rs = "0.3.0"
which = "6.0.1"
rayon = "1.10.0"
[dependencies.reqwest]
version = "0.12.0"

@ -362,7 +362,6 @@ The author mainly focused on writing and programming growing up ...
.set temperature 1.2
.set top_p 0.8
.set rag_top_k 4
.set rag_minimum_score 0
.set function_calling true
.set compress_threshold 1000
.set dry_run true

@ -41,8 +41,10 @@ rag_chunk_size: null
rag_chunk_overlap: null
# Specifies the number of documents to retrieve
rag_top_k: 4
# Specifies the minimum relevance score for retrieved documents
rag_minimum_score: 0
# Specifies the minimum relevance score for vector search
rag_min_score_vector: 0
# Specifies the minimum relevance score for full-text search
rag_min_score_text: 0
# Defines the query structure using variables like __CONTEXT__ and __INPUT__ to tailor searches to specific needs
rag_template: |

@ -169,12 +169,22 @@ impl Input {
if !self.text.is_empty() {
let rag = self.config.read().rag.clone();
if let Some(rag) = rag {
let (top_k, minimum_score) = {
let (top_k, min_score_vector, min_score_text) = {
let config = self.config.read();
(config.rag_top_k, config.rag_minimum_score)
(
config.rag_top_k,
config.rag_min_score_vector,
config.rag_min_score_text,
)
};
let embeddings = rag
.search(&self.text, top_k, minimum_score, abort_signal)
.search(
&self.text,
top_k,
min_score_vector,
min_score_text,
abort_signal,
)
.await?;
let text = self.config.read().rag_template(&embeddings, &self.text);
self.patched_text = Some(text);

@ -105,7 +105,8 @@ pub struct Config {
pub rag_chunk_size: Option<usize>,
pub rag_chunk_overlap: Option<usize>,
pub rag_top_k: usize,
pub rag_minimum_score: f32,
pub rag_min_score_vector: f32,
pub rag_min_score_text: f32,
pub rag_template: Option<String>,
pub compress_threshold: usize,
pub summarize_prompt: Option<String>,
@ -158,7 +159,8 @@ impl Default for Config {
rag_chunk_size: None,
rag_chunk_overlap: None,
rag_top_k: 4,
rag_minimum_score: 0.0,
rag_min_score_vector: 0.0,
rag_min_score_text: 0.0,
rag_template: None,
compress_threshold: 4000,
summarize_prompt: None,
@ -441,7 +443,6 @@ impl Config {
("temperature", format_option_value(&role.temperature())),
("top_p", format_option_value(&role.top_p())),
("rag_top_k", self.rag_top_k.to_string()),
("rag_minimum_score", self.rag_minimum_score.to_string()),
("function_calling", self.function_calling.to_string()),
("compress_threshold", self.compress_threshold.to_string()),
("dry_run", self.dry_run.to_string()),
@ -494,11 +495,6 @@ impl Config {
self.rag_top_k = value;
}
}
"rag_minimum_score" => {
if let Some(value) = parse_value(value)? {
self.rag_minimum_score = value;
}
}
"function_calling" => {
let value = value.parse().with_context(|| "Invalid value")?;
self.function_calling = value;
@ -1057,7 +1053,6 @@ impl Config {
"temperature",
"top_p",
"rag_top_k",
"rag_minimum_score",
"function_calling",
"compress_threshold",
"save",

@ -0,0 +1,172 @@
use rayon::prelude::*;
use std::collections::HashMap;
use std::f64;
#[derive(Debug, Clone)]
pub struct BM25Options {
k1: f64,
b: f64,
epsilon: f64,
}
impl Default for BM25Options {
fn default() -> Self {
Self {
k1: 1.5,
b: 0.75,
epsilon: 0.25,
}
}
}
#[derive(Debug, Clone)]
pub struct BM25<T> {
options: BM25Options,
corpus_size: usize,
avgdl: f64,
doc_freqs: Vec<HashMap<String, u32>>,
doc_ids: Vec<T>,
idf: HashMap<String, f64>,
doc_len: Vec<usize>,
}
impl<T: Clone> BM25<T> {
pub fn new(corpus: Vec<(T, String)>, options: BM25Options) -> Self {
let mut doc_ids = vec![];
let mut docs = vec![];
for (id, value) in corpus {
doc_ids.push(id);
docs.push(value);
}
let tokenized_docs = docs.into_par_iter().map(|text| tokenize(&text)).collect();
let mut bm25 = BM25 {
options,
corpus_size: 0,
avgdl: 0.0,
doc_freqs: Vec::new(),
doc_ids,
idf: HashMap::new(),
doc_len: Vec::new(),
};
let map = bm25.initialize(tokenized_docs);
bm25.calc_idf(map);
bm25
}
pub fn search(&self, query: &str, top_k: usize, min_score: Option<f64>) -> Vec<T> {
let scores = self.get_scores(query);
let mut indexed_scores: Vec<(T, f64)> = scores
.into_iter()
.enumerate()
.filter_map(|(i, v)| match min_score {
Some(minimum_score) => {
if v < minimum_score {
None
} else {
Some((self.doc_ids[i].clone(), v))
}
}
None => Some((self.doc_ids[i].clone(), v)),
})
.collect();
indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
indexed_scores
.into_iter()
.take(top_k)
.map(|(id, _)| id)
.collect()
}
pub fn get_scores(&self, query: &str) -> Vec<f64> {
let mut score = vec![0.0; self.corpus_size];
for q in tokenize(query) {
if let Some(idf) = self.idf.get(&q) {
for (i, doc) in self.doc_freqs.iter().enumerate() {
let q_freq = doc.get(&q).unwrap_or(&0);
score[i] += *idf
* (*q_freq as f64 * (self.options.k1 + 1.0)
/ (*q_freq as f64
+ self.options.k1
* (1.0 - self.options.b
+ self.options.b * self.doc_len[i] as f64 / self.avgdl)));
}
}
}
score
}
fn initialize(&mut self, corpus: Vec<Vec<String>>) -> HashMap<String, usize> {
let mut map = HashMap::new();
let mut num_doc = 0;
for document in corpus {
self.doc_len.push(document.len());
num_doc += document.len();
let mut frequencies = HashMap::new();
for word in document {
*frequencies.entry(word).or_insert(0) += 1;
}
self.doc_freqs.push(frequencies);
for word in self.doc_freqs[self.doc_freqs.len() - 1].keys() {
*map.entry(word.clone()).or_insert(0) += 1;
}
self.corpus_size += 1;
}
self.avgdl = num_doc as f64 / self.corpus_size as f64;
map
}
fn calc_idf(&mut self, map: HashMap<String, usize>) {
let mut idf_sum = 0.0;
let mut negative_idfs = Vec::new();
for (word, freq) in map {
let idf = (self.corpus_size as f64 - freq as f64 + 0.5).ln() - (freq as f64 + 0.5).ln();
self.idf.insert(word.clone(), idf);
idf_sum += idf;
if idf < 0.0 {
negative_idfs.push(word);
}
}
let average_idf = idf_sum / self.idf.len() as f64;
for word in negative_idfs {
self.idf.insert(word, self.options.epsilon * average_idf);
}
}
}
fn tokenize(text: &str) -> Vec<String> {
text.split(' ').map(|v| v.to_string()).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25() {
let corpus = vec![
(0, "Hello there good man!".into()),
(1, "It is quite windy in London".into()),
(2, "How is the weather today?".into()),
];
let bm25 = BM25::new(corpus, BM25Options::default());
let scores = bm25.get_scores("windy London");
assert_eq!(scores, [0.0, 0.9372947225064051, 0.0]);
let top_n = bm25.search("windy London", 3, None);
assert_eq!(top_n, vec![1, 0, 2])
}
}

@ -1,3 +1,4 @@
use self::bm25::*;
use self::loader::*;
use self::splitter::*;
@ -5,6 +6,7 @@ use crate::client::*;
use crate::config::*;
use crate::utils::*;
mod bm25;
mod loader;
mod splitter;
@ -16,8 +18,7 @@ use inquire::{required, validator::Validation, Select, Text};
use path_absolutize::Absolutize;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::fmt::Debug;
use std::{io::BufReader, path::Path};
use std::{collections::HashMap, fmt::Debug, io::BufReader, path::Path};
use tokio::sync::mpsc;
pub struct Rag {
@ -26,6 +27,7 @@ pub struct Rag {
path: String,
model: Model,
hnsw: Hnsw<'static, f32, DistCosine>,
bm25: BM25<VectorID>,
data: RagData,
}
@ -85,6 +87,7 @@ impl Rag {
pub fn create(config: &GlobalConfig, name: &str, path: &Path, data: RagData) -> Result<Self> {
let hnsw = data.build_hnsw();
let bm25 = data.build_bm25();
let model = Model::retrieve_embedding(&config.read(), &data.model)?;
let client = init_client(config, Some(model.clone()))?;
let rag = Rag {
@ -94,6 +97,7 @@ impl Rag {
data,
model,
hnsw,
bm25,
};
Ok(rag)
}
@ -194,12 +198,13 @@ impl Rag {
&self,
text: &str,
top_k: usize,
minimum_score: f32,
min_score_vector: f32,
min_score_text: f32,
abort_signal: AbortSignal,
) -> Result<String> {
let (stop_spinner_tx, _) = run_spinner("Searching").await;
let ret = tokio::select! {
ret = self.search_impl(text, top_k, minimum_score) => {
ret = self.hybird_search(text, top_k, min_score_vector, min_score_text) => {
ret
}
_ = watch_abort_signal(abort_signal) => {
@ -289,18 +294,44 @@ impl Rag {
Ok(())
}
async fn search_impl(
async fn hybird_search(
&self,
text: &str,
query: &str,
top_k: usize,
minimum_score: f32,
min_score_vector: f32,
min_score_text: f32,
) -> Result<Vec<String>> {
let (vector_search_result, text_search_result) = tokio::join!(
self.vector_search(query, top_k, min_score_vector),
self.text_search(query, top_k, min_score_text)
);
let vector_search_ids = vector_search_result?;
let text_search_ids = text_search_result?;
let ids = reciprocal_rank_fusion(vector_search_ids, text_search_ids, 1.0, 1.0, top_k);
let output: Vec<_> = ids
.into_iter()
.filter_map(|id| {
let (file_index, document_index) = split_vector_id(id);
let file = self.data.files.get(file_index)?;
let document = file.documents.get(document_index)?;
Some(document.page_content.clone())
})
.collect();
Ok(output)
}
async fn vector_search(
&self,
query: &str,
top_k: usize,
min_score: f32,
) -> Result<Vec<VectorID>> {
let splitter = RecursiveCharacterTextSplitter::new(
self.data.chunk_size,
self.data.chunk_overlap,
&DEFAULT_SEPARATES,
);
let texts = splitter.split_text(text);
let texts = splitter.split_text(query);
let embeddings_data = EmbeddingsData::new(texts, true);
let embeddings = self.create_embeddings(embeddings_data, None).await?;
let output = self
@ -310,13 +341,10 @@ impl Rag {
.flat_map(|list| {
list.into_iter()
.filter_map(|v| {
if v.distance < minimum_score {
if v.distance < min_score {
return None;
}
let (file_index, document_index) = split_vector_id(v.d_id);
let file = self.data.files.get(file_index)?;
let document = file.documents.get(document_index)?;
Some(document.page_content.clone())
Some(v.d_id)
})
.collect::<Vec<_>>()
})
@ -324,6 +352,16 @@ impl Rag {
Ok(output)
}
async fn text_search(
&self,
query: &str,
top_k: usize,
min_score: f32,
) -> Result<Vec<VectorID>> {
let output = self.bm25.search(query, top_k, Some(min_score as f64));
Ok(output)
}
async fn create_embeddings(
&self,
data: EmbeddingsData,
@ -393,6 +431,17 @@ impl RagData {
hnsw.parallel_insert(&list);
hnsw
}
pub fn build_bm25(&self) -> BM25<VectorID> {
let mut corpus = vec![];
for (file_index, file) in self.files.iter().enumerate() {
for (document_index, document) in file.documents.iter().enumerate() {
let id = combine_vector_id(file_index, document_index);
corpus.push((id, document.page_content.clone()));
}
}
BM25::new(corpus, BM25Options::default())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -502,3 +551,29 @@ fn progress(spinner_message_tx: &Option<mpsc::UnboundedSender<String>>, message:
let _ = tx.send(message);
}
}
fn reciprocal_rank_fusion(
vector_search_ids: Vec<VectorID>,
text_search_ids: Vec<VectorID>,
vector_search_weight: f32,
text_search_weight: f32,
top_k: usize,
) -> Vec<VectorID> {
let rrf_k = top_k * 2;
let mut map: HashMap<VectorID, f32> = HashMap::new();
for (index, &item) in vector_search_ids.iter().enumerate() {
*map.entry(item).or_default() +=
(1.0 / ((rrf_k + index + 1) as f32)) * vector_search_weight;
}
for (index, &item) in text_search_ids.iter().enumerate() {
*map.entry(item).or_default() += (1.0 / ((rrf_k + index + 1) as f32)) * text_search_weight;
}
let mut sorted_items: Vec<(VectorID, f32)> = map.into_iter().collect();
sorted_items.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
sorted_items
.into_iter()
.take(top_k)
.map(|(v, _)| v)
.collect()
}

@ -1,5 +1,6 @@
use anyhow::Result;
use crossterm::{cursor, queue, style, terminal};
use is_terminal::IsTerminal;
use std::{
io::{stdout, Stdout, Write},
time::Duration,
@ -76,6 +77,7 @@ async fn run_spinner_inner(
mut message_rx: mpsc::UnboundedReceiver<String>,
) -> Result<()> {
let mut writer = stdout();
let is_stdout_terminal = stdout().is_terminal();
let mut spinner = Spinner::new(&message);
let mut interval = interval(Duration::from_millis(50));
tokio::select! {
@ -83,7 +85,9 @@ async fn run_spinner_inner(
loop {
tokio::select! {
_ = interval.tick() => {
let _ = spinner.step(&mut writer);
if is_stdout_terminal {
let _ = spinner.step(&mut writer);
}
}
message = message_rx.recv() => {
if let Some(message) = message {
@ -94,7 +98,9 @@ async fn run_spinner_inner(
}
} => {}
_ = stop_rx => {
spinner.stop(&mut writer)?;
if is_stdout_terminal {
spinner.stop(&mut writer)?;
}
}
}
Ok(())

Loading…
Cancel
Save