Feat: Add batching to Qdrant (#5443)

# Add batching to Qdrant

Several people requested a batching mechanism while uploading data to
Qdrant. It is important, as there are some limits for the maximum size
of the request payload, and without batching implemented in Langchain,
users need to implement it on their own. This PR exposes a new optional
`batch_size` parameter, so all the documents/texts are loaded in batches
of the expected size (64, by default).

The integration tests of Qdrant are extended to cover two cases:
1. Documents are sent in separate batches.
2. All the documents are sent in a single request.
This commit is contained in:
Kacper Łukawski 2023-05-31 00:33:54 +02:00 committed by GitHub
parent 80e133f16d
commit f93d256190
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 121 additions and 47 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import uuid
import warnings
from hashlib import md5
from itertools import islice
from operator import itemgetter
from typing import (
TYPE_CHECKING,
@ -158,6 +159,7 @@ class Qdrant(VectorStore):
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
batch_size: int = 64,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
@ -171,24 +173,30 @@ class Qdrant(VectorStore):
"""
from qdrant_client.http import models as rest
texts = list(
texts
) # otherwise iterable might be exhausted after id calculation
ids = [md5(text.encode("utf-8")).hexdigest() for text in texts]
ids = []
texts_iterator = iter(texts)
metadatas_iterator = iter(metadatas or [])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
self.client.upsert(
collection_name=self.collection_name,
points=rest.Batch.construct(
ids=ids,
vectors=self._embed_texts(texts),
payloads=self._build_payloads(
texts,
metadatas,
self.content_payload_key,
self.metadata_payload_key,
batch_ids = [md5(text.encode("utf-8")).hexdigest() for text in batch_texts]
self.client.upsert(
collection_name=self.collection_name,
points=rest.Batch.construct(
ids=batch_ids,
vectors=self._embed_texts(batch_texts),
payloads=self._build_payloads(
batch_texts,
batch_metadatas,
self.content_payload_key,
self.metadata_payload_key,
),
),
),
)
)
ids.extend(batch_ids)
return ids
@ -309,6 +317,7 @@ class Qdrant(VectorStore):
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
batch_size: int = 64,
**kwargs: Any,
) -> Qdrant:
"""Construct Qdrant wrapper from a list of texts.
@ -361,7 +370,7 @@ class Qdrant(VectorStore):
**kwargs:
Additional arguments passed directly into REST client initialization
This is a user friendly interface that:
This is a user-friendly interface that:
1. Creates embeddings, one for each text
2. Initializes the Qdrant database as an in-memory docstore by default
(and overridable to a remote docstore)
@ -417,19 +426,28 @@ class Qdrant(VectorStore):
),
)
# Now generate the embeddings for all the texts
embeddings = embedding.embed_documents(texts)
texts_iterator = iter(texts)
metadatas_iterator = iter(metadatas or [])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
client.upsert(
collection_name=collection_name,
points=rest.Batch.construct(
ids=[md5(text.encode("utf-8")).hexdigest() for text in texts],
vectors=embeddings,
payloads=cls._build_payloads(
texts, metadatas, content_payload_key, metadata_payload_key
# Generate the embeddings for all the texts in a batch
batch_embeddings = embedding.embed_documents(batch_texts)
client.upsert(
collection_name=collection_name,
points=rest.Batch.construct(
ids=[md5(text.encode("utf-8")).hexdigest() for text in batch_texts],
vectors=batch_embeddings,
payloads=cls._build_payloads(
batch_texts,
batch_metadatas,
content_payload_key,
metadata_payload_key,
),
),
),
)
)
return cls(
client=client,

View File

@ -20,3 +20,28 @@ class FakeEmbeddings(Embeddings):
Distance to each text will be that text's index,
as it was passed to embed_documents."""
return [float(1.0)] * 9 + [float(0.0)]
class ConsistentFakeEmbeddings(FakeEmbeddings):
"""Fake embeddings which remember all the texts seen so far to return consistent
vectors for the same texts."""
def __init__(self) -> None:
self.known_texts: List[str] = []
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return consistent embeddings for each text seen so far."""
out_vectors = []
for text in texts:
if text not in self.known_texts:
self.known_texts.append(text)
vector = [float(1.0)] * 9 + [float(self.known_texts.index(text))]
out_vectors.append(vector)
return out_vectors
def embed_query(self, text: str) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
if text not in self.known_texts:
return [float(1.0)] * 9 + [float(0.0)]
return [float(1.0)] * 9 + [float(self.known_texts.index(text))]

View File

@ -6,9 +6,12 @@ import pytest
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import Qdrant
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize(
["content_payload_key", "metadata_payload_key"],
[
@ -18,36 +21,59 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
("foo", Qdrant.METADATA_KEY),
],
)
def test_qdrant(content_payload_key: str, metadata_payload_key: str) -> None:
def test_qdrant_similarity_search(
batch_size: int, content_payload_key: str, metadata_payload_key: str
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch = Qdrant.from_texts(
texts,
FakeEmbeddings(),
ConsistentFakeEmbeddings(),
location=":memory:",
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_qdrant_add_documents() -> None:
@pytest.mark.parametrize("batch_size", [1, 64])
def test_qdrant_add_documents(batch_size: int) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch: Qdrant = Qdrant.from_texts(texts, FakeEmbeddings(), location=":memory:")
docsearch: Qdrant = Qdrant.from_texts(
texts, ConsistentFakeEmbeddings(), location=":memory:", batch_size=batch_size
)
new_texts = ["foobar", "foobaz"]
docsearch.add_documents([Document(page_content=content) for content in new_texts])
docsearch.add_documents(
[Document(page_content=content) for content in new_texts], batch_size=batch_size
)
output = docsearch.similarity_search("foobar", k=1)
# FakeEmbeddings return the same query embedding as the first document embedding
# computed in `embedding.embed_documents`. Since embed_documents is called twice,
# "foo" embedding is the same as "foobar" embedding
# StatefulFakeEmbeddings return the same query embedding as the first document
# embedding computed in `embedding.embed_documents`. Thus, "foo" embedding is the
# same as "foobar" embedding
assert output == [Document(page_content="foobar")] or output == [
Document(page_content="foo")
]
@pytest.mark.parametrize("batch_size", [1, 64])
def test_qdrant_add_texts_returns_all_ids(batch_size: int) -> None:
docsearch: Qdrant = Qdrant.from_texts(
["foobar"],
ConsistentFakeEmbeddings(),
location=":memory:",
batch_size=batch_size,
)
ids = docsearch.add_texts(["foo", "bar", "baz"])
assert 3 == len(ids)
assert 3 == len(set(ids))
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize(
["content_payload_key", "metadata_payload_key"],
[
@ -58,24 +84,26 @@ def test_qdrant_add_documents() -> None:
],
)
def test_qdrant_with_metadatas(
content_payload_key: str, metadata_payload_key: str
batch_size: int, content_payload_key: str, metadata_payload_key: str
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = Qdrant.from_texts(
texts,
FakeEmbeddings(),
ConsistentFakeEmbeddings(),
metadatas=metadatas,
location=":memory:",
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": 0})]
def test_qdrant_similarity_search_filters() -> None:
@pytest.mark.parametrize("batch_size", [1, 64])
def test_qdrant_similarity_search_filters(batch_size: int) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
metadatas = [
@ -84,9 +112,10 @@ def test_qdrant_similarity_search_filters() -> None:
]
docsearch = Qdrant.from_texts(
texts,
FakeEmbeddings(),
ConsistentFakeEmbeddings(),
metadatas=metadatas,
location=":memory:",
batch_size=batch_size,
)
output = docsearch.similarity_search(
@ -100,6 +129,7 @@ def test_qdrant_similarity_search_filters() -> None:
]
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize(
["content_payload_key", "metadata_payload_key"],
[
@ -110,18 +140,19 @@ def test_qdrant_similarity_search_filters() -> None:
],
)
def test_qdrant_max_marginal_relevance_search(
content_payload_key: str, metadata_payload_key: str
batch_size: int, content_payload_key: str, metadata_payload_key: str
) -> None:
"""Test end to end construction and MRR search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = Qdrant.from_texts(
texts,
FakeEmbeddings(),
ConsistentFakeEmbeddings(),
metadatas=metadatas,
location=":memory:",
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
)
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
assert output == [
@ -133,9 +164,9 @@ def test_qdrant_max_marginal_relevance_search(
@pytest.mark.parametrize(
["embeddings", "embedding_function"],
[
(FakeEmbeddings(), None),
(FakeEmbeddings().embed_query, None),
(None, FakeEmbeddings().embed_query),
(ConsistentFakeEmbeddings(), None),
(ConsistentFakeEmbeddings().embed_query, None),
(None, ConsistentFakeEmbeddings().embed_query),
],
)
def test_qdrant_embedding_interface(
@ -157,7 +188,7 @@ def test_qdrant_embedding_interface(
@pytest.mark.parametrize(
["embeddings", "embedding_function"],
[
(FakeEmbeddings(), FakeEmbeddings().embed_query),
(ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings().embed_query),
(None, None),
],
)