diff --git a/langchain/vectorstores/qdrant.py b/langchain/vectorstores/qdrant.py index b9424c79..4ce535fc 100644 --- a/langchain/vectorstores/qdrant.py +++ b/langchain/vectorstores/qdrant.py @@ -1,13 +1,15 @@ """Wrapper around Qdrant vector database.""" import uuid from operator import itemgetter -from typing import Any, Callable, Iterable, List, Optional, Tuple, cast +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance +MetadataFilter = Dict[str, Union[str, int, bool]] + class Qdrant(VectorStore): """Wrapper around Qdrant vector database. @@ -91,28 +93,34 @@ class Qdrant(VectorStore): return ids def similarity_search( - self, query: str, k: int = 4, **kwargs: Any + self, + query: str, + k: int = 4, + filter: Optional[MetadataFilter] = None, + **kwargs: Any, ) -> List[Document]: """Return docs most similar to query. Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. + filter: Filter by metadata. Defaults to None. Returns: List of Documents most similar to the query. """ - results = self.similarity_search_with_score(query, k) + results = self.similarity_search_with_score(query, k, filter) return list(map(itemgetter(0), results)) def similarity_search_with_score( - self, query: str, k: int = 4 + self, query: str, k: int = 4, filter: Optional[MetadataFilter] = None ) -> List[Tuple[Document, float]]: """Return docs most similar to query. Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. + filter: Filter by metadata. Defaults to None. Returns: List of Documents most similar to the query and score for each @@ -121,6 +129,7 @@ class Qdrant(VectorStore): results = self.client.search( collection_name=self.collection_name, query_vector=embedding, + query_filter=self._qdrant_filter_from_dict(filter), with_payload=True, limit=k, ) @@ -380,3 +389,19 @@ class Qdrant(VectorStore): page_content=scored_point.payload.get(content_payload_key), metadata=scored_point.payload.get(metadata_payload_key) or {}, ) + + def _qdrant_filter_from_dict(self, filter: Optional[MetadataFilter]) -> Any: + if filter is None or 0 == len(filter): + return None + + from qdrant_client.http import models as rest + + return rest.Filter( + must=[ + rest.FieldCondition( + key=f"{self.metadata_payload_key}.{key}", + match=rest.MatchValue(value=value), + ) + for key, value in filter.items() + ] + ) diff --git a/tests/integration_tests/vectorstores/test_qdrant.py b/tests/integration_tests/vectorstores/test_qdrant.py index 0e249b22..fe9498b7 100644 --- a/tests/integration_tests/vectorstores/test_qdrant.py +++ b/tests/integration_tests/vectorstores/test_qdrant.py @@ -56,6 +56,20 @@ def test_qdrant_with_metadatas( assert output == [Document(page_content="foo", metadata={"page": 0})] +def test_qdrant_similarity_search_filters() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = Qdrant.from_texts( + texts, + FakeEmbeddings(), + metadatas=metadatas, + host="localhost", + ) + output = docsearch.similarity_search("foo", k=1, filter={"page": 1}) + assert output == [Document(page_content="bar", metadata={"page": 1})] + + @pytest.mark.parametrize( ["content_payload_key", "metadata_payload_key"], [