forked from Archives/langchain
f93d256190
# Add batching to Qdrant Several people requested a batching mechanism while uploading data to Qdrant. It is important, as there are some limits for the maximum size of the request payload, and without batching implemented in Langchain, users need to implement it on their own. This PR exposes a new optional `batch_size` parameter, so all the documents/texts are loaded in batches of the expected size (64, by default). The integration tests of Qdrant are extended to cover two cases: 1. Documents are sent in separate batches. 2. All the documents are sent in a single request.
210 lines
6.4 KiB
Python
210 lines
6.4 KiB
Python
"""Test Qdrant functionality."""
|
|
from typing import Callable, Optional
|
|
|
|
import pytest
|
|
|
|
from langchain.docstore.document import Document
|
|
from langchain.embeddings.base import Embeddings
|
|
from langchain.vectorstores import Qdrant
|
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
|
ConsistentFakeEmbeddings,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 64])
|
|
@pytest.mark.parametrize(
|
|
["content_payload_key", "metadata_payload_key"],
|
|
[
|
|
(Qdrant.CONTENT_KEY, Qdrant.METADATA_KEY),
|
|
("foo", "bar"),
|
|
(Qdrant.CONTENT_KEY, "bar"),
|
|
("foo", Qdrant.METADATA_KEY),
|
|
],
|
|
)
|
|
def test_qdrant_similarity_search(
|
|
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
|
) -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
docsearch = Qdrant.from_texts(
|
|
texts,
|
|
ConsistentFakeEmbeddings(),
|
|
location=":memory:",
|
|
content_payload_key=content_payload_key,
|
|
metadata_payload_key=metadata_payload_key,
|
|
batch_size=batch_size,
|
|
)
|
|
output = docsearch.similarity_search("foo", k=1)
|
|
assert output == [Document(page_content="foo")]
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 64])
|
|
def test_qdrant_add_documents(batch_size: int) -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
docsearch: Qdrant = Qdrant.from_texts(
|
|
texts, ConsistentFakeEmbeddings(), location=":memory:", batch_size=batch_size
|
|
)
|
|
|
|
new_texts = ["foobar", "foobaz"]
|
|
docsearch.add_documents(
|
|
[Document(page_content=content) for content in new_texts], batch_size=batch_size
|
|
)
|
|
output = docsearch.similarity_search("foobar", k=1)
|
|
# StatefulFakeEmbeddings return the same query embedding as the first document
|
|
# embedding computed in `embedding.embed_documents`. Thus, "foo" embedding is the
|
|
# same as "foobar" embedding
|
|
assert output == [Document(page_content="foobar")] or output == [
|
|
Document(page_content="foo")
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 64])
|
|
def test_qdrant_add_texts_returns_all_ids(batch_size: int) -> None:
|
|
docsearch: Qdrant = Qdrant.from_texts(
|
|
["foobar"],
|
|
ConsistentFakeEmbeddings(),
|
|
location=":memory:",
|
|
batch_size=batch_size,
|
|
)
|
|
|
|
ids = docsearch.add_texts(["foo", "bar", "baz"])
|
|
assert 3 == len(ids)
|
|
assert 3 == len(set(ids))
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 64])
|
|
@pytest.mark.parametrize(
|
|
["content_payload_key", "metadata_payload_key"],
|
|
[
|
|
(Qdrant.CONTENT_KEY, Qdrant.METADATA_KEY),
|
|
("test_content", "test_payload"),
|
|
(Qdrant.CONTENT_KEY, "payload_test"),
|
|
("content_test", Qdrant.METADATA_KEY),
|
|
],
|
|
)
|
|
def test_qdrant_with_metadatas(
|
|
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
|
) -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": i} for i in range(len(texts))]
|
|
docsearch = Qdrant.from_texts(
|
|
texts,
|
|
ConsistentFakeEmbeddings(),
|
|
metadatas=metadatas,
|
|
location=":memory:",
|
|
content_payload_key=content_payload_key,
|
|
metadata_payload_key=metadata_payload_key,
|
|
batch_size=batch_size,
|
|
)
|
|
output = docsearch.similarity_search("foo", k=1)
|
|
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 64])
|
|
def test_qdrant_similarity_search_filters(batch_size: int) -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [
|
|
{"page": i, "metadata": {"page": i + 1, "pages": [i + 2, -1]}}
|
|
for i in range(len(texts))
|
|
]
|
|
docsearch = Qdrant.from_texts(
|
|
texts,
|
|
ConsistentFakeEmbeddings(),
|
|
metadatas=metadatas,
|
|
location=":memory:",
|
|
batch_size=batch_size,
|
|
)
|
|
|
|
output = docsearch.similarity_search(
|
|
"foo", k=1, filter={"page": 1, "metadata": {"page": 2, "pages": [3]}}
|
|
)
|
|
assert output == [
|
|
Document(
|
|
page_content="bar",
|
|
metadata={"page": 1, "metadata": {"page": 2, "pages": [3, -1]}},
|
|
)
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 64])
|
|
@pytest.mark.parametrize(
|
|
["content_payload_key", "metadata_payload_key"],
|
|
[
|
|
(Qdrant.CONTENT_KEY, Qdrant.METADATA_KEY),
|
|
("test_content", "test_payload"),
|
|
(Qdrant.CONTENT_KEY, "payload_test"),
|
|
("content_test", Qdrant.METADATA_KEY),
|
|
],
|
|
)
|
|
def test_qdrant_max_marginal_relevance_search(
|
|
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
|
) -> None:
|
|
"""Test end to end construction and MRR search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": i} for i in range(len(texts))]
|
|
docsearch = Qdrant.from_texts(
|
|
texts,
|
|
ConsistentFakeEmbeddings(),
|
|
metadatas=metadatas,
|
|
location=":memory:",
|
|
content_payload_key=content_payload_key,
|
|
metadata_payload_key=metadata_payload_key,
|
|
batch_size=batch_size,
|
|
)
|
|
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
|
|
assert output == [
|
|
Document(page_content="foo", metadata={"page": 0}),
|
|
Document(page_content="bar", metadata={"page": 1}),
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["embeddings", "embedding_function"],
|
|
[
|
|
(ConsistentFakeEmbeddings(), None),
|
|
(ConsistentFakeEmbeddings().embed_query, None),
|
|
(None, ConsistentFakeEmbeddings().embed_query),
|
|
],
|
|
)
|
|
def test_qdrant_embedding_interface(
|
|
embeddings: Optional[Embeddings], embedding_function: Optional[Callable]
|
|
) -> None:
|
|
from qdrant_client import QdrantClient
|
|
|
|
client = QdrantClient(":memory:")
|
|
collection_name = "test"
|
|
|
|
Qdrant(
|
|
client,
|
|
collection_name,
|
|
embeddings=embeddings,
|
|
embedding_function=embedding_function,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["embeddings", "embedding_function"],
|
|
[
|
|
(ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings().embed_query),
|
|
(None, None),
|
|
],
|
|
)
|
|
def test_qdrant_embedding_interface_raises(
|
|
embeddings: Optional[Embeddings], embedding_function: Optional[Callable]
|
|
) -> None:
|
|
from qdrant_client import QdrantClient
|
|
|
|
client = QdrantClient(":memory:")
|
|
collection_name = "test"
|
|
|
|
with pytest.raises(ValueError):
|
|
Qdrant(
|
|
client,
|
|
collection_name,
|
|
embeddings=embeddings,
|
|
embedding_function=embedding_function,
|
|
)
|