refactor: rag default_chunk_size (#588)

pull/589/head
sigoden 4 months ago committed by GitHub
parent e9e48f2320
commit 1f33b3a07a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -68,12 +68,12 @@
- name: text-embedding-3-large
mode: embedding
max_input_tokens: 8191
default_chunk_size: 8000
default_chunk_size: 4000
max_concurrent_chunks: 100
- name: text-embedding-3-small
mode: embedding
max_input_tokens: 8191
default_chunk_size: 8000
default_chunk_size: 4000
max_concurrent_chunks: 100
- platform: gemini
@ -179,7 +179,7 @@
- name: mistral-embed
mode: embedding
max_input_tokens: 8092
default_chunk_size: 8000
default_chunk_size: 4000
- platform: cohere
# docs:
@ -204,12 +204,12 @@
- name: embed-english-v3.0
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
default_chunk_size: 500
max_concurrent_chunks: 96
- name: embed-multilingual-v3.0
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
default_chunk_size: 500
max_concurrent_chunks: 96
- platform: perplexity

@ -20,7 +20,6 @@ use std::fmt::Debug;
use std::{io::BufReader, path::Path};
use tokio::sync::mpsc;
pub const CHUNK_OVERLAP: usize = 20;
pub const SIMILARITY_THRESHOLD: f32 = 0.25;
pub struct Rag {
@ -53,9 +52,9 @@ impl Rag {
) -> Result<Self> {
debug!("init rag: {name}");
let model = select_embedding_model(config)?;
let chunk_size = model.default_chunk_size();
let chunk_size = set_chunk_size(chunk_size)?;
let data = RagData::new(&model.id(), chunk_size);
let chunk_size = set_chunk_size(&model)?;
let chunk_overlap = chunk_size / 20;
let data = RagData::new(&model.id(), chunk_size, chunk_overlap);
let mut rag = Self::create(config, name, save_path, data)?;
let mut paths = doc_paths.to_vec();
if paths.is_empty() {
@ -188,7 +187,7 @@ impl Rag {
.map(|v| v.to_string_lossy().to_lowercase())
.unwrap_or_default();
let separator = autodetect_separator(&extension);
let splitter = Splitter::new(self.data.chunk_size, CHUNK_OVERLAP, separator);
let splitter = Splitter::new(self.data.chunk_size, self.data.chunk_overlap, separator);
let documents = load(&path, &extension)
.await
.with_context(|| format!("Failed to load text at '{path}'"))?;
@ -228,7 +227,11 @@ impl Rag {
}
async fn search_impl(&self, text: &str, top_k: usize) -> Result<Vec<String>> {
let splitter = Splitter::new(self.data.chunk_size, CHUNK_OVERLAP, &DEFAULT_SEPARATES);
let splitter = Splitter::new(
self.data.chunk_size,
self.data.chunk_overlap,
&DEFAULT_SEPARATES,
);
let texts = splitter.split_text(text);
let embeddings_data = EmbeddingsData::new(texts, true);
let embeddings = self.create_embeddings(embeddings_data, None).await?;
@ -291,15 +294,17 @@ impl Rag {
pub struct RagData {
pub model: String,
pub chunk_size: usize,
pub chunk_overlap: usize,
pub files: Vec<RagFile>,
pub vectors: IndexMap<VectorID, Vec<f32>>,
}
impl RagData {
pub fn new(model: &str, chunk_size: usize) -> Self {
pub fn new(model: &str, chunk_size: usize, chunk_overlap: usize) -> Self {
Self {
model: model.to_string(),
chunk_size,
chunk_overlap,
files: Default::default(),
vectors: Default::default(),
}
@ -397,17 +402,25 @@ fn select_embedding_model(config: &GlobalConfig) -> Result<Model> {
Ok(model)
}
fn set_chunk_size(chunk_size: usize) -> Result<usize> {
let value = Text::new("Set chunk size:")
.with_default(&chunk_size.to_string())
fn set_chunk_size(model: &Model) -> Result<usize> {
let default_value = model.default_chunk_size().to_string();
let help_message = model
.max_input_tokens()
.map(|v| format!("The model's max_input_token is {v}"));
let mut text = Text::new("Set chunk size:")
.with_default(&default_value)
.with_validator(move |text: &str| {
let out = match text.parse::<usize>() {
Ok(_) => Validation::Valid,
Err(_) => Validation::Invalid("Must be a integer".into()),
};
Ok(out)
})
.prompt()?;
});
if let Some(help_message) = &help_message {
text = text.with_help_message(help_message);
}
let value = text.prompt()?;
value.parse().map_err(|_| anyhow!("Invalid chunk_size"))
}

Loading…
Cancel
Save