diff --git a/langchain/vectorstores/qdrant.py b/langchain/vectorstores/qdrant.py index 5af38b96..1ff1e061 100644 --- a/langchain/vectorstores/qdrant.py +++ b/langchain/vectorstores/qdrant.py @@ -4,6 +4,7 @@ from __future__ import annotations import uuid import warnings from hashlib import md5 +from itertools import islice from operator import itemgetter from typing import ( TYPE_CHECKING, @@ -158,6 +159,7 @@ class Qdrant(VectorStore): self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, + batch_size: int = 64, **kwargs: Any, ) -> List[str]: """Run more texts through the embeddings and add to the vectorstore. @@ -171,24 +173,30 @@ class Qdrant(VectorStore): """ from qdrant_client.http import models as rest - texts = list( - texts - ) # otherwise iterable might be exhausted after id calculation - ids = [md5(text.encode("utf-8")).hexdigest() for text in texts] - - self.client.upsert( - collection_name=self.collection_name, - points=rest.Batch.construct( - ids=ids, - vectors=self._embed_texts(texts), - payloads=self._build_payloads( - texts, - metadatas, - self.content_payload_key, - self.metadata_payload_key, + ids = [] + texts_iterator = iter(texts) + metadatas_iterator = iter(metadatas or []) + while batch_texts := list(islice(texts_iterator, batch_size)): + # Take the corresponding metadata for each text in a batch + batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None + + batch_ids = [md5(text.encode("utf-8")).hexdigest() for text in batch_texts] + + self.client.upsert( + collection_name=self.collection_name, + points=rest.Batch.construct( + ids=batch_ids, + vectors=self._embed_texts(batch_texts), + payloads=self._build_payloads( + batch_texts, + batch_metadatas, + self.content_payload_key, + self.metadata_payload_key, + ), ), - ), - ) + ) + + ids.extend(batch_ids) return ids @@ -309,6 +317,7 @@ class Qdrant(VectorStore): distance_func: str = "Cosine", content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, + batch_size: int = 64, **kwargs: Any, ) -> Qdrant: """Construct Qdrant wrapper from a list of texts. @@ -361,7 +370,7 @@ class Qdrant(VectorStore): **kwargs: Additional arguments passed directly into REST client initialization - This is a user friendly interface that: + This is a user-friendly interface that: 1. Creates embeddings, one for each text 2. Initializes the Qdrant database as an in-memory docstore by default (and overridable to a remote docstore) @@ -417,19 +426,28 @@ class Qdrant(VectorStore): ), ) - # Now generate the embeddings for all the texts - embeddings = embedding.embed_documents(texts) - - client.upsert( - collection_name=collection_name, - points=rest.Batch.construct( - ids=[md5(text.encode("utf-8")).hexdigest() for text in texts], - vectors=embeddings, - payloads=cls._build_payloads( - texts, metadatas, content_payload_key, metadata_payload_key + texts_iterator = iter(texts) + metadatas_iterator = iter(metadatas or []) + while batch_texts := list(islice(texts_iterator, batch_size)): + # Take the corresponding metadata for each text in a batch + batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None + + # Generate the embeddings for all the texts in a batch + batch_embeddings = embedding.embed_documents(batch_texts) + + client.upsert( + collection_name=collection_name, + points=rest.Batch.construct( + ids=[md5(text.encode("utf-8")).hexdigest() for text in batch_texts], + vectors=batch_embeddings, + payloads=cls._build_payloads( + batch_texts, + batch_metadatas, + content_payload_key, + metadata_payload_key, + ), ), - ), - ) + ) return cls( client=client, diff --git a/tests/integration_tests/vectorstores/fake_embeddings.py b/tests/integration_tests/vectorstores/fake_embeddings.py index 17a81e04..d6914e8a 100644 --- a/tests/integration_tests/vectorstores/fake_embeddings.py +++ b/tests/integration_tests/vectorstores/fake_embeddings.py @@ -20,3 +20,28 @@ class FakeEmbeddings(Embeddings): Distance to each text will be that text's index, as it was passed to embed_documents.""" return [float(1.0)] * 9 + [float(0.0)] + + +class ConsistentFakeEmbeddings(FakeEmbeddings): + """Fake embeddings which remember all the texts seen so far to return consistent + vectors for the same texts.""" + + def __init__(self) -> None: + self.known_texts: List[str] = [] + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return consistent embeddings for each text seen so far.""" + out_vectors = [] + for text in texts: + if text not in self.known_texts: + self.known_texts.append(text) + vector = [float(1.0)] * 9 + [float(self.known_texts.index(text))] + out_vectors.append(vector) + return out_vectors + + def embed_query(self, text: str) -> List[float]: + """Return consistent embeddings for the text, if seen before, or a constant + one if the text is unknown.""" + if text not in self.known_texts: + return [float(1.0)] * 9 + [float(0.0)] + return [float(1.0)] * 9 + [float(self.known_texts.index(text))] diff --git a/tests/integration_tests/vectorstores/test_qdrant.py b/tests/integration_tests/vectorstores/test_qdrant.py index 8362951c..b7c8bca4 100644 --- a/tests/integration_tests/vectorstores/test_qdrant.py +++ b/tests/integration_tests/vectorstores/test_qdrant.py @@ -6,9 +6,12 @@ import pytest from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.vectorstores import Qdrant -from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings +from tests.integration_tests.vectorstores.fake_embeddings import ( + ConsistentFakeEmbeddings, +) +@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize( ["content_payload_key", "metadata_payload_key"], [ @@ -18,36 +21,59 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings ("foo", Qdrant.METADATA_KEY), ], ) -def test_qdrant(content_payload_key: str, metadata_payload_key: str) -> None: +def test_qdrant_similarity_search( + batch_size: int, content_payload_key: str, metadata_payload_key: str +) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] docsearch = Qdrant.from_texts( texts, - FakeEmbeddings(), + ConsistentFakeEmbeddings(), location=":memory:", content_payload_key=content_payload_key, metadata_payload_key=metadata_payload_key, + batch_size=batch_size, ) output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo")] -def test_qdrant_add_documents() -> None: +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_qdrant_add_documents(batch_size: int) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] - docsearch: Qdrant = Qdrant.from_texts(texts, FakeEmbeddings(), location=":memory:") + docsearch: Qdrant = Qdrant.from_texts( + texts, ConsistentFakeEmbeddings(), location=":memory:", batch_size=batch_size + ) new_texts = ["foobar", "foobaz"] - docsearch.add_documents([Document(page_content=content) for content in new_texts]) + docsearch.add_documents( + [Document(page_content=content) for content in new_texts], batch_size=batch_size + ) output = docsearch.similarity_search("foobar", k=1) - # FakeEmbeddings return the same query embedding as the first document embedding - # computed in `embedding.embed_documents`. Since embed_documents is called twice, - # "foo" embedding is the same as "foobar" embedding + # StatefulFakeEmbeddings return the same query embedding as the first document + # embedding computed in `embedding.embed_documents`. Thus, "foo" embedding is the + # same as "foobar" embedding assert output == [Document(page_content="foobar")] or output == [ Document(page_content="foo") ] +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_qdrant_add_texts_returns_all_ids(batch_size: int) -> None: + docsearch: Qdrant = Qdrant.from_texts( + ["foobar"], + ConsistentFakeEmbeddings(), + location=":memory:", + batch_size=batch_size, + ) + + ids = docsearch.add_texts(["foo", "bar", "baz"]) + assert 3 == len(ids) + assert 3 == len(set(ids)) + + +@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize( ["content_payload_key", "metadata_payload_key"], [ @@ -58,24 +84,26 @@ def test_qdrant_add_documents() -> None: ], ) def test_qdrant_with_metadatas( - content_payload_key: str, metadata_payload_key: str + batch_size: int, content_payload_key: str, metadata_payload_key: str ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] docsearch = Qdrant.from_texts( texts, - FakeEmbeddings(), + ConsistentFakeEmbeddings(), metadatas=metadatas, location=":memory:", content_payload_key=content_payload_key, metadata_payload_key=metadata_payload_key, + batch_size=batch_size, ) output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo", metadata={"page": 0})] -def test_qdrant_similarity_search_filters() -> None: +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_qdrant_similarity_search_filters(batch_size: int) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] metadatas = [ @@ -84,9 +112,10 @@ def test_qdrant_similarity_search_filters() -> None: ] docsearch = Qdrant.from_texts( texts, - FakeEmbeddings(), + ConsistentFakeEmbeddings(), metadatas=metadatas, location=":memory:", + batch_size=batch_size, ) output = docsearch.similarity_search( @@ -100,6 +129,7 @@ def test_qdrant_similarity_search_filters() -> None: ] +@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize( ["content_payload_key", "metadata_payload_key"], [ @@ -110,18 +140,19 @@ def test_qdrant_similarity_search_filters() -> None: ], ) def test_qdrant_max_marginal_relevance_search( - content_payload_key: str, metadata_payload_key: str + batch_size: int, content_payload_key: str, metadata_payload_key: str ) -> None: """Test end to end construction and MRR search.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] docsearch = Qdrant.from_texts( texts, - FakeEmbeddings(), + ConsistentFakeEmbeddings(), metadatas=metadatas, location=":memory:", content_payload_key=content_payload_key, metadata_payload_key=metadata_payload_key, + batch_size=batch_size, ) output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3) assert output == [ @@ -133,9 +164,9 @@ def test_qdrant_max_marginal_relevance_search( @pytest.mark.parametrize( ["embeddings", "embedding_function"], [ - (FakeEmbeddings(), None), - (FakeEmbeddings().embed_query, None), - (None, FakeEmbeddings().embed_query), + (ConsistentFakeEmbeddings(), None), + (ConsistentFakeEmbeddings().embed_query, None), + (None, ConsistentFakeEmbeddings().embed_query), ], ) def test_qdrant_embedding_interface( @@ -157,7 +188,7 @@ def test_qdrant_embedding_interface( @pytest.mark.parametrize( ["embeddings", "embedding_function"], [ - (FakeEmbeddings(), FakeEmbeddings().embed_query), + (ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings().embed_query), (None, None), ], )