diff --git a/libs/community/langchain_community/vectorstores/astradb.py b/libs/community/langchain_community/vectorstores/astradb.py index efb544f9af..1c71d3f7b8 100644 --- a/libs/community/langchain_community/vectorstores/astradb.py +++ b/libs/community/langchain_community/vectorstores/astradb.py @@ -121,11 +121,21 @@ class AstraDB(VectorStore): """ @staticmethod - def _filter_to_metadata(filter_dict: Optional[Dict[str, str]]) -> Dict[str, Any]: + def _filter_to_metadata(filter_dict: Optional[Dict[str, Any]]) -> Dict[str, Any]: if filter_dict is None: return {} else: - return {f"metadata.{mdk}": mdv for mdk, mdv in filter_dict.items()} + metadata_filter = {} + for k, v in filter_dict.items(): + if k and k[0] == "$": + if isinstance(v, list): + metadata_filter[k] = [AstraDB._filter_to_metadata(f) for f in v] + else: + metadata_filter[k] = AstraDB._filter_to_metadata(v) + else: + metadata_filter[f"metadata.{k}"] = v + + return metadata_filter def __init__( self, @@ -471,7 +481,7 @@ class AstraDB(VectorStore): self, embedding: List[float], k: int = 4, - filter: Optional[Dict[str, str]] = None, + filter: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Document, float, str]]: """Return docs most similar to embedding vector. @@ -512,7 +522,7 @@ class AstraDB(VectorStore): self, query: str, k: int = 4, - filter: Optional[Dict[str, str]] = None, + filter: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Document, float, str]]: embedding_vector = self.embedding.embed_query(query) return self.similarity_search_with_score_id_by_vector( @@ -525,7 +535,7 @@ class AstraDB(VectorStore): self, embedding: List[float], k: int = 4, - filter: Optional[Dict[str, str]] = None, + filter: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Document, float]]: """Return docs most similar to embedding vector. @@ -548,7 +558,7 @@ class AstraDB(VectorStore): self, query: str, k: int = 4, - filter: Optional[Dict[str, str]] = None, + filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: embedding_vector = self.embedding.embed_query(query) @@ -562,7 +572,7 @@ class AstraDB(VectorStore): self, embedding: List[float], k: int = 4, - filter: Optional[Dict[str, str]] = None, + filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: return [ @@ -578,7 +588,7 @@ class AstraDB(VectorStore): self, query: str, k: int = 4, - filter: Optional[Dict[str, str]] = None, + filter: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Document, float]]: embedding_vector = self.embedding.embed_query(query) return self.similarity_search_with_score_by_vector( @@ -593,7 +603,7 @@ class AstraDB(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, + filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -650,7 +660,7 @@ class AstraDB(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, + filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. diff --git a/libs/community/tests/integration_tests/vectorstores/test_astradb.py b/libs/community/tests/integration_tests/vectorstores/test_astradb.py index c92737ea0c..4263b56161 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_astradb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_astradb.py @@ -354,6 +354,13 @@ class TestAstraDB: filter={"group": "consonant", "ord": ord("q"), "case": "upper"}, ) assert res3 == [] + # filter with logical operator + res4 = store_someemb.similarity_search( + "x", + k=10, + filter={"$or": [{"ord": ord("q")}, {"ord": ord("r")}]}, + ) + assert {doc.page_content for doc in res4} == {"q", "r"} def test_astradb_vectorstore_similarity_scale( self, store_parseremb: AstraDB