Add support for Qdrant nested filter (#4354)

# Add support for Qdrant nested filter

This extends the filter functionality for the Qdrant vectorstore. The
current filter implementation is limited to a single-level metadata
structure; however, Qdrant supports nested metadata filtering. This
extends the functionality for users to maximize the filter functionality
when using Qdrant as the vectorstore.

Reference: https://qdrant.tech/documentation/filtering/#nested-key

---------

Signed-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>
parallel_dir_loader
Aivin V. Solatorio 1 year ago committed by GitHub
parent 872605a5c5
commit 6335cb5b3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,7 +5,18 @@ import uuid
import warnings import warnings
from hashlib import md5 from hashlib import md5
from operator import itemgetter from operator import itemgetter
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
Union,
)
import numpy as np import numpy as np
@ -14,7 +25,11 @@ from langchain.embeddings.base import Embeddings
from langchain.vectorstores import VectorStore from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance from langchain.vectorstores.utils import maximal_marginal_relevance
MetadataFilter = Dict[str, Union[str, int, bool]] if TYPE_CHECKING:
from qdrant_client.http import models as rest
MetadataFilter = Dict[str, Union[str, int, bool, dict, list]]
class Qdrant(VectorStore): class Qdrant(VectorStore):
@ -461,18 +476,42 @@ class Qdrant(VectorStore):
metadata=scored_point.payload.get(metadata_payload_key) or {}, metadata=scored_point.payload.get(metadata_payload_key) or {},
) )
def _qdrant_filter_from_dict(self, filter: Optional[MetadataFilter]) -> Any: def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
if filter is None or 0 == len(filter):
return None
from qdrant_client.http import models as rest from qdrant_client.http import models as rest
return rest.Filter( out = []
must=[
if isinstance(value, dict):
for _key, value in value.items():
out.extend(self._build_condition(f"{key}.{_key}", value))
elif isinstance(value, list):
for _value in value:
if isinstance(_value, dict):
out.extend(self._build_condition(f"{key}[]", _value))
else:
out.extend(self._build_condition(f"{key}", _value))
else:
out.append(
rest.FieldCondition( rest.FieldCondition(
key=f"{self.metadata_payload_key}.{key}", key=f"{self.metadata_payload_key}.{key}",
match=rest.MatchValue(value=value), match=rest.MatchValue(value=value),
) )
)
return out
def _qdrant_filter_from_dict(
self, filter: Optional[MetadataFilter]
) -> Optional[rest.Filter]:
from qdrant_client.http import models as rest
if not filter:
return None
return rest.Filter(
must=[
condition
for key, value in filter.items() for key, value in filter.items()
for condition in self._build_condition(key, value)
] ]
) )

@ -78,15 +78,26 @@ def test_qdrant_with_metadatas(
def test_qdrant_similarity_search_filters() -> None: def test_qdrant_similarity_search_filters() -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))] metadatas = [
{"page": i, "metadata": {"page": i + 1, "pages": [i + 2, -1]}}
for i in range(len(texts))
]
docsearch = Qdrant.from_texts( docsearch = Qdrant.from_texts(
texts, texts,
FakeEmbeddings(), FakeEmbeddings(),
metadatas=metadatas, metadatas=metadatas,
location=":memory:", location=":memory:",
) )
output = docsearch.similarity_search("foo", k=1, filter={"page": 1})
assert output == [Document(page_content="bar", metadata={"page": 1})] output = docsearch.similarity_search(
"foo", k=1, filter={"page": 1, "metadata": {"page": 2, "pages": [3]}}
)
assert output == [
Document(
page_content="bar",
metadata={"page": 1, "metadata": {"page": 2, "pages": [3, -1]}},
)
]
@pytest.mark.parametrize( @pytest.mark.parametrize(

Loading…
Cancel
Save