From 6335cb5b3a7b57a4a3dfa4b8070d7962aed789b1 Mon Sep 17 00:00:00 2001 From: "Aivin V. Solatorio" Date: Tue, 9 May 2023 13:34:11 -0400 Subject: [PATCH] Add support for Qdrant nested filter (#4354) # Add support for Qdrant nested filter This extends the filter functionality for the Qdrant vectorstore. The current filter implementation is limited to a single-level metadata structure; however, Qdrant supports nested metadata filtering. This extends the functionality for users to maximize the filter functionality when using Qdrant as the vectorstore. Reference: https://qdrant.tech/documentation/filtering/#nested-key --------- Signed-off-by: Aivin V. Solatorio --- langchain/vectorstores/qdrant.py | 55 ++++++++++++++++--- .../vectorstores/test_qdrant.py | 17 +++++- 2 files changed, 61 insertions(+), 11 deletions(-) diff --git a/langchain/vectorstores/qdrant.py b/langchain/vectorstores/qdrant.py index 168da43a..5af38b96 100644 --- a/langchain/vectorstores/qdrant.py +++ b/langchain/vectorstores/qdrant.py @@ -5,7 +5,18 @@ import uuid import warnings from hashlib import md5 from operator import itemgetter -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + Union, +) import numpy as np @@ -14,7 +25,11 @@ from langchain.embeddings.base import Embeddings from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance -MetadataFilter = Dict[str, Union[str, int, bool]] +if TYPE_CHECKING: + from qdrant_client.http import models as rest + + +MetadataFilter = Dict[str, Union[str, int, bool, dict, list]] class Qdrant(VectorStore): @@ -461,18 +476,42 @@ class Qdrant(VectorStore): metadata=scored_point.payload.get(metadata_payload_key) or {}, ) - def _qdrant_filter_from_dict(self, filter: Optional[MetadataFilter]) -> Any: - if filter is None or 0 == len(filter): - return None - + def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]: from qdrant_client.http import models as rest - return rest.Filter( - must=[ + out = [] + + if isinstance(value, dict): + for _key, value in value.items(): + out.extend(self._build_condition(f"{key}.{_key}", value)) + elif isinstance(value, list): + for _value in value: + if isinstance(_value, dict): + out.extend(self._build_condition(f"{key}[]", _value)) + else: + out.extend(self._build_condition(f"{key}", _value)) + else: + out.append( rest.FieldCondition( key=f"{self.metadata_payload_key}.{key}", match=rest.MatchValue(value=value), ) + ) + + return out + + def _qdrant_filter_from_dict( + self, filter: Optional[MetadataFilter] + ) -> Optional[rest.Filter]: + from qdrant_client.http import models as rest + + if not filter: + return None + + return rest.Filter( + must=[ + condition for key, value in filter.items() + for condition in self._build_condition(key, value) ] ) diff --git a/tests/integration_tests/vectorstores/test_qdrant.py b/tests/integration_tests/vectorstores/test_qdrant.py index 1f43a0bc..8362951c 100644 --- a/tests/integration_tests/vectorstores/test_qdrant.py +++ b/tests/integration_tests/vectorstores/test_qdrant.py @@ -78,15 +78,26 @@ def test_qdrant_with_metadatas( def test_qdrant_similarity_search_filters() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] + metadatas = [ + {"page": i, "metadata": {"page": i + 1, "pages": [i + 2, -1]}} + for i in range(len(texts)) + ] docsearch = Qdrant.from_texts( texts, FakeEmbeddings(), metadatas=metadatas, location=":memory:", ) - output = docsearch.similarity_search("foo", k=1, filter={"page": 1}) - assert output == [Document(page_content="bar", metadata={"page": 1})] + + output = docsearch.similarity_search( + "foo", k=1, filter={"page": 1, "metadata": {"page": 2, "pages": [3]}} + ) + assert output == [ + Document( + page_content="bar", + metadata={"page": 1, "metadata": {"page": 2, "pages": [3, -1]}}, + ) + ] @pytest.mark.parametrize(