From 239dd7c0c03d0430c55c2c41cf56cf0dd537199b Mon Sep 17 00:00:00 2001 From: Nilanjan De Date: Thu, 28 Mar 2024 10:52:57 +0400 Subject: [PATCH] langchain[patch]: Use map() and avoid "ValueError: max() arg is an empty sequence" in MergerRetriever (#18679) - **Issue:** When passing an empty list to MergerRetriever it fails with error: ValueError: max() arg is an empty sequence - **Description:** We have a use case where we dynamically select retrievers and use MergerRetriever for merging the output of the retrievers. We faced this issue when the retriever_docs list is empty. Adding a default 0 for cases when retriever_docs is an empty list to avoid "ValueError: max() arg is an empty sequence". Also, changed to use map() which is more than twice as fast compared to the current implementation. ``` import timeit # Sample retriever_docs with varying lengths of sublists retriever_docs = [[i for i in range(j)] for j in range(1, 1000)] # First code snippet code1 = ''' max_docs = max(len(docs) for docs in retriever_docs) ''' # Second code snippet code2 = ''' max_docs = max(map(len, retriever_docs), default=0) ''' # Benchmarking time1 = timeit.timeit(stmt=code1, globals=globals(), number=10000) time2 = timeit.timeit(stmt=code2, globals=globals(), number=10000) # Output print(f"Execution time for code snippet 1: {time1} seconds") print(f"Execution time for code snippet 2: {time2} seconds") ``` - **Dependencies:** none --- libs/langchain/langchain/retrievers/merger_retriever.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/retrievers/merger_retriever.py b/libs/langchain/langchain/retrievers/merger_retriever.py index f5326773bc..4979779c25 100644 --- a/libs/langchain/langchain/retrievers/merger_retriever.py +++ b/libs/langchain/langchain/retrievers/merger_retriever.py @@ -80,7 +80,7 @@ class MergerRetriever(BaseRetriever): # Merge the results of the retrievers. merged_documents = [] - max_docs = max(len(docs) for docs in retriever_docs) + max_docs = max(map(len, retriever_docs), default=0) for i in range(max_docs): for retriever, doc in zip(self.retrievers, retriever_docs): if i < len(doc): @@ -113,7 +113,7 @@ class MergerRetriever(BaseRetriever): # Merge the results of the retrievers. merged_documents = [] - max_docs = max(len(docs) for docs in retriever_docs) + max_docs = max(map(len, retriever_docs), default=0) for i in range(max_docs): for retriever, doc in zip(self.retrievers, retriever_docs): if i < len(doc):