Fix AstraDB logical operator filtering (#15699)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This change fixes the AstraDB logical operator filtering (`$and,`
`$or`).
The `metadata` prefix must not be added if the key is `$and` or `$or`.
pull/15723/head
Christophe Bornet 6 months ago committed by GitHub
parent 1f5f6381ec
commit a466f79ac9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -121,11 +121,21 @@ class AstraDB(VectorStore):
"""
@staticmethod
def _filter_to_metadata(filter_dict: Optional[Dict[str, str]]) -> Dict[str, Any]:
def _filter_to_metadata(filter_dict: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if filter_dict is None:
return {}
else:
return {f"metadata.{mdk}": mdv for mdk, mdv in filter_dict.items()}
metadata_filter = {}
for k, v in filter_dict.items():
if k and k[0] == "$":
if isinstance(v, list):
metadata_filter[k] = [AstraDB._filter_to_metadata(f) for f in v]
else:
metadata_filter[k] = AstraDB._filter_to_metadata(v)
else:
metadata_filter[f"metadata.{k}"] = v
return metadata_filter
def __init__(
self,
@ -471,7 +481,7 @@ class AstraDB(VectorStore):
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]:
"""Return docs most similar to embedding vector.
@ -512,7 +522,7 @@ class AstraDB(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, str]] = None,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]:
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_id_by_vector(
@ -525,7 +535,7 @@ class AstraDB(VectorStore):
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to embedding vector.
@ -548,7 +558,7 @@ class AstraDB(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, str]] = None,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
embedding_vector = self.embedding.embed_query(query)
@ -562,7 +572,7 @@ class AstraDB(VectorStore):
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
return [
@ -578,7 +588,7 @@ class AstraDB(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, str]] = None,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]:
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_by_vector(
@ -593,7 +603,7 @@ class AstraDB(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
@ -650,7 +660,7 @@ class AstraDB(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.

@ -354,6 +354,13 @@ class TestAstraDB:
filter={"group": "consonant", "ord": ord("q"), "case": "upper"},
)
assert res3 == []
# filter with logical operator
res4 = store_someemb.similarity_search(
"x",
k=10,
filter={"$or": [{"ord": ord("q")}, {"ord": ord("r")}]},
)
assert {doc.page_content for doc in res4} == {"q", "r"}
def test_astradb_vectorstore_similarity_scale(
self, store_parseremb: AstraDB

Loading…
Cancel
Save