mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
fixed integration tests
This commit is contained in:
parent
14f1827953
commit
868f2f6932
@ -1,12 +1,16 @@
|
|||||||
"""Test Chroma functionality."""
|
"""Test Chroma functionality."""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Generator
|
from typing import (
|
||||||
|
Generator,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
import pytest # type: ignore[import-not-found]
|
import pytest # type: ignore[import-not-found]
|
||||||
import requests
|
import requests
|
||||||
from chromadb.api.client import SharedSystemClient
|
from chromadb.api.client import SharedSystemClient
|
||||||
|
from chromadb.api.types import Embeddable
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings.fake import FakeEmbeddings as Fak
|
from langchain_core.embeddings.fake import FakeEmbeddings as Fak
|
||||||
|
|
||||||
@ -17,6 +21,15 @@ from tests.integration_tests.fake_embeddings import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MyEmbeddingFunction:
|
||||||
|
def __init__(self, fak: Fak):
|
||||||
|
self.fak = fak
|
||||||
|
|
||||||
|
def __call__(self, input: Embeddable) -> list[list[float]]:
|
||||||
|
texts = cast(list[str], input)
|
||||||
|
return self.fak.embed_documents(texts=texts)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def client() -> Generator[chromadb.ClientAPI, None, None]:
|
def client() -> Generator[chromadb.ClientAPI, None, None]:
|
||||||
SharedSystemClient.clear_system_cache()
|
SharedSystemClient.clear_system_cache()
|
||||||
@ -254,8 +267,8 @@ def test_chroma_update_document() -> None:
|
|||||||
# Assert that the updated document is returned by the search
|
# Assert that the updated document is returned by the search
|
||||||
assert output == [Document(page_content=updated_content, metadata={"page": "0"})]
|
assert output == [Document(page_content=updated_content, metadata={"page": "0"})]
|
||||||
|
|
||||||
assert new_embedding == embedding.embed_documents([updated_content])[0]
|
assert list(new_embedding) == list(embedding.embed_documents([updated_content])[0])
|
||||||
assert new_embedding != old_embedding
|
assert list(new_embedding) != list(old_embedding)
|
||||||
|
|
||||||
|
|
||||||
# TODO: RELEVANCE SCORE IS BROKEN. FIX TEST
|
# TODO: RELEVANCE SCORE IS BROKEN. FIX TEST
|
||||||
@ -341,17 +354,17 @@ def batch_support_chroma_version() -> bool:
|
|||||||
)
|
)
|
||||||
def test_chroma_large_batch() -> None:
|
def test_chroma_large_batch() -> None:
|
||||||
client = chromadb.HttpClient()
|
client = chromadb.HttpClient()
|
||||||
embedding_function = Fak(size=255)
|
embedding_function = MyEmbeddingFunction(fak=Fak(size=255))
|
||||||
col = client.get_or_create_collection(
|
col = client.get_or_create_collection(
|
||||||
"my_collection",
|
"my_collection",
|
||||||
embedding_function=embedding_function.embed_documents, # type: ignore
|
embedding_function=embedding_function, # type: ignore
|
||||||
)
|
)
|
||||||
docs = ["This is a test document"] * (client.max_batch_size + 100) # type: ignore
|
docs = ["This is a test document"] * (client.get_max_batch_size() + 100) # type: ignore
|
||||||
db = Chroma.from_texts(
|
db = Chroma.from_texts(
|
||||||
client=client,
|
client=client,
|
||||||
collection_name=col.name,
|
collection_name=col.name,
|
||||||
texts=docs,
|
texts=docs,
|
||||||
embedding=embedding_function,
|
embedding=embedding_function.fak,
|
||||||
ids=[str(uuid.uuid4()) for _ in range(len(docs))],
|
ids=[str(uuid.uuid4()) for _ in range(len(docs))],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -369,18 +382,18 @@ def test_chroma_large_batch() -> None:
|
|||||||
)
|
)
|
||||||
def test_chroma_large_batch_update() -> None:
|
def test_chroma_large_batch_update() -> None:
|
||||||
client = chromadb.HttpClient()
|
client = chromadb.HttpClient()
|
||||||
embedding_function = Fak(size=255)
|
embedding_function = MyEmbeddingFunction(fak=Fak(size=255))
|
||||||
col = client.get_or_create_collection(
|
col = client.get_or_create_collection(
|
||||||
"my_collection",
|
"my_collection",
|
||||||
embedding_function=embedding_function.embed_documents, # type: ignore
|
embedding_function=embedding_function, # type: ignore
|
||||||
)
|
)
|
||||||
docs = ["This is a test document"] * (client.max_batch_size + 100) # type: ignore
|
docs = ["This is a test document"] * (client.get_max_batch_size() + 100) # type: ignore
|
||||||
ids = [str(uuid.uuid4()) for _ in range(len(docs))]
|
ids = [str(uuid.uuid4()) for _ in range(len(docs))]
|
||||||
db = Chroma.from_texts(
|
db = Chroma.from_texts(
|
||||||
client=client,
|
client=client,
|
||||||
collection_name=col.name,
|
collection_name=col.name,
|
||||||
texts=docs,
|
texts=docs,
|
||||||
embedding=embedding_function,
|
embedding=embedding_function.fak,
|
||||||
ids=ids,
|
ids=ids,
|
||||||
)
|
)
|
||||||
new_docs = [
|
new_docs = [
|
||||||
@ -408,7 +421,7 @@ def test_chroma_legacy_batching() -> None:
|
|||||||
embedding_function = Fak(size=255)
|
embedding_function = Fak(size=255)
|
||||||
col = client.get_or_create_collection(
|
col = client.get_or_create_collection(
|
||||||
"my_collection",
|
"my_collection",
|
||||||
embedding_function=embedding_function.embed_documents, # type: ignore
|
embedding_function=MyEmbeddingFunction, # type: ignore
|
||||||
)
|
)
|
||||||
docs = ["This is a test document"] * 100
|
docs = ["This is a test document"] * 100
|
||||||
db = Chroma.from_texts(
|
db = Chroma.from_texts(
|
||||||
|
Loading…
Reference in New Issue
Block a user