fixed integration tests

This commit is contained in:
Eric Pinzur 2024-11-07 11:46:30 -06:00
parent 14f1827953
commit 868f2f6932

View File

@ -1,12 +1,16 @@
"""Test Chroma functionality.""" """Test Chroma functionality."""
import uuid import uuid
from typing import Generator from typing import (
Generator,
cast,
)
import chromadb import chromadb
import pytest # type: ignore[import-not-found] import pytest # type: ignore[import-not-found]
import requests import requests
from chromadb.api.client import SharedSystemClient from chromadb.api.client import SharedSystemClient
from chromadb.api.types import Embeddable
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings.fake import FakeEmbeddings as Fak from langchain_core.embeddings.fake import FakeEmbeddings as Fak
@ -17,6 +21,15 @@ from tests.integration_tests.fake_embeddings import (
) )
class MyEmbeddingFunction:
def __init__(self, fak: Fak):
self.fak = fak
def __call__(self, input: Embeddable) -> list[list[float]]:
texts = cast(list[str], input)
return self.fak.embed_documents(texts=texts)
@pytest.fixture() @pytest.fixture()
def client() -> Generator[chromadb.ClientAPI, None, None]: def client() -> Generator[chromadb.ClientAPI, None, None]:
SharedSystemClient.clear_system_cache() SharedSystemClient.clear_system_cache()
@ -254,8 +267,8 @@ def test_chroma_update_document() -> None:
# Assert that the updated document is returned by the search # Assert that the updated document is returned by the search
assert output == [Document(page_content=updated_content, metadata={"page": "0"})] assert output == [Document(page_content=updated_content, metadata={"page": "0"})]
assert new_embedding == embedding.embed_documents([updated_content])[0] assert list(new_embedding) == list(embedding.embed_documents([updated_content])[0])
assert new_embedding != old_embedding assert list(new_embedding) != list(old_embedding)
# TODO: RELEVANCE SCORE IS BROKEN. FIX TEST # TODO: RELEVANCE SCORE IS BROKEN. FIX TEST
@ -341,17 +354,17 @@ def batch_support_chroma_version() -> bool:
) )
def test_chroma_large_batch() -> None: def test_chroma_large_batch() -> None:
client = chromadb.HttpClient() client = chromadb.HttpClient()
embedding_function = Fak(size=255) embedding_function = MyEmbeddingFunction(fak=Fak(size=255))
col = client.get_or_create_collection( col = client.get_or_create_collection(
"my_collection", "my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore embedding_function=embedding_function, # type: ignore
) )
docs = ["This is a test document"] * (client.max_batch_size + 100) # type: ignore docs = ["This is a test document"] * (client.get_max_batch_size() + 100) # type: ignore
db = Chroma.from_texts( db = Chroma.from_texts(
client=client, client=client,
collection_name=col.name, collection_name=col.name,
texts=docs, texts=docs,
embedding=embedding_function, embedding=embedding_function.fak,
ids=[str(uuid.uuid4()) for _ in range(len(docs))], ids=[str(uuid.uuid4()) for _ in range(len(docs))],
) )
@ -369,18 +382,18 @@ def test_chroma_large_batch() -> None:
) )
def test_chroma_large_batch_update() -> None: def test_chroma_large_batch_update() -> None:
client = chromadb.HttpClient() client = chromadb.HttpClient()
embedding_function = Fak(size=255) embedding_function = MyEmbeddingFunction(fak=Fak(size=255))
col = client.get_or_create_collection( col = client.get_or_create_collection(
"my_collection", "my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore embedding_function=embedding_function, # type: ignore
) )
docs = ["This is a test document"] * (client.max_batch_size + 100) # type: ignore docs = ["This is a test document"] * (client.get_max_batch_size() + 100) # type: ignore
ids = [str(uuid.uuid4()) for _ in range(len(docs))] ids = [str(uuid.uuid4()) for _ in range(len(docs))]
db = Chroma.from_texts( db = Chroma.from_texts(
client=client, client=client,
collection_name=col.name, collection_name=col.name,
texts=docs, texts=docs,
embedding=embedding_function, embedding=embedding_function.fak,
ids=ids, ids=ids,
) )
new_docs = [ new_docs = [
@ -408,7 +421,7 @@ def test_chroma_legacy_batching() -> None:
embedding_function = Fak(size=255) embedding_function = Fak(size=255)
col = client.get_or_create_collection( col = client.get_or_create_collection(
"my_collection", "my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore embedding_function=MyEmbeddingFunction, # type: ignore
) )
docs = ["This is a test document"] * 100 docs = ["This is a test document"] * 100
db = Chroma.from_texts( db = Chroma.from_texts(