diff --git a/src/rag/mod.rs b/src/rag/mod.rs index 4f65925..e84ece3 100644 --- a/src/rag/mod.rs +++ b/src/rag/mod.rs @@ -19,12 +19,9 @@ use parking_lot::RwLock; use path_absolutize::Absolutize; use serde::{Deserialize, Serialize}; use serde_json::json; -use std::{collections::HashMap, fmt::Debug, fs, path::Path, time::Duration}; +use std::{collections::HashMap, env, fmt::Debug, fs, path::Path, time::Duration}; use tokio::time::sleep; -const EMBEDDING_RETRY_LIMIT: usize = 3; -const RERANK_RETRY_LIMIT: usize = 2; - pub struct Rag { config: GlobalConfig, name: String, @@ -487,23 +484,7 @@ impl Rag { } } let data = RerankData::new(query.to_string(), documents, top_k); - let mut retry = 0; - let list = loop { - retry += 1; - match client.rerank(&data).await { - Ok(result) => break result, - Err(e) if retry < RERANK_RETRY_LIMIT => { - debug!("retry {} failed: {}", retry, e); - sleep(Duration::from_secs(retry as _)).await; - continue; - } - Err(e) => { - return Err(e).with_context(|| { - format!("Failed to rerank after {RERANK_RETRY_LIMIT} attempts") - })? - } - } - }; + let list = client.rerank(&data).await.context("Failed to rerank")?; let ids: Vec<_> = list .into_iter() .take(top_k) @@ -598,6 +579,10 @@ impl Rag { let mut output = vec![]; let batch_chunks = texts.chunks(batch_size.max(1)); let batch_chunks_len = batch_chunks.len(); + let retry_limit = env::var(get_env_name("embeddings_retry_limit")) + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(2); for (index, texts) in batch_chunks.enumerate() { progress( &spinner, @@ -612,16 +597,14 @@ impl Rag { retry += 1; match embedding_client.embeddings(&chunk_data).await { Ok(v) => break v, - Err(e) if retry < EMBEDDING_RETRY_LIMIT => { + Err(e) if retry < retry_limit => { debug!("retry {} failed: {}", retry, e); - sleep(Duration::from_secs(retry as _)).await; + sleep(Duration::from_secs(2u64.pow(retry - 1))).await; continue; } Err(e) => { return Err(e).with_context(|| { - format!( - "Failed to create embedding after {EMBEDDING_RETRY_LIMIT} attempts" - ) + format!("Failed to create embedding after {retry_limit} attempts") })? } }