From 00c4a6e421f01590ffc3fa5601e93d5ec755fca7 Mon Sep 17 00:00:00 2001 From: sigoden Date: Sun, 22 Sep 2024 09:40:52 +0800 Subject: [PATCH] feat: add batch_size to RAG yaml (#876) --- src/rag/mod.rs | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/rag/mod.rs b/src/rag/mod.rs index 973bc1c..62a8ff8 100644 --- a/src/rag/mod.rs +++ b/src/rag/mod.rs @@ -78,6 +78,7 @@ impl Rag { chunk_overlap, reranker_model, top_k, + embedding_model.max_batch_size(), ); let mut rag = Self::create(config, name, save_path, data)?; let mut paths = doc_paths.to_vec(); @@ -288,6 +289,7 @@ impl Rag { "chunk_overlap": self.data.chunk_overlap, "reranker_model": self.data.reranker_model, "top_k": self.data.top_k, + "batch_size": self.data.batch_size, "document_paths": self.data.document_paths, "files": files, }); @@ -559,18 +561,22 @@ impl Rag { ) -> Result { let embedding_client = init_client(&self.config, Some(self.embedding_model.clone()))?; let EmbeddingsData { texts, query } = data; - let size = match self.embedding_model.max_input_tokens() { + let batch_size = self + .data + .batch_size + .or_else(|| self.embedding_model.max_batch_size()); + let batch_size = match self.embedding_model.max_input_tokens() { Some(max_input_tokens) => { let x = max_input_tokens / self.data.chunk_size; - match self.embedding_model.max_batch_size() { + match batch_size { Some(y) => x.min(y), None => x, } } - None => self.embedding_model.max_batch_size().unwrap_or(1), + None => batch_size.unwrap_or(1), }; let mut output = vec![]; - let batch_chunks = texts.chunks(size.max(1)); + let batch_chunks = texts.chunks(batch_size.max(1)); let batch_chunks_len = batch_chunks.len(); for (index, texts) in batch_chunks.enumerate() { progress( @@ -598,6 +604,7 @@ pub struct RagData { pub chunk_overlap: usize, pub reranker_model: Option, pub top_k: usize, + pub batch_size: Option, pub next_file_id: FileId, pub document_paths: Vec, pub files: IndexMap, @@ -611,6 +618,9 @@ impl Debug for RagData { .field("embedding_model", &self.embedding_model) .field("chunk_size", &self.chunk_size) .field("chunk_overlap", &self.chunk_overlap) + .field("reranker_model", &self.reranker_model) + .field("top_k", &self.top_k) + .field("batch_size", &self.batch_size) .field("next_file_id", &self.next_file_id) .field("document_paths", &self.document_paths) .field("files", &self.files) @@ -625,6 +635,7 @@ impl RagData { chunk_overlap: usize, reranker_model: Option, top_k: usize, + batch_size: Option, ) -> Self { Self { embedding_model, @@ -632,6 +643,7 @@ impl RagData { chunk_overlap, reranker_model, top_k, + batch_size, next_file_id: 0, document_paths: Default::default(), files: Default::default(),