fixed integration tests

This commit is contained in:
Eric Pinzur 2024-11-07 11:46:30 -06:00
parent 14f1827953
commit 868f2f6932

View File

@ -1,12 +1,16 @@
"""Test Chroma functionality."""
import uuid
from typing import Generator
from typing import (
Generator,
cast,
)
import chromadb
import pytest # type: ignore[import-not-found]
import requests
from chromadb.api.client import SharedSystemClient
from chromadb.api.types import Embeddable
from langchain_core.documents import Document
from langchain_core.embeddings.fake import FakeEmbeddings as Fak
@ -17,6 +21,15 @@ from tests.integration_tests.fake_embeddings import (
)
class MyEmbeddingFunction:
def __init__(self, fak: Fak):
self.fak = fak
def __call__(self, input: Embeddable) -> list[list[float]]:
texts = cast(list[str], input)
return self.fak.embed_documents(texts=texts)
@pytest.fixture()
def client() -> Generator[chromadb.ClientAPI, None, None]:
SharedSystemClient.clear_system_cache()
@ -254,8 +267,8 @@ def test_chroma_update_document() -> None:
# Assert that the updated document is returned by the search
assert output == [Document(page_content=updated_content, metadata={"page": "0"})]
assert new_embedding == embedding.embed_documents([updated_content])[0]
assert new_embedding != old_embedding
assert list(new_embedding) == list(embedding.embed_documents([updated_content])[0])
assert list(new_embedding) != list(old_embedding)
# TODO: RELEVANCE SCORE IS BROKEN. FIX TEST
@ -341,17 +354,17 @@ def batch_support_chroma_version() -> bool:
)
def test_chroma_large_batch() -> None:
client = chromadb.HttpClient()
embedding_function = Fak(size=255)
embedding_function = MyEmbeddingFunction(fak=Fak(size=255))
col = client.get_or_create_collection(
"my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore
embedding_function=embedding_function, # type: ignore
)
docs = ["This is a test document"] * (client.max_batch_size + 100) # type: ignore
docs = ["This is a test document"] * (client.get_max_batch_size() + 100) # type: ignore
db = Chroma.from_texts(
client=client,
collection_name=col.name,
texts=docs,
embedding=embedding_function,
embedding=embedding_function.fak,
ids=[str(uuid.uuid4()) for _ in range(len(docs))],
)
@ -369,18 +382,18 @@ def test_chroma_large_batch() -> None:
)
def test_chroma_large_batch_update() -> None:
client = chromadb.HttpClient()
embedding_function = Fak(size=255)
embedding_function = MyEmbeddingFunction(fak=Fak(size=255))
col = client.get_or_create_collection(
"my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore
embedding_function=embedding_function, # type: ignore
)
docs = ["This is a test document"] * (client.max_batch_size + 100) # type: ignore
docs = ["This is a test document"] * (client.get_max_batch_size() + 100) # type: ignore
ids = [str(uuid.uuid4()) for _ in range(len(docs))]
db = Chroma.from_texts(
client=client,
collection_name=col.name,
texts=docs,
embedding=embedding_function,
embedding=embedding_function.fak,
ids=ids,
)
new_docs = [
@ -408,7 +421,7 @@ def test_chroma_legacy_batching() -> None:
embedding_function = Fak(size=255)
col = client.get_or_create_collection(
"my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore
embedding_function=MyEmbeddingFunction, # type: ignore
)
docs = ["This is a test document"] * 100
db = Chroma.from_texts(