Propagate "filter" arg in Chroma similarity_search (#1869)

Technically a duplicate fix to #1619 but with unit tests and a small
documentation update
- Propagate `filter` arg in Chroma `similarity_search` to delegated call
to `similarity_search_with_score`
- Add `filter` arg to `similarity_search_by_vector`
- Clarify doc strings on FakeEmbeddings
This commit is contained in:
Eli 2023-03-22 19:40:10 -07:00 committed by GitHub
parent 31f9ecfc19
commit 12f868b292
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 6 deletions

View File

@ -130,13 +130,14 @@ class Chroma(VectorStore):
Returns: Returns:
List[Document]: List of documents most simmilar to the query text. List[Document]: List of documents most simmilar to the query text.
""" """
docs_and_scores = self.similarity_search_with_score(query, k) docs_and_scores = self.similarity_search_with_score(query, k, where=filter)
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]
def similarity_search_by_vector( def similarity_search_by_vector(
self, self,
embedding: List[float], embedding: List[float],
k: int = 4, k: int = 4,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs most similar to embedding vector. """Return docs most similar to embedding vector.
@ -146,7 +147,9 @@ class Chroma(VectorStore):
Returns: Returns:
List of Documents most similar to the query vector. List of Documents most similar to the query vector.
""" """
results = self._collection.query(query_embeddings=embedding, n_results=k) results = self._collection.query(
query_embeddings=embedding, n_results=k, where=filter
)
return _results_to_docs(results) return _results_to_docs(results)
def similarity_search_with_score( def similarity_search_with_score(

View File

@ -10,9 +10,13 @@ class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing.""" """Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return simple embeddings.""" """Return simple embeddings.
Embeddings encode each text as its index."""
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Return simple embeddings.""" """Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0].
Distance to each text will be that text's index,
as it was passed to embed_documents."""
return [float(1.0)] * 9 + [float(0.0)] return [float(1.0)] * 9 + [float(0.0)]

View File

@ -29,7 +29,7 @@ def test_chroma_with_metadatas() -> None:
def test_chroma_with_metadatas_with_scores() -> None: def test_chroma_with_metadatas_with_scores() -> None:
"""Test end to end construction and search.""" """Test end to end construction and scored search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))] metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = Chroma.from_texts( docsearch = Chroma.from_texts(
@ -39,7 +39,47 @@ def test_chroma_with_metadatas_with_scores() -> None:
metadatas=metadatas, metadatas=metadatas,
) )
output = docsearch.similarity_search_with_score("foo", k=1) output = docsearch.similarity_search_with_score("foo", k=1)
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)] assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
def test_chroma_search_filter() -> None:
"""Test end to end construction and search with metadata filtering."""
texts = ["far", "bar", "baz"]
metadatas = [{"first_letter": "{}".format(text[0])} for text in texts]
docsearch = Chroma.from_texts(
collection_name="test_collection",
texts=texts,
embedding=FakeEmbeddings(),
metadatas=metadatas,
)
output = docsearch.similarity_search("far", k=1, filter={"first_letter": "f"})
assert output == [Document(page_content="far", metadata={"first_letter": "f"})]
output = docsearch.similarity_search("far", k=1, filter={"first_letter": "b"})
assert output == [Document(page_content="bar", metadata={"first_letter": "b"})]
def test_chroma_search_filter_with_scores() -> None:
"""Test end to end construction and scored search with metadata filtering."""
texts = ["far", "bar", "baz"]
metadatas = [{"first_letter": "{}".format(text[0])} for text in texts]
docsearch = Chroma.from_texts(
collection_name="test_collection",
texts=texts,
embedding=FakeEmbeddings(),
metadatas=metadatas,
)
output = docsearch.similarity_search_with_score(
"far", k=1, filter={"first_letter": "f"}
)
assert output == [
(Document(page_content="far", metadata={"first_letter": "f"}), 0.0)
]
output = docsearch.similarity_search_with_score(
"far", k=1, filter={"first_letter": "b"}
)
assert output == [
(Document(page_content="bar", metadata={"first_letter": "b"}), 1.0)
]
def test_chroma_with_persistence() -> None: def test_chroma_with_persistence() -> None: