Chroma batching (#11203)

- **Description:** Chroma >= 0.4.10 added support for batch sizes
validation of add/upsert. This batch size is dependent on the SQLite
limits of the target system and varies. In this change, for
Chroma>=0.4.10 batch splitting was added as the aforementioned
validation is starting to surface in the Chroma community (users using
LC)
 - **Issue:** N/A
 - **Dependencies:** N/A
 - **Tag maintainer:** @eyurtsev
 - **Twitter handle:** t_azarov
pull/11950/head
Trayan Azarov 12 months ago committed by GitHub
parent 9373b9c004
commit 1fd21ed21c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -558,12 +558,31 @@ class Chroma(VectorStore):
)
embeddings = self._embedding_function.embed_documents(text)
self._collection.update(
ids=ids,
embeddings=embeddings,
documents=text,
metadatas=metadata,
)
if hasattr(
self._collection._client, "max_batch_size"
): # for Chroma 0.4.10 and above
from chromadb.utils.batch_utils import create_batches
for batch in create_batches(
api=self._collection._client,
ids=ids,
metadatas=metadata,
documents=text,
embeddings=embeddings,
):
self._collection.update(
ids=batch[0],
embeddings=batch[1],
documents=batch[3],
metadatas=batch[2],
)
else:
self._collection.update(
ids=ids,
embeddings=embeddings,
documents=text,
metadatas=metadata,
)
@classmethod
def from_texts(
@ -607,7 +626,24 @@ class Chroma(VectorStore):
collection_metadata=collection_metadata,
**kwargs,
)
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
if hasattr(
chroma_collection._client, "max_batch_size"
): # for Chroma 0.4.10 and above
from chromadb.utils.batch_utils import create_batches
for batch in create_batches(
api=chroma_collection._client,
ids=ids,
metadatas=metadatas,
documents=texts,
):
chroma_collection.add_texts(
texts=batch[3] if batch[3] else [],
metadatas=batch[2] if batch[2] else None,
ids=batch[0],
)
else:
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
return chroma_collection
@classmethod

@ -1,7 +1,11 @@
"""Test Chroma functionality."""
import uuid
import pytest
import requests
from langchain.docstore.document import Document
from langchain.embeddings import FakeEmbeddings as Fak
from langchain.vectorstores import Chroma
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
@ -301,3 +305,112 @@ def test_chroma_add_documents_mixed_metadata() -> None:
assert sorted(search, key=lambda d: d.page_content) == sorted(
docs, key=lambda d: d.page_content
)
def is_api_accessible(url: str) -> bool:
try:
response = requests.get(url)
return response.status_code == 200
except Exception:
return False
def batch_support_chroma_version() -> bool:
import chromadb
major, minor, patch = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) >= 4 and int(patch) >= 10:
return True
return False
@pytest.mark.requires("chromadb")
@pytest.mark.skipif(
not is_api_accessible("http://localhost:8000/api/v1/heartbeat"),
reason="API not accessible",
)
@pytest.mark.skipif(
not batch_support_chroma_version(),
reason="ChromaDB version does not support batching",
)
def test_chroma_large_batch() -> None:
import chromadb
client = chromadb.HttpClient()
embedding_function = Fak(size=255)
col = client.get_or_create_collection(
"my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore
)
docs = ["This is a test document"] * (client.max_batch_size + 100)
Chroma.from_texts(
client=client,
collection_name=col.name,
texts=docs,
embedding=embedding_function,
ids=[str(uuid.uuid4()) for _ in range(len(docs))],
)
@pytest.mark.requires("chromadb")
@pytest.mark.skipif(
not is_api_accessible("http://localhost:8000/api/v1/heartbeat"),
reason="API not accessible",
)
@pytest.mark.skipif(
not batch_support_chroma_version(),
reason="ChromaDB version does not support batching",
)
def test_chroma_large_batch_update() -> None:
import chromadb
client = chromadb.HttpClient()
embedding_function = Fak(size=255)
col = client.get_or_create_collection(
"my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore
)
docs = ["This is a test document"] * (client.max_batch_size + 100)
ids = [str(uuid.uuid4()) for _ in range(len(docs))]
db = Chroma.from_texts(
client=client,
collection_name=col.name,
texts=docs,
embedding=embedding_function,
ids=ids,
)
new_docs = [
Document(
page_content="This is a new test document", metadata={"doc_id": f"{i}"}
)
for i in range(len(docs) - 10)
]
new_ids = [_id for _id in ids[: len(new_docs)]]
db.update_documents(ids=new_ids, documents=new_docs)
@pytest.mark.requires("chromadb")
@pytest.mark.skipif(
not is_api_accessible("http://localhost:8000/api/v1/heartbeat"),
reason="API not accessible",
)
@pytest.mark.skipif(
batch_support_chroma_version(), reason="ChromaDB version does not support batching"
)
def test_chroma_legacy_batching() -> None:
import chromadb
client = chromadb.HttpClient()
embedding_function = Fak(size=255)
col = client.get_or_create_collection(
"my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore
)
docs = ["This is a test document"] * 100
Chroma.from_texts(
client=client,
collection_name=col.name,
texts=docs,
embedding=embedding_function,
ids=[str(uuid.uuid4()) for _ in range(len(docs))],
)

Loading…
Cancel
Save