feat: add batch_size to RAG yaml (#876)

pull/878/head
sigoden 4 weeks ago committed by GitHub
parent 029058c43d
commit 00c4a6e421
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -78,6 +78,7 @@ impl Rag {
chunk_overlap, chunk_overlap,
reranker_model, reranker_model,
top_k, top_k,
embedding_model.max_batch_size(),
); );
let mut rag = Self::create(config, name, save_path, data)?; let mut rag = Self::create(config, name, save_path, data)?;
let mut paths = doc_paths.to_vec(); let mut paths = doc_paths.to_vec();
@ -288,6 +289,7 @@ impl Rag {
"chunk_overlap": self.data.chunk_overlap, "chunk_overlap": self.data.chunk_overlap,
"reranker_model": self.data.reranker_model, "reranker_model": self.data.reranker_model,
"top_k": self.data.top_k, "top_k": self.data.top_k,
"batch_size": self.data.batch_size,
"document_paths": self.data.document_paths, "document_paths": self.data.document_paths,
"files": files, "files": files,
}); });
@ -559,18 +561,22 @@ impl Rag {
) -> Result<EmbeddingsOutput> { ) -> Result<EmbeddingsOutput> {
let embedding_client = init_client(&self.config, Some(self.embedding_model.clone()))?; let embedding_client = init_client(&self.config, Some(self.embedding_model.clone()))?;
let EmbeddingsData { texts, query } = data; let EmbeddingsData { texts, query } = data;
let size = match self.embedding_model.max_input_tokens() { let batch_size = self
.data
.batch_size
.or_else(|| self.embedding_model.max_batch_size());
let batch_size = match self.embedding_model.max_input_tokens() {
Some(max_input_tokens) => { Some(max_input_tokens) => {
let x = max_input_tokens / self.data.chunk_size; let x = max_input_tokens / self.data.chunk_size;
match self.embedding_model.max_batch_size() { match batch_size {
Some(y) => x.min(y), Some(y) => x.min(y),
None => x, None => x,
} }
} }
None => self.embedding_model.max_batch_size().unwrap_or(1), None => batch_size.unwrap_or(1),
}; };
let mut output = vec![]; let mut output = vec![];
let batch_chunks = texts.chunks(size.max(1)); let batch_chunks = texts.chunks(batch_size.max(1));
let batch_chunks_len = batch_chunks.len(); let batch_chunks_len = batch_chunks.len();
for (index, texts) in batch_chunks.enumerate() { for (index, texts) in batch_chunks.enumerate() {
progress( progress(
@ -598,6 +604,7 @@ pub struct RagData {
pub chunk_overlap: usize, pub chunk_overlap: usize,
pub reranker_model: Option<String>, pub reranker_model: Option<String>,
pub top_k: usize, pub top_k: usize,
pub batch_size: Option<usize>,
pub next_file_id: FileId, pub next_file_id: FileId,
pub document_paths: Vec<String>, pub document_paths: Vec<String>,
pub files: IndexMap<FileId, RagFile>, pub files: IndexMap<FileId, RagFile>,
@ -611,6 +618,9 @@ impl Debug for RagData {
.field("embedding_model", &self.embedding_model) .field("embedding_model", &self.embedding_model)
.field("chunk_size", &self.chunk_size) .field("chunk_size", &self.chunk_size)
.field("chunk_overlap", &self.chunk_overlap) .field("chunk_overlap", &self.chunk_overlap)
.field("reranker_model", &self.reranker_model)
.field("top_k", &self.top_k)
.field("batch_size", &self.batch_size)
.field("next_file_id", &self.next_file_id) .field("next_file_id", &self.next_file_id)
.field("document_paths", &self.document_paths) .field("document_paths", &self.document_paths)
.field("files", &self.files) .field("files", &self.files)
@ -625,6 +635,7 @@ impl RagData {
chunk_overlap: usize, chunk_overlap: usize,
reranker_model: Option<String>, reranker_model: Option<String>,
top_k: usize, top_k: usize,
batch_size: Option<usize>,
) -> Self { ) -> Self {
Self { Self {
embedding_model, embedding_model,
@ -632,6 +643,7 @@ impl RagData {
chunk_overlap, chunk_overlap,
reranker_model, reranker_model,
top_k, top_k,
batch_size,
next_file_id: 0, next_file_id: 0,
document_paths: Default::default(), document_paths: Default::default(),
files: Default::default(), files: Default::default(),

Loading…
Cancel
Save