forked from Archives/langchain
Feat: Add batching to Qdrant (#5443)
# Add batching to Qdrant Several people requested a batching mechanism while uploading data to Qdrant. It is important, as there are some limits for the maximum size of the request payload, and without batching implemented in Langchain, users need to implement it on their own. This PR exposes a new optional `batch_size` parameter, so all the documents/texts are loaded in batches of the expected size (64, by default). The integration tests of Qdrant are extended to cover two cases: 1. Documents are sent in separate batches. 2. All the documents are sent in a single request.
This commit is contained in:
parent
80e133f16d
commit
f93d256190
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
import warnings
|
||||
from hashlib import md5
|
||||
from itertools import islice
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -158,6 +159,7 @@ class Qdrant(VectorStore):
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
batch_size: int = 64,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
@ -171,25 +173,31 @@ class Qdrant(VectorStore):
|
||||
"""
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
texts = list(
|
||||
texts
|
||||
) # otherwise iterable might be exhausted after id calculation
|
||||
ids = [md5(text.encode("utf-8")).hexdigest() for text in texts]
|
||||
ids = []
|
||||
texts_iterator = iter(texts)
|
||||
metadatas_iterator = iter(metadatas or [])
|
||||
while batch_texts := list(islice(texts_iterator, batch_size)):
|
||||
# Take the corresponding metadata for each text in a batch
|
||||
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
||||
|
||||
batch_ids = [md5(text.encode("utf-8")).hexdigest() for text in batch_texts]
|
||||
|
||||
self.client.upsert(
|
||||
collection_name=self.collection_name,
|
||||
points=rest.Batch.construct(
|
||||
ids=ids,
|
||||
vectors=self._embed_texts(texts),
|
||||
ids=batch_ids,
|
||||
vectors=self._embed_texts(batch_texts),
|
||||
payloads=self._build_payloads(
|
||||
texts,
|
||||
metadatas,
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
self.content_payload_key,
|
||||
self.metadata_payload_key,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
ids.extend(batch_ids)
|
||||
|
||||
return ids
|
||||
|
||||
def similarity_search(
|
||||
@ -309,6 +317,7 @@ class Qdrant(VectorStore):
|
||||
distance_func: str = "Cosine",
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
batch_size: int = 64,
|
||||
**kwargs: Any,
|
||||
) -> Qdrant:
|
||||
"""Construct Qdrant wrapper from a list of texts.
|
||||
@ -361,7 +370,7 @@ class Qdrant(VectorStore):
|
||||
**kwargs:
|
||||
Additional arguments passed directly into REST client initialization
|
||||
|
||||
This is a user friendly interface that:
|
||||
This is a user-friendly interface that:
|
||||
1. Creates embeddings, one for each text
|
||||
2. Initializes the Qdrant database as an in-memory docstore by default
|
||||
(and overridable to a remote docstore)
|
||||
@ -417,16 +426,25 @@ class Qdrant(VectorStore):
|
||||
),
|
||||
)
|
||||
|
||||
# Now generate the embeddings for all the texts
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
texts_iterator = iter(texts)
|
||||
metadatas_iterator = iter(metadatas or [])
|
||||
while batch_texts := list(islice(texts_iterator, batch_size)):
|
||||
# Take the corresponding metadata for each text in a batch
|
||||
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
||||
|
||||
# Generate the embeddings for all the texts in a batch
|
||||
batch_embeddings = embedding.embed_documents(batch_texts)
|
||||
|
||||
client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=rest.Batch.construct(
|
||||
ids=[md5(text.encode("utf-8")).hexdigest() for text in texts],
|
||||
vectors=embeddings,
|
||||
ids=[md5(text.encode("utf-8")).hexdigest() for text in batch_texts],
|
||||
vectors=batch_embeddings,
|
||||
payloads=cls._build_payloads(
|
||||
texts, metadatas, content_payload_key, metadata_payload_key
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
content_payload_key,
|
||||
metadata_payload_key,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -20,3 +20,28 @@ class FakeEmbeddings(Embeddings):
|
||||
Distance to each text will be that text's index,
|
||||
as it was passed to embed_documents."""
|
||||
return [float(1.0)] * 9 + [float(0.0)]
|
||||
|
||||
|
||||
class ConsistentFakeEmbeddings(FakeEmbeddings):
|
||||
"""Fake embeddings which remember all the texts seen so far to return consistent
|
||||
vectors for the same texts."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.known_texts: List[str] = []
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return consistent embeddings for each text seen so far."""
|
||||
out_vectors = []
|
||||
for text in texts:
|
||||
if text not in self.known_texts:
|
||||
self.known_texts.append(text)
|
||||
vector = [float(1.0)] * 9 + [float(self.known_texts.index(text))]
|
||||
out_vectors.append(vector)
|
||||
return out_vectors
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return consistent embeddings for the text, if seen before, or a constant
|
||||
one if the text is unknown."""
|
||||
if text not in self.known_texts:
|
||||
return [float(1.0)] * 9 + [float(0.0)]
|
||||
return [float(1.0)] * 9 + [float(self.known_texts.index(text))]
|
||||
|
@ -6,9 +6,12 @@ import pytest
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores import Qdrant
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize(
|
||||
["content_payload_key", "metadata_payload_key"],
|
||||
[
|
||||
@ -18,36 +21,59 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
("foo", Qdrant.METADATA_KEY),
|
||||
],
|
||||
)
|
||||
def test_qdrant(content_payload_key: str, metadata_payload_key: str) -> None:
|
||||
def test_qdrant_similarity_search(
|
||||
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = Qdrant.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
ConsistentFakeEmbeddings(),
|
||||
location=":memory:",
|
||||
content_payload_key=content_payload_key,
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_qdrant_add_documents() -> None:
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
def test_qdrant_add_documents(batch_size: int) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch: Qdrant = Qdrant.from_texts(texts, FakeEmbeddings(), location=":memory:")
|
||||
docsearch: Qdrant = Qdrant.from_texts(
|
||||
texts, ConsistentFakeEmbeddings(), location=":memory:", batch_size=batch_size
|
||||
)
|
||||
|
||||
new_texts = ["foobar", "foobaz"]
|
||||
docsearch.add_documents([Document(page_content=content) for content in new_texts])
|
||||
docsearch.add_documents(
|
||||
[Document(page_content=content) for content in new_texts], batch_size=batch_size
|
||||
)
|
||||
output = docsearch.similarity_search("foobar", k=1)
|
||||
# FakeEmbeddings return the same query embedding as the first document embedding
|
||||
# computed in `embedding.embed_documents`. Since embed_documents is called twice,
|
||||
# "foo" embedding is the same as "foobar" embedding
|
||||
# StatefulFakeEmbeddings return the same query embedding as the first document
|
||||
# embedding computed in `embedding.embed_documents`. Thus, "foo" embedding is the
|
||||
# same as "foobar" embedding
|
||||
assert output == [Document(page_content="foobar")] or output == [
|
||||
Document(page_content="foo")
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
def test_qdrant_add_texts_returns_all_ids(batch_size: int) -> None:
|
||||
docsearch: Qdrant = Qdrant.from_texts(
|
||||
["foobar"],
|
||||
ConsistentFakeEmbeddings(),
|
||||
location=":memory:",
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
ids = docsearch.add_texts(["foo", "bar", "baz"])
|
||||
assert 3 == len(ids)
|
||||
assert 3 == len(set(ids))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize(
|
||||
["content_payload_key", "metadata_payload_key"],
|
||||
[
|
||||
@ -58,24 +84,26 @@ def test_qdrant_add_documents() -> None:
|
||||
],
|
||||
)
|
||||
def test_qdrant_with_metadatas(
|
||||
content_payload_key: str, metadata_payload_key: str
|
||||
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = Qdrant.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
ConsistentFakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
location=":memory:",
|
||||
content_payload_key=content_payload_key,
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
|
||||
|
||||
def test_qdrant_similarity_search_filters() -> None:
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
def test_qdrant_similarity_search_filters(batch_size: int) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [
|
||||
@ -84,9 +112,10 @@ def test_qdrant_similarity_search_filters() -> None:
|
||||
]
|
||||
docsearch = Qdrant.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
ConsistentFakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
location=":memory:",
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
output = docsearch.similarity_search(
|
||||
@ -100,6 +129,7 @@ def test_qdrant_similarity_search_filters() -> None:
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize(
|
||||
["content_payload_key", "metadata_payload_key"],
|
||||
[
|
||||
@ -110,18 +140,19 @@ def test_qdrant_similarity_search_filters() -> None:
|
||||
],
|
||||
)
|
||||
def test_qdrant_max_marginal_relevance_search(
|
||||
content_payload_key: str, metadata_payload_key: str
|
||||
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
||||
) -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = Qdrant.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
ConsistentFakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
location=":memory:",
|
||||
content_payload_key=content_payload_key,
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
|
||||
assert output == [
|
||||
@ -133,9 +164,9 @@ def test_qdrant_max_marginal_relevance_search(
|
||||
@pytest.mark.parametrize(
|
||||
["embeddings", "embedding_function"],
|
||||
[
|
||||
(FakeEmbeddings(), None),
|
||||
(FakeEmbeddings().embed_query, None),
|
||||
(None, FakeEmbeddings().embed_query),
|
||||
(ConsistentFakeEmbeddings(), None),
|
||||
(ConsistentFakeEmbeddings().embed_query, None),
|
||||
(None, ConsistentFakeEmbeddings().embed_query),
|
||||
],
|
||||
)
|
||||
def test_qdrant_embedding_interface(
|
||||
@ -157,7 +188,7 @@ def test_qdrant_embedding_interface(
|
||||
@pytest.mark.parametrize(
|
||||
["embeddings", "embedding_function"],
|
||||
[
|
||||
(FakeEmbeddings(), FakeEmbeddings().embed_query),
|
||||
(ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings().embed_query),
|
||||
(None, None),
|
||||
],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user