|
|
|
@ -20,7 +20,6 @@ use std::fmt::Debug;
|
|
|
|
|
use std::{io::BufReader, path::Path};
|
|
|
|
|
use tokio::sync::mpsc;
|
|
|
|
|
|
|
|
|
|
pub const CHUNK_OVERLAP: usize = 20;
|
|
|
|
|
pub const SIMILARITY_THRESHOLD: f32 = 0.25;
|
|
|
|
|
|
|
|
|
|
pub struct Rag {
|
|
|
|
@ -53,9 +52,9 @@ impl Rag {
|
|
|
|
|
) -> Result<Self> {
|
|
|
|
|
debug!("init rag: {name}");
|
|
|
|
|
let model = select_embedding_model(config)?;
|
|
|
|
|
let chunk_size = model.default_chunk_size();
|
|
|
|
|
let chunk_size = set_chunk_size(chunk_size)?;
|
|
|
|
|
let data = RagData::new(&model.id(), chunk_size);
|
|
|
|
|
let chunk_size = set_chunk_size(&model)?;
|
|
|
|
|
let chunk_overlap = chunk_size / 20;
|
|
|
|
|
let data = RagData::new(&model.id(), chunk_size, chunk_overlap);
|
|
|
|
|
let mut rag = Self::create(config, name, save_path, data)?;
|
|
|
|
|
let mut paths = doc_paths.to_vec();
|
|
|
|
|
if paths.is_empty() {
|
|
|
|
@ -188,7 +187,7 @@ impl Rag {
|
|
|
|
|
.map(|v| v.to_string_lossy().to_lowercase())
|
|
|
|
|
.unwrap_or_default();
|
|
|
|
|
let separator = autodetect_separator(&extension);
|
|
|
|
|
let splitter = Splitter::new(self.data.chunk_size, CHUNK_OVERLAP, separator);
|
|
|
|
|
let splitter = Splitter::new(self.data.chunk_size, self.data.chunk_overlap, separator);
|
|
|
|
|
let documents = load(&path, &extension)
|
|
|
|
|
.await
|
|
|
|
|
.with_context(|| format!("Failed to load text at '{path}'"))?;
|
|
|
|
@ -228,7 +227,11 @@ impl Rag {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn search_impl(&self, text: &str, top_k: usize) -> Result<Vec<String>> {
|
|
|
|
|
let splitter = Splitter::new(self.data.chunk_size, CHUNK_OVERLAP, &DEFAULT_SEPARATES);
|
|
|
|
|
let splitter = Splitter::new(
|
|
|
|
|
self.data.chunk_size,
|
|
|
|
|
self.data.chunk_overlap,
|
|
|
|
|
&DEFAULT_SEPARATES,
|
|
|
|
|
);
|
|
|
|
|
let texts = splitter.split_text(text);
|
|
|
|
|
let embeddings_data = EmbeddingsData::new(texts, true);
|
|
|
|
|
let embeddings = self.create_embeddings(embeddings_data, None).await?;
|
|
|
|
@ -291,15 +294,17 @@ impl Rag {
|
|
|
|
|
pub struct RagData {
|
|
|
|
|
pub model: String,
|
|
|
|
|
pub chunk_size: usize,
|
|
|
|
|
pub chunk_overlap: usize,
|
|
|
|
|
pub files: Vec<RagFile>,
|
|
|
|
|
pub vectors: IndexMap<VectorID, Vec<f32>>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl RagData {
|
|
|
|
|
pub fn new(model: &str, chunk_size: usize) -> Self {
|
|
|
|
|
pub fn new(model: &str, chunk_size: usize, chunk_overlap: usize) -> Self {
|
|
|
|
|
Self {
|
|
|
|
|
model: model.to_string(),
|
|
|
|
|
chunk_size,
|
|
|
|
|
chunk_overlap,
|
|
|
|
|
files: Default::default(),
|
|
|
|
|
vectors: Default::default(),
|
|
|
|
|
}
|
|
|
|
@ -397,17 +402,25 @@ fn select_embedding_model(config: &GlobalConfig) -> Result<Model> {
|
|
|
|
|
Ok(model)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn set_chunk_size(chunk_size: usize) -> Result<usize> {
|
|
|
|
|
let value = Text::new("Set chunk size:")
|
|
|
|
|
.with_default(&chunk_size.to_string())
|
|
|
|
|
fn set_chunk_size(model: &Model) -> Result<usize> {
|
|
|
|
|
let default_value = model.default_chunk_size().to_string();
|
|
|
|
|
let help_message = model
|
|
|
|
|
.max_input_tokens()
|
|
|
|
|
.map(|v| format!("The model's max_input_token is {v}"));
|
|
|
|
|
|
|
|
|
|
let mut text = Text::new("Set chunk size:")
|
|
|
|
|
.with_default(&default_value)
|
|
|
|
|
.with_validator(move |text: &str| {
|
|
|
|
|
let out = match text.parse::<usize>() {
|
|
|
|
|
Ok(_) => Validation::Valid,
|
|
|
|
|
Err(_) => Validation::Invalid("Must be a integer".into()),
|
|
|
|
|
};
|
|
|
|
|
Ok(out)
|
|
|
|
|
})
|
|
|
|
|
.prompt()?;
|
|
|
|
|
});
|
|
|
|
|
if let Some(help_message) = &help_message {
|
|
|
|
|
text = text.with_help_message(help_message);
|
|
|
|
|
}
|
|
|
|
|
let value = text.prompt()?;
|
|
|
|
|
value.parse().map_err(|_| anyhow!("Invalid chunk_size"))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|