From 00a7c31ffd3cb1d2f16cdc4e3e7f69071b306beb Mon Sep 17 00:00:00 2001 From: Lorenzo <44714920+lorenzofavaro@users.noreply.github.com> Date: Mon, 4 Sep 2023 00:27:46 +0200 Subject: [PATCH] Fix: Nested Dicts Handling of Document Metadata (#9880) ## Description When the `MultiQueryRetriever` is used to get the list of documents relevant according to a query, inside a vector store, and at least one of these contain metadata with nested dictionaries, a `TypeError: unhashable type: 'dict'` exception is thrown. This is caused by the `unique_union` function which, to guarantee the uniqueness of the returned documents, tries, unsuccessfully, to hash the nested dictionaries and use them as a part of key. ```python unique_documents_dict = { (doc.page_content, tuple(sorted(doc.metadata.items()))): doc for doc in documents } ``` ## Issue #9872 (MultiQueryRetriever (get_relevant_documents) raises TypeError: unhashable type: 'dict' with dic metadata) ## Solution A possible solution is to dump the metadata dict to a string and use it as a part of hashed key. ```python unique_documents_dict = { (doc.page_content, json.dumps(doc.metadata, sort_keys=True)): doc for doc in documents } ``` --------- Co-authored-by: Bagatur --- .../langchain/retrievers/multi_query.py | 23 +++++------ .../unit_tests/retrievers/test_multi_query.py | 40 +++++++++++++++++++ 2 files changed, 49 insertions(+), 14 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/retrievers/test_multi_query.py diff --git a/libs/langchain/langchain/retrievers/multi_query.py b/libs/langchain/langchain/retrievers/multi_query.py index 8398ce40d9..b99bb84a7e 100644 --- a/libs/langchain/langchain/retrievers/multi_query.py +++ b/libs/langchain/langchain/retrievers/multi_query.py @@ -1,5 +1,5 @@ import logging -from typing import List +from typing import List, Sequence from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.chains.llm import LLMChain @@ -43,10 +43,14 @@ DEFAULT_QUERY_PROMPT = PromptTemplate( ) +def _unique_documents(documents: Sequence[Document]) -> List[Document]: + return [doc for i, doc in enumerate(documents) if doc not in documents[:i]] + + class MultiQueryRetriever(BaseRetriever): """Given a query, use an LLM to write a set of queries. - Retrieve docs for each query. Rake the unique union of all retrieved docs. + Retrieve docs for each query. Return the unique union of all retrieved docs. """ retriever: BaseRetriever @@ -85,7 +89,7 @@ class MultiQueryRetriever(BaseRetriever): *, run_manager: CallbackManagerForRetrieverRun, ) -> List[Document]: - """Get relevated documents given a user query. + """Get relevant documents given a user query. Args: question: user query @@ -95,8 +99,7 @@ class MultiQueryRetriever(BaseRetriever): """ queries = self.generate_queries(query, run_manager) documents = self.retrieve_documents(queries, run_manager) - unique_documents = self.unique_union(documents) - return unique_documents + return self.unique_union(documents) def generate_queries( self, question: str, run_manager: CallbackManagerForRetrieverRun @@ -145,12 +148,4 @@ class MultiQueryRetriever(BaseRetriever): Returns: List of unique retrieved Documents """ - # Create a dictionary with page_content as keys to remove duplicates - # TODO: Add Document ID property (e.g., UUID) - unique_documents_dict = { - (doc.page_content, tuple(sorted(doc.metadata.items()))): doc - for doc in documents - } - - unique_documents = list(unique_documents_dict.values()) - return unique_documents + return _unique_documents(documents) diff --git a/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py b/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py new file mode 100644 index 0000000000..978950ec58 --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py @@ -0,0 +1,40 @@ +from typing import List + +import pytest as pytest + +from langchain.retrievers.multi_query import _unique_documents +from langchain.schema import Document + + +@pytest.mark.parametrize( + "documents,expected", + [ + ([], []), + ([Document(page_content="foo")], [Document(page_content="foo")]), + ([Document(page_content="foo")] * 2, [Document(page_content="foo")]), + ( + [Document(page_content="foo", metadata={"bar": "baz"})] * 2, + [Document(page_content="foo", metadata={"bar": "baz"})], + ), + ( + [Document(page_content="foo", metadata={"bar": [1, 2]})] * 2, + [Document(page_content="foo", metadata={"bar": [1, 2]})], + ), + ( + [Document(page_content="foo", metadata={"bar": {1, 2}})] * 2, + [Document(page_content="foo", metadata={"bar": {1, 2}})], + ), + ( + [ + Document(page_content="foo", metadata={"bar": [1, 2]}), + Document(page_content="foo", metadata={"bar": [2, 1]}), + ], + [ + Document(page_content="foo", metadata={"bar": [1, 2]}), + Document(page_content="foo", metadata={"bar": [2, 1]}), + ], + ), + ], +) +def test__unique_documents(documents: List[Document], expected: List[Document]) -> None: + assert _unique_documents(documents) == expected