Fix: Nested Dicts Handling of Document Metadata (#9880)

## Description
When the `MultiQueryRetriever` is used to get the list of documents
relevant according to a query, inside a vector store, and at least one
of these contain metadata with nested dictionaries, a `TypeError:
unhashable type: 'dict'` exception is thrown.
This is caused by the `unique_union` function which, to guarantee the
uniqueness of the returned documents, tries, unsuccessfully, to hash the
nested dictionaries and use them as a part of key.
```python
unique_documents_dict = {
    (doc.page_content, tuple(sorted(doc.metadata.items()))): doc
    for doc in documents
}
```

## Issue
#9872 (MultiQueryRetriever (get_relevant_documents) raises TypeError:
unhashable type: 'dict' with dic metadata)

## Solution
A possible solution is to dump the metadata dict to a string and use it
as a part of hashed key.
```python
unique_documents_dict = {
    (doc.page_content, json.dumps(doc.metadata, sort_keys=True)): doc
    for doc in documents
}
```

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/10155/head
Lorenzo 1 year ago committed by GitHub
parent a52fe9528e
commit 00a7c31ffd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,5 @@
import logging
from typing import List
from typing import List, Sequence
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.chains.llm import LLMChain
@ -43,10 +43,14 @@ DEFAULT_QUERY_PROMPT = PromptTemplate(
)
def _unique_documents(documents: Sequence[Document]) -> List[Document]:
return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
class MultiQueryRetriever(BaseRetriever):
"""Given a query, use an LLM to write a set of queries.
Retrieve docs for each query. Rake the unique union of all retrieved docs.
Retrieve docs for each query. Return the unique union of all retrieved docs.
"""
retriever: BaseRetriever
@ -85,7 +89,7 @@ class MultiQueryRetriever(BaseRetriever):
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""Get relevated documents given a user query.
"""Get relevant documents given a user query.
Args:
question: user query
@ -95,8 +99,7 @@ class MultiQueryRetriever(BaseRetriever):
"""
queries = self.generate_queries(query, run_manager)
documents = self.retrieve_documents(queries, run_manager)
unique_documents = self.unique_union(documents)
return unique_documents
return self.unique_union(documents)
def generate_queries(
self, question: str, run_manager: CallbackManagerForRetrieverRun
@ -145,12 +148,4 @@ class MultiQueryRetriever(BaseRetriever):
Returns:
List of unique retrieved Documents
"""
# Create a dictionary with page_content as keys to remove duplicates
# TODO: Add Document ID property (e.g., UUID)
unique_documents_dict = {
(doc.page_content, tuple(sorted(doc.metadata.items()))): doc
for doc in documents
}
unique_documents = list(unique_documents_dict.values())
return unique_documents
return _unique_documents(documents)

@ -0,0 +1,40 @@
from typing import List
import pytest as pytest
from langchain.retrievers.multi_query import _unique_documents
from langchain.schema import Document
@pytest.mark.parametrize(
"documents,expected",
[
([], []),
([Document(page_content="foo")], [Document(page_content="foo")]),
([Document(page_content="foo")] * 2, [Document(page_content="foo")]),
(
[Document(page_content="foo", metadata={"bar": "baz"})] * 2,
[Document(page_content="foo", metadata={"bar": "baz"})],
),
(
[Document(page_content="foo", metadata={"bar": [1, 2]})] * 2,
[Document(page_content="foo", metadata={"bar": [1, 2]})],
),
(
[Document(page_content="foo", metadata={"bar": {1, 2}})] * 2,
[Document(page_content="foo", metadata={"bar": {1, 2}})],
),
(
[
Document(page_content="foo", metadata={"bar": [1, 2]}),
Document(page_content="foo", metadata={"bar": [2, 1]}),
],
[
Document(page_content="foo", metadata={"bar": [1, 2]}),
Document(page_content="foo", metadata={"bar": [2, 1]}),
],
),
],
)
def test__unique_documents(documents: List[Document], expected: List[Document]) -> None:
assert _unique_documents(documents) == expected
Loading…
Cancel
Save