From 4a327dd1d602163df58f8807b67a3dafc29c7b58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Wed, 15 Mar 2023 15:31:39 +0100 Subject: [PATCH] Implement basic metadata filtering in Qdrant (#1689) This PR implements a basic metadata filtering mechanism similar to the ones in Chroma and Pinecone. It still cannot express complex conditions, as there are no operators, but some users requested to have that feature available. --- langchain/vectorstores/qdrant.py | 33 ++++++++++++++++--- .../vectorstores/test_qdrant.py | 14 ++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/langchain/vectorstores/qdrant.py b/langchain/vectorstores/qdrant.py index b9424c79..4ce535fc 100644 --- a/langchain/vectorstores/qdrant.py +++ b/langchain/vectorstores/qdrant.py @@ -1,13 +1,15 @@ """Wrapper around Qdrant vector database.""" import uuid from operator import itemgetter -from typing import Any, Callable, Iterable, List, Optional, Tuple, cast +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance +MetadataFilter = Dict[str, Union[str, int, bool]] + class Qdrant(VectorStore): """Wrapper around Qdrant vector database. @@ -91,28 +93,34 @@ class Qdrant(VectorStore): return ids def similarity_search( - self, query: str, k: int = 4, **kwargs: Any + self, + query: str, + k: int = 4, + filter: Optional[MetadataFilter] = None, + **kwargs: Any, ) -> List[Document]: """Return docs most similar to query. Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. + filter: Filter by metadata. Defaults to None. Returns: List of Documents most similar to the query. """ - results = self.similarity_search_with_score(query, k) + results = self.similarity_search_with_score(query, k, filter) return list(map(itemgetter(0), results)) def similarity_search_with_score( - self, query: str, k: int = 4 + self, query: str, k: int = 4, filter: Optional[MetadataFilter] = None ) -> List[Tuple[Document, float]]: """Return docs most similar to query. Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. + filter: Filter by metadata. Defaults to None. Returns: List of Documents most similar to the query and score for each @@ -121,6 +129,7 @@ class Qdrant(VectorStore): results = self.client.search( collection_name=self.collection_name, query_vector=embedding, + query_filter=self._qdrant_filter_from_dict(filter), with_payload=True, limit=k, ) @@ -380,3 +389,19 @@ class Qdrant(VectorStore): page_content=scored_point.payload.get(content_payload_key), metadata=scored_point.payload.get(metadata_payload_key) or {}, ) + + def _qdrant_filter_from_dict(self, filter: Optional[MetadataFilter]) -> Any: + if filter is None or 0 == len(filter): + return None + + from qdrant_client.http import models as rest + + return rest.Filter( + must=[ + rest.FieldCondition( + key=f"{self.metadata_payload_key}.{key}", + match=rest.MatchValue(value=value), + ) + for key, value in filter.items() + ] + ) diff --git a/tests/integration_tests/vectorstores/test_qdrant.py b/tests/integration_tests/vectorstores/test_qdrant.py index 0e249b22..fe9498b7 100644 --- a/tests/integration_tests/vectorstores/test_qdrant.py +++ b/tests/integration_tests/vectorstores/test_qdrant.py @@ -56,6 +56,20 @@ def test_qdrant_with_metadatas( assert output == [Document(page_content="foo", metadata={"page": 0})] +def test_qdrant_similarity_search_filters() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = Qdrant.from_texts( + texts, + FakeEmbeddings(), + metadatas=metadatas, + host="localhost", + ) + output = docsearch.similarity_search("foo", k=1, filter={"page": 1}) + assert output == [Document(page_content="bar", metadata={"page": 1})] + + @pytest.mark.parametrize( ["content_payload_key", "metadata_payload_key"], [