diff --git a/docs/docs/integrations/vectorstores/faiss.ipynb b/docs/docs/integrations/vectorstores/faiss.ipynb index 9bfde3a925..54894e4c66 100644 --- a/docs/docs/integrations/vectorstores/faiss.ipynb +++ b/docs/docs/integrations/vectorstores/faiss.ipynb @@ -416,7 +416,7 @@ "metadata": {}, "source": [ "## Similarity Search with filtering\n", - "FAISS vectorstore can also support filtering, since the FAISS does not natively support filtering we have to do it manually. This is done by first fetching more results than `k` and then filtering them. You can filter the documents based on metadata. You can also set the `fetch_k` parameter when calling any search method to set how many documents you want to fetch before filtering. Here is a small example:" + "FAISS vectorstore can also support filtering, since the FAISS does not natively support filtering we have to do it manually. This is done by first fetching more results than `k` and then filtering them. This filter is either a callble that takes as input a metadata dict and returns a bool, or a metadata dict where each missing key is ignored and each present k must be in a list of values. You can also set the `fetch_k` parameter when calling any search method to set how many documents you want to fetch before filtering. Here is a small example:" ] }, { @@ -480,6 +480,8 @@ ], "source": [ "results_with_scores = db.similarity_search_with_score(\"foo\", filter=dict(page=1))\n", + "# Or with a callable:\n", + "# results_with_scores = db.similarity_search_with_score(\"foo\", filter=lambda d: d[\"page\"] == 1)\n", "for doc, score in results_with_scores:\n", " print(f\"Content: {doc.page_content}, Metadata: {doc.metadata}, Score: {score}\")" ] diff --git a/libs/community/langchain_community/vectorstores/faiss.py b/libs/community/langchain_community/vectorstores/faiss.py index 8ca609b725..044209add2 100644 --- a/libs/community/langchain_community/vectorstores/faiss.py +++ b/libs/community/langchain_community/vectorstores/faiss.py @@ -273,7 +273,7 @@ class FAISS(VectorStore): self, embedding: List[float], k: int = 4, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, fetch_k: int = 20, **kwargs: Any, ) -> List[Tuple[Document, float]]: @@ -282,7 +282,9 @@ class FAISS(VectorStore): Args: embedding: Embedding vector to look up documents similar to. k: Number of Documents to return. Defaults to 4. - filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None. + filter (Optional[Union[Callable, Dict[str, Any]]]): Filter by metadata. + Defaults to None. If a callable, it must take as input the + metadata dict of Document and return a bool. fetch_k: (Optional[int]) Number of Documents to fetch before filtering. Defaults to 20. **kwargs: kwargs to be passed to similarity search. Can include: @@ -299,6 +301,27 @@ class FAISS(VectorStore): faiss.normalize_L2(vector) scores, indices = self.index.search(vector, k if filter is None else fetch_k) docs = [] + + if filter is not None: + if isinstance(filter, dict): + + def filter_func(metadata): + if all( + metadata.get(key) in value + if isinstance(value, list) + else metadata.get(key) == value + for key, value in filter.items() + ): + return True + return False + elif callable(filter): + filter_func = filter + else: + raise ValueError( + "filter must be a dict of metadata or " + f"a callable, not {type(filter)}" + ) + for j, i in enumerate(indices[0]): if i == -1: # This happens when not enough docs are returned. @@ -307,13 +330,8 @@ class FAISS(VectorStore): doc = self.docstore.search(_id) if not isinstance(doc, Document): raise ValueError(f"Could not find document for id {_id}, got {doc}") - if filter is not None: - filter = { - key: [value] if not isinstance(value, list) else value - for key, value in filter.items() - } - if all(doc.metadata.get(key) in value for key, value in filter.items()): - docs.append((doc, scores[0][j])) + if filter is not None and filter_func(doc.metadata): + docs.append((doc, scores[0][j])) else: docs.append((doc, scores[0][j])) @@ -336,7 +354,7 @@ class FAISS(VectorStore): self, embedding: List[float], k: int = 4, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, fetch_k: int = 20, **kwargs: Any, ) -> List[Tuple[Document, float]]: @@ -345,7 +363,10 @@ class FAISS(VectorStore): Args: embedding: Embedding vector to look up documents similar to. k: Number of Documents to return. Defaults to 4. - filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None. + filter (Optional[Dict[str, Any]]): Filter by metadata. + Defaults to None. If a callable, it must take as input the + metadata dict of Document and return a bool. + fetch_k: (Optional[int]) Number of Documents to fetch before filtering. Defaults to 20. **kwargs: kwargs to be passed to similarity search. Can include: @@ -372,7 +393,7 @@ class FAISS(VectorStore): self, query: str, k: int = 4, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, fetch_k: int = 20, **kwargs: Any, ) -> List[Tuple[Document, float]]: @@ -381,7 +402,10 @@ class FAISS(VectorStore): Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + filter (Optional[Dict[str, str]]): Filter by metadata. + Defaults to None. If a callable, it must take as input the + metadata dict of Document and return a bool. + fetch_k: (Optional[int]) Number of Documents to fetch before filtering. Defaults to 20. @@ -403,7 +427,7 @@ class FAISS(VectorStore): self, query: str, k: int = 4, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, fetch_k: int = 20, **kwargs: Any, ) -> List[Tuple[Document, float]]: @@ -412,7 +436,10 @@ class FAISS(VectorStore): Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + filter (Optional[Dict[str, str]]): Filter by metadata. + Defaults to None. If a callable, it must take as input the + metadata dict of Document and return a bool. + fetch_k: (Optional[int]) Number of Documents to fetch before filtering. Defaults to 20. @@ -443,7 +470,10 @@ class FAISS(VectorStore): Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + filter (Optional[Dict[str, str]]): Filter by metadata. + Defaults to None. If a callable, it must take as input the + metadata dict of Document and return a bool. + fetch_k: (Optional[int]) Number of Documents to fetch before filtering. Defaults to 20. @@ -463,7 +493,7 @@ class FAISS(VectorStore): self, embedding: List[float], k: int = 4, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, fetch_k: int = 20, **kwargs: Any, ) -> List[Document]: @@ -472,7 +502,10 @@ class FAISS(VectorStore): Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + filter (Optional[Dict[str, str]]): Filter by metadata. + Defaults to None. If a callable, it must take as input the + metadata dict of Document and return a bool. + fetch_k: (Optional[int]) Number of Documents to fetch before filtering. Defaults to 20. @@ -492,7 +525,7 @@ class FAISS(VectorStore): self, query: str, k: int = 4, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, fetch_k: int = 20, **kwargs: Any, ) -> List[Document]: @@ -517,7 +550,7 @@ class FAISS(VectorStore): self, query: str, k: int = 4, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, fetch_k: int = 20, **kwargs: Any, ) -> List[Document]: @@ -545,7 +578,7 @@ class FAISS(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, ) -> List[Tuple[Document, float]]: """Return docs and their similarity scores selected using the maximal marginal relevance. @@ -572,6 +605,24 @@ class FAISS(VectorStore): ) if filter is not None: filtered_indices = [] + if isinstance(filter, dict): + + def filter_func(metadata): + if all( + metadata.get(key) in value + if isinstance(value, list) + else metadata.get(key) == value + for key, value in filter.items() + ): + return True + return False + elif callable(filter): + filter_func = filter + else: + raise ValueError( + "filter must be a dict of metadata or " + f"a callable, not {type(filter)}" + ) for i in indices[0]: if i == -1: # This happens when not enough docs are returned. @@ -580,12 +631,7 @@ class FAISS(VectorStore): doc = self.docstore.search(_id) if not isinstance(doc, Document): raise ValueError(f"Could not find document for id {_id}, got {doc}") - if all( - doc.metadata.get(key) in value - if isinstance(value, list) - else doc.metadata.get(key) == value - for key, value in filter.items() - ): + if filter_func(doc.metadata): filtered_indices.append(i) indices = np.array([filtered_indices]) # -1 happens when not enough docs are returned. @@ -617,7 +663,7 @@ class FAISS(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, ) -> List[Tuple[Document, float]]: """Return docs and their similarity scores selected using the maximal marginal relevance asynchronously. @@ -655,7 +701,7 @@ class FAISS(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -686,7 +732,7 @@ class FAISS(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance asynchronously. @@ -719,7 +765,7 @@ class FAISS(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -756,7 +802,7 @@ class FAISS(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance asynchronously. @@ -1110,7 +1156,7 @@ class FAISS(VectorStore): self, query: str, k: int = 4, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, fetch_k: int = 20, **kwargs: Any, ) -> List[Tuple[Document, float]]: @@ -1139,7 +1185,7 @@ class FAISS(VectorStore): self, query: str, k: int = 4, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Union[Callable, Dict[str, Any]]] = None, fetch_k: int = 20, **kwargs: Any, ) -> List[Tuple[Document, float]]: diff --git a/libs/community/tests/unit_tests/vectorstores/test_faiss.py b/libs/community/tests/unit_tests/vectorstores/test_faiss.py index db8228962c..350e0a1a64 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_faiss.py +++ b/libs/community/tests/unit_tests/vectorstores/test_faiss.py @@ -307,6 +307,9 @@ def test_faiss_mmr_with_metadatas_and_filter() -> None: assert len(output) == 1 assert output[0][0] == Document(page_content="foo", metadata={"page": 1}) assert output[0][1] == 0.0 + assert output == docsearch.max_marginal_relevance_search_with_score_by_vector( + query_vec, k=10, lambda_mult=0.1, filter=lambda di: di["page"] == 1 + ) @pytest.mark.requires("faiss") @@ -321,6 +324,12 @@ async def test_faiss_async_mmr_with_metadatas_and_filter() -> None: assert len(output) == 1 assert output[0][0] == Document(page_content="foo", metadata={"page": 1}) assert output[0][1] == 0.0 + assert ( + output + == await docsearch.amax_marginal_relevance_search_with_score_by_vector( + query_vec, k=10, lambda_mult=0.1, filter=lambda di: di["page"] == 1 + ) + ) @pytest.mark.requires("faiss") @@ -336,6 +345,9 @@ def test_faiss_mmr_with_metadatas_and_list_filter() -> None: assert output[0][0] == Document(page_content="foo", metadata={"page": 0}) assert output[0][1] == 0.0 assert output[1][0] != Document(page_content="foo", metadata={"page": 0}) + assert output == docsearch.max_marginal_relevance_search_with_score_by_vector( + query_vec, k=10, lambda_mult=0.1, filter=lambda di: di["page"] in [0, 1, 2] + ) @pytest.mark.requires("faiss") @@ -351,6 +363,11 @@ async def test_faiss_async_mmr_with_metadatas_and_list_filter() -> None: assert output[0][0] == Document(page_content="foo", metadata={"page": 0}) assert output[0][1] == 0.0 assert output[1][0] != Document(page_content="foo", metadata={"page": 0}) + assert output == ( + await docsearch.amax_marginal_relevance_search_with_score_by_vector( + query_vec, k=10, lambda_mult=0.1, filter=lambda di: di["page"] in [0, 1, 2] + ) + ) @pytest.mark.requires("faiss") @@ -421,7 +438,11 @@ def test_faiss_with_metadatas_and_filter() -> None: ) assert docsearch.docstore.__dict__ == expected_docstore.__dict__ output = docsearch.similarity_search("foo", k=1, filter={"page": 1}) - assert output == [Document(page_content="bar", metadata={"page": 1})] + assert output == [Document(page_content="foo", metadata={"page": 0})] + assert output != [Document(page_content="bar", metadata={"page": 1})] + assert output == docsearch.similarity_search( + "foo", k=1, filter=lambda di: di["page"] == 1 + ) @pytest.mark.requires("faiss") @@ -444,7 +465,11 @@ async def test_faiss_async_with_metadatas_and_filter() -> None: ) assert docsearch.docstore.__dict__ == expected_docstore.__dict__ output = await docsearch.asimilarity_search("foo", k=1, filter={"page": 1}) - assert output == [Document(page_content="bar", metadata={"page": 1})] + assert output == [Document(page_content="foo", metadata={"page": 0})] + assert output != [Document(page_content="bar", metadata={"page": 1})] + assert output == await docsearch.asimilarity_search( + "foo", k=1, filter=lambda di: di["page"] == 1 + ) @pytest.mark.requires("faiss") @@ -474,6 +499,9 @@ def test_faiss_with_metadatas_and_list_filter() -> None: assert docsearch.docstore.__dict__ == expected_docstore.__dict__ output = docsearch.similarity_search("foor", k=1, filter={"page": [0, 1, 2]}) assert output == [Document(page_content="foo", metadata={"page": 0})] + assert output == docsearch.similarity_search( + "foor", k=1, filter=lambda di: di["page"] in [0, 1, 2] + ) @pytest.mark.requires("faiss") @@ -503,6 +531,9 @@ async def test_faiss_async_with_metadatas_and_list_filter() -> None: assert docsearch.docstore.__dict__ == expected_docstore.__dict__ output = await docsearch.asimilarity_search("foor", k=1, filter={"page": [0, 1, 2]}) assert output == [Document(page_content="foo", metadata={"page": 0})] + assert output == await docsearch.asimilarity_search( + "foor", k=1, filter=lambda di: di["page"] in [0, 1, 2] + ) @pytest.mark.requires("faiss")