diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index 2f2b49db..2a0db589 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -130,13 +130,14 @@ class Chroma(VectorStore): Returns: List[Document]: List of documents most simmilar to the query text. """ - docs_and_scores = self.similarity_search_with_score(query, k) + docs_and_scores = self.similarity_search_with_score(query, k, where=filter) return [doc for doc, _ in docs_and_scores] def similarity_search_by_vector( self, embedding: List[float], k: int = 4, + filter: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Document]: """Return docs most similar to embedding vector. @@ -146,7 +147,9 @@ class Chroma(VectorStore): Returns: List of Documents most similar to the query vector. """ - results = self._collection.query(query_embeddings=embedding, n_results=k) + results = self._collection.query( + query_embeddings=embedding, n_results=k, where=filter + ) return _results_to_docs(results) def similarity_search_with_score( diff --git a/tests/integration_tests/vectorstores/fake_embeddings.py b/tests/integration_tests/vectorstores/fake_embeddings.py index e5a5fcd3..17a81e04 100644 --- a/tests/integration_tests/vectorstores/fake_embeddings.py +++ b/tests/integration_tests/vectorstores/fake_embeddings.py @@ -10,9 +10,13 @@ class FakeEmbeddings(Embeddings): """Fake embeddings functionality for testing.""" def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Return simple embeddings.""" + """Return simple embeddings. + Embeddings encode each text as its index.""" return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] def embed_query(self, text: str) -> List[float]: - """Return simple embeddings.""" + """Return constant query embeddings. + Embeddings are identical to embed_documents(texts)[0]. + Distance to each text will be that text's index, + as it was passed to embed_documents.""" return [float(1.0)] * 9 + [float(0.0)] diff --git a/tests/integration_tests/vectorstores/test_chroma.py b/tests/integration_tests/vectorstores/test_chroma.py index e558a074..9cb07599 100644 --- a/tests/integration_tests/vectorstores/test_chroma.py +++ b/tests/integration_tests/vectorstores/test_chroma.py @@ -29,7 +29,7 @@ def test_chroma_with_metadatas() -> None: def test_chroma_with_metadatas_with_scores() -> None: - """Test end to end construction and search.""" + """Test end to end construction and scored search.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": str(i)} for i in range(len(texts))] docsearch = Chroma.from_texts( @@ -39,7 +39,47 @@ def test_chroma_with_metadatas_with_scores() -> None: metadatas=metadatas, ) output = docsearch.similarity_search_with_score("foo", k=1) - assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)] + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] + + +def test_chroma_search_filter() -> None: + """Test end to end construction and search with metadata filtering.""" + texts = ["far", "bar", "baz"] + metadatas = [{"first_letter": "{}".format(text[0])} for text in texts] + docsearch = Chroma.from_texts( + collection_name="test_collection", + texts=texts, + embedding=FakeEmbeddings(), + metadatas=metadatas, + ) + output = docsearch.similarity_search("far", k=1, filter={"first_letter": "f"}) + assert output == [Document(page_content="far", metadata={"first_letter": "f"})] + output = docsearch.similarity_search("far", k=1, filter={"first_letter": "b"}) + assert output == [Document(page_content="bar", metadata={"first_letter": "b"})] + + +def test_chroma_search_filter_with_scores() -> None: + """Test end to end construction and scored search with metadata filtering.""" + texts = ["far", "bar", "baz"] + metadatas = [{"first_letter": "{}".format(text[0])} for text in texts] + docsearch = Chroma.from_texts( + collection_name="test_collection", + texts=texts, + embedding=FakeEmbeddings(), + metadatas=metadatas, + ) + output = docsearch.similarity_search_with_score( + "far", k=1, filter={"first_letter": "f"} + ) + assert output == [ + (Document(page_content="far", metadata={"first_letter": "f"}), 0.0) + ] + output = docsearch.similarity_search_with_score( + "far", k=1, filter={"first_letter": "b"} + ) + assert output == [ + (Document(page_content="bar", metadata={"first_letter": "b"}), 1.0) + ] def test_chroma_with_persistence() -> None: