Feat: Add batching to Qdrant (#5443)

# Add batching to Qdrant

Several people requested a batching mechanism while uploading data to
Qdrant. It is important, as there are some limits for the maximum size
of the request payload, and without batching implemented in Langchain,
users need to implement it on their own. This PR exposes a new optional
`batch_size` parameter, so all the documents/texts are loaded in batches
of the expected size (64, by default).

The integration tests of Qdrant are extended to cover two cases:
1. Documents are sent in separate batches.
2. All the documents are sent in a single request.
searx_updates
Kacper Łukawski 1 year ago committed by GitHub
parent 80e133f16d
commit f93d256190
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,6 +4,7 @@ from __future__ import annotations
import uuid import uuid
import warnings import warnings
from hashlib import md5 from hashlib import md5
from itertools import islice
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -158,6 +159,7 @@ class Qdrant(VectorStore):
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
batch_size: int = 64,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore. """Run more texts through the embeddings and add to the vectorstore.
@ -171,24 +173,30 @@ class Qdrant(VectorStore):
""" """
from qdrant_client.http import models as rest from qdrant_client.http import models as rest
texts = list( ids = []
texts texts_iterator = iter(texts)
) # otherwise iterable might be exhausted after id calculation metadatas_iterator = iter(metadatas or [])
ids = [md5(text.encode("utf-8")).hexdigest() for text in texts] while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata for each text in a batch
self.client.upsert( batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
collection_name=self.collection_name,
points=rest.Batch.construct( batch_ids = [md5(text.encode("utf-8")).hexdigest() for text in batch_texts]
ids=ids,
vectors=self._embed_texts(texts), self.client.upsert(
payloads=self._build_payloads( collection_name=self.collection_name,
texts, points=rest.Batch.construct(
metadatas, ids=batch_ids,
self.content_payload_key, vectors=self._embed_texts(batch_texts),
self.metadata_payload_key, payloads=self._build_payloads(
batch_texts,
batch_metadatas,
self.content_payload_key,
self.metadata_payload_key,
),
), ),
), )
)
ids.extend(batch_ids)
return ids return ids
@ -309,6 +317,7 @@ class Qdrant(VectorStore):
distance_func: str = "Cosine", distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY, content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY, metadata_payload_key: str = METADATA_KEY,
batch_size: int = 64,
**kwargs: Any, **kwargs: Any,
) -> Qdrant: ) -> Qdrant:
"""Construct Qdrant wrapper from a list of texts. """Construct Qdrant wrapper from a list of texts.
@ -361,7 +370,7 @@ class Qdrant(VectorStore):
**kwargs: **kwargs:
Additional arguments passed directly into REST client initialization Additional arguments passed directly into REST client initialization
This is a user friendly interface that: This is a user-friendly interface that:
1. Creates embeddings, one for each text 1. Creates embeddings, one for each text
2. Initializes the Qdrant database as an in-memory docstore by default 2. Initializes the Qdrant database as an in-memory docstore by default
(and overridable to a remote docstore) (and overridable to a remote docstore)
@ -417,19 +426,28 @@ class Qdrant(VectorStore):
), ),
) )
# Now generate the embeddings for all the texts texts_iterator = iter(texts)
embeddings = embedding.embed_documents(texts) metadatas_iterator = iter(metadatas or [])
while batch_texts := list(islice(texts_iterator, batch_size)):
client.upsert( # Take the corresponding metadata for each text in a batch
collection_name=collection_name, batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
points=rest.Batch.construct(
ids=[md5(text.encode("utf-8")).hexdigest() for text in texts], # Generate the embeddings for all the texts in a batch
vectors=embeddings, batch_embeddings = embedding.embed_documents(batch_texts)
payloads=cls._build_payloads(
texts, metadatas, content_payload_key, metadata_payload_key client.upsert(
collection_name=collection_name,
points=rest.Batch.construct(
ids=[md5(text.encode("utf-8")).hexdigest() for text in batch_texts],
vectors=batch_embeddings,
payloads=cls._build_payloads(
batch_texts,
batch_metadatas,
content_payload_key,
metadata_payload_key,
),
), ),
), )
)
return cls( return cls(
client=client, client=client,

@ -20,3 +20,28 @@ class FakeEmbeddings(Embeddings):
Distance to each text will be that text's index, Distance to each text will be that text's index,
as it was passed to embed_documents.""" as it was passed to embed_documents."""
return [float(1.0)] * 9 + [float(0.0)] return [float(1.0)] * 9 + [float(0.0)]
class ConsistentFakeEmbeddings(FakeEmbeddings):
"""Fake embeddings which remember all the texts seen so far to return consistent
vectors for the same texts."""
def __init__(self) -> None:
self.known_texts: List[str] = []
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return consistent embeddings for each text seen so far."""
out_vectors = []
for text in texts:
if text not in self.known_texts:
self.known_texts.append(text)
vector = [float(1.0)] * 9 + [float(self.known_texts.index(text))]
out_vectors.append(vector)
return out_vectors
def embed_query(self, text: str) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
if text not in self.known_texts:
return [float(1.0)] * 9 + [float(0.0)]
return [float(1.0)] * 9 + [float(self.known_texts.index(text))]

@ -6,9 +6,12 @@ import pytest
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.vectorstores import Qdrant from langchain.vectorstores import Qdrant
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize( @pytest.mark.parametrize(
["content_payload_key", "metadata_payload_key"], ["content_payload_key", "metadata_payload_key"],
[ [
@ -18,36 +21,59 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
("foo", Qdrant.METADATA_KEY), ("foo", Qdrant.METADATA_KEY),
], ],
) )
def test_qdrant(content_payload_key: str, metadata_payload_key: str) -> None: def test_qdrant_similarity_search(
batch_size: int, content_payload_key: str, metadata_payload_key: str
) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
docsearch = Qdrant.from_texts( docsearch = Qdrant.from_texts(
texts, texts,
FakeEmbeddings(), ConsistentFakeEmbeddings(),
location=":memory:", location=":memory:",
content_payload_key=content_payload_key, content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key, metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
) )
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")] assert output == [Document(page_content="foo")]
def test_qdrant_add_documents() -> None: @pytest.mark.parametrize("batch_size", [1, 64])
def test_qdrant_add_documents(batch_size: int) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
docsearch: Qdrant = Qdrant.from_texts(texts, FakeEmbeddings(), location=":memory:") docsearch: Qdrant = Qdrant.from_texts(
texts, ConsistentFakeEmbeddings(), location=":memory:", batch_size=batch_size
)
new_texts = ["foobar", "foobaz"] new_texts = ["foobar", "foobaz"]
docsearch.add_documents([Document(page_content=content) for content in new_texts]) docsearch.add_documents(
[Document(page_content=content) for content in new_texts], batch_size=batch_size
)
output = docsearch.similarity_search("foobar", k=1) output = docsearch.similarity_search("foobar", k=1)
# FakeEmbeddings return the same query embedding as the first document embedding # StatefulFakeEmbeddings return the same query embedding as the first document
# computed in `embedding.embed_documents`. Since embed_documents is called twice, # embedding computed in `embedding.embed_documents`. Thus, "foo" embedding is the
# "foo" embedding is the same as "foobar" embedding # same as "foobar" embedding
assert output == [Document(page_content="foobar")] or output == [ assert output == [Document(page_content="foobar")] or output == [
Document(page_content="foo") Document(page_content="foo")
] ]
@pytest.mark.parametrize("batch_size", [1, 64])
def test_qdrant_add_texts_returns_all_ids(batch_size: int) -> None:
docsearch: Qdrant = Qdrant.from_texts(
["foobar"],
ConsistentFakeEmbeddings(),
location=":memory:",
batch_size=batch_size,
)
ids = docsearch.add_texts(["foo", "bar", "baz"])
assert 3 == len(ids)
assert 3 == len(set(ids))
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize( @pytest.mark.parametrize(
["content_payload_key", "metadata_payload_key"], ["content_payload_key", "metadata_payload_key"],
[ [
@ -58,24 +84,26 @@ def test_qdrant_add_documents() -> None:
], ],
) )
def test_qdrant_with_metadatas( def test_qdrant_with_metadatas(
content_payload_key: str, metadata_payload_key: str batch_size: int, content_payload_key: str, metadata_payload_key: str
) -> None: ) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))] metadatas = [{"page": i} for i in range(len(texts))]
docsearch = Qdrant.from_texts( docsearch = Qdrant.from_texts(
texts, texts,
FakeEmbeddings(), ConsistentFakeEmbeddings(),
metadatas=metadatas, metadatas=metadatas,
location=":memory:", location=":memory:",
content_payload_key=content_payload_key, content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key, metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
) )
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": 0})] assert output == [Document(page_content="foo", metadata={"page": 0})]
def test_qdrant_similarity_search_filters() -> None: @pytest.mark.parametrize("batch_size", [1, 64])
def test_qdrant_similarity_search_filters(batch_size: int) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
metadatas = [ metadatas = [
@ -84,9 +112,10 @@ def test_qdrant_similarity_search_filters() -> None:
] ]
docsearch = Qdrant.from_texts( docsearch = Qdrant.from_texts(
texts, texts,
FakeEmbeddings(), ConsistentFakeEmbeddings(),
metadatas=metadatas, metadatas=metadatas,
location=":memory:", location=":memory:",
batch_size=batch_size,
) )
output = docsearch.similarity_search( output = docsearch.similarity_search(
@ -100,6 +129,7 @@ def test_qdrant_similarity_search_filters() -> None:
] ]
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize( @pytest.mark.parametrize(
["content_payload_key", "metadata_payload_key"], ["content_payload_key", "metadata_payload_key"],
[ [
@ -110,18 +140,19 @@ def test_qdrant_similarity_search_filters() -> None:
], ],
) )
def test_qdrant_max_marginal_relevance_search( def test_qdrant_max_marginal_relevance_search(
content_payload_key: str, metadata_payload_key: str batch_size: int, content_payload_key: str, metadata_payload_key: str
) -> None: ) -> None:
"""Test end to end construction and MRR search.""" """Test end to end construction and MRR search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))] metadatas = [{"page": i} for i in range(len(texts))]
docsearch = Qdrant.from_texts( docsearch = Qdrant.from_texts(
texts, texts,
FakeEmbeddings(), ConsistentFakeEmbeddings(),
metadatas=metadatas, metadatas=metadatas,
location=":memory:", location=":memory:",
content_payload_key=content_payload_key, content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key, metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
) )
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3) output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
assert output == [ assert output == [
@ -133,9 +164,9 @@ def test_qdrant_max_marginal_relevance_search(
@pytest.mark.parametrize( @pytest.mark.parametrize(
["embeddings", "embedding_function"], ["embeddings", "embedding_function"],
[ [
(FakeEmbeddings(), None), (ConsistentFakeEmbeddings(), None),
(FakeEmbeddings().embed_query, None), (ConsistentFakeEmbeddings().embed_query, None),
(None, FakeEmbeddings().embed_query), (None, ConsistentFakeEmbeddings().embed_query),
], ],
) )
def test_qdrant_embedding_interface( def test_qdrant_embedding_interface(
@ -157,7 +188,7 @@ def test_qdrant_embedding_interface(
@pytest.mark.parametrize( @pytest.mark.parametrize(
["embeddings", "embedding_function"], ["embeddings", "embedding_function"],
[ [
(FakeEmbeddings(), FakeEmbeddings().embed_query), (ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings().embed_query),
(None, None), (None, None),
], ],
) )

Loading…
Cancel
Save