From 16551536e3abad3a45329d5a6b658780c3c3e5c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Wed, 2 Aug 2023 19:30:18 +0200 Subject: [PATCH] Refactor Qdrant integration (#8634) This small PR introduces new parameters into Qdrant (`on_disk`), fixes some tests and changes the error message to be more clear. Tagging: @baskaryan, @rlancemartin, @eyurtsev --- .../langchain/vectorstores/mongodb_atlas.py | 4 +- .../langchain/vectorstores/qdrant.py | 22 ++++---- .../qdrant/async_api/test_add_texts.py | 7 +-- .../qdrant/async_api/test_from_texts.py | 4 ++ .../vectorstores/qdrant/test_add_texts.py | 7 +-- .../qdrant/test_embedding_interface.py | 5 +- .../vectorstores/qdrant/test_from_texts.py | 54 +++++++++++++++---- 7 files changed, 71 insertions(+), 32 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/mongodb_atlas.py b/libs/langchain/langchain/vectorstores/mongodb_atlas.py index 8306de4898..b7cd5136dd 100644 --- a/libs/langchain/langchain/vectorstores/mongodb_atlas.py +++ b/libs/langchain/langchain/vectorstores/mongodb_atlas.py @@ -142,7 +142,7 @@ class MongoDBAtlasVectorSearch(VectorStore): for t, m, embedding in zip(texts, metadatas, embeddings) ] # insert the documents in MongoDB Atlas - insert_result = self._collection.insert_many(to_insert) + insert_result = self._collection.insert_many(to_insert) # type: ignore return insert_result.inserted_ids def _similarity_search_with_score( @@ -170,7 +170,7 @@ class MongoDBAtlasVectorSearch(VectorStore): ] if post_filter_pipeline is not None: pipeline.extend(post_filter_pipeline) - cursor = self._collection.aggregate(pipeline) + cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type] docs = [] for res in cursor: text = res.pop(self._text_key) diff --git a/libs/langchain/langchain/vectorstores/qdrant.py b/libs/langchain/langchain/vectorstores/qdrant.py index 611c4dfc58..44d8f9a6cd 100644 --- a/libs/langchain/langchain/vectorstores/qdrant.py +++ b/libs/langchain/langchain/vectorstores/qdrant.py @@ -981,6 +981,7 @@ class Qdrant(VectorStore): wal_config: Optional[common_types.WalConfigDiff] = None, quantization_config: Optional[common_types.QuantizationConfig] = None, init_from: Optional[common_types.InitFrom] = None, + on_disk: Optional[bool] = None, force_recreate: bool = False, **kwargs: Any, ) -> Qdrant: @@ -1090,8 +1091,6 @@ class Qdrant(VectorStore): qdrant = cls._construct_instance( texts, embedding, - metadatas, - ids, location, url, port, @@ -1117,6 +1116,7 @@ class Qdrant(VectorStore): wal_config, quantization_config, init_from, + on_disk, force_recreate, **kwargs, ) @@ -1157,6 +1157,7 @@ class Qdrant(VectorStore): wal_config: Optional[common_types.WalConfigDiff] = None, quantization_config: Optional[common_types.QuantizationConfig] = None, init_from: Optional[common_types.InitFrom] = None, + on_disk: Optional[bool] = None, force_recreate: bool = False, **kwargs: Any, ) -> Qdrant: @@ -1266,8 +1267,6 @@ class Qdrant(VectorStore): qdrant = cls._construct_instance( texts, embedding, - metadatas, - ids, location, url, port, @@ -1293,6 +1292,7 @@ class Qdrant(VectorStore): wal_config, quantization_config, init_from, + on_disk, force_recreate, **kwargs, ) @@ -1304,8 +1304,6 @@ class Qdrant(VectorStore): cls: Type[Qdrant], texts: List[str], embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[Sequence[str]] = None, location: Optional[str] = None, url: Optional[str] = None, port: Optional[int] = 6333, @@ -1331,6 +1329,7 @@ class Qdrant(VectorStore): wal_config: Optional[common_types.WalConfigDiff] = None, quantization_config: Optional[common_types.QuantizationConfig] = None, init_from: Optional[common_types.InitFrom] = None, + on_disk: Optional[bool] = None, force_recreate: bool = False, **kwargs: Any, ) -> Qdrant: @@ -1421,16 +1420,17 @@ class Qdrant(VectorStore): if current_distance_func != distance_func: raise QdrantException( f"Existing Qdrant collection is configured for " - f"{current_vector_config.distance} " # type: ignore[union-attr] - f"similarity. Please set `distance_func` parameter to " - f"`{distance_func}` if you want to reuse it. If you want to " - f"recreate the collection, set `force_recreate` parameter to " - f"`True`." + f"{current_distance_func} similarity, but requested " + f"{distance_func}. Please set `distance_func` parameter to " + f"`{current_distance_func}` if you want to reuse it. " + f"If you want to recreate the collection, set `force_recreate` " + f"parameter to `True`." ) except (UnexpectedResponse, RpcError, ValueError): vectors_config = rest.VectorParams( size=vector_size, distance=rest.Distance[distance_func], + on_disk=on_disk, ) # If vector name was provided, we're going to use the named vectors feature diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_add_texts.py b/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_add_texts.py index 1de598d475..a6f3e72b78 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_add_texts.py +++ b/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_add_texts.py @@ -1,3 +1,4 @@ +import uuid from typing import Optional import pytest @@ -42,7 +43,7 @@ async def test_qdrant_aadd_texts_stores_duplicated_texts( from qdrant_client.http import models as rest client = QdrantClient(location=qdrant_location) - collection_name = "test" + collection_name = uuid.uuid4().hex vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE) if vector_name is not None: vectors_config = {vector_name: vectors_config} # type: ignore[assignment] @@ -75,7 +76,7 @@ async def test_qdrant_aadd_texts_stores_ids( ] client = QdrantClient(location=qdrant_location) - collection_name = "test" + collection_name = uuid.uuid4().hex client.recreate_collection( collection_name, vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE), @@ -101,7 +102,7 @@ async def test_qdrant_aadd_texts_stores_embeddings_as_named_vectors( """Test end to end Qdrant.aadd_texts stores named vectors if name is provided.""" from qdrant_client import QdrantClient - collection_name = "test" + collection_name = uuid.uuid4().hex client = QdrantClient(location=qdrant_location) client.recreate_collection( diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_from_texts.py b/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_from_texts.py index 0240c691ce..a2c4b5eb77 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_from_texts.py +++ b/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_from_texts.py @@ -224,6 +224,10 @@ async def test_qdrant_from_texts_recreates_collection_on_force_recreate( client = QdrantClient() assert 2 == client.count(collection_name).count + vector_params = client.get_collection(collection_name).config.params.vectors + if vector_name is not None: + vector_params = vector_params[vector_name] # type: ignore[index] + assert 5 == vector_params.size # type: ignore[union-attr] @pytest.mark.asyncio diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_add_texts.py b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_add_texts.py index 4afd823161..0290911e40 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_add_texts.py +++ b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_add_texts.py @@ -1,3 +1,4 @@ +import uuid from typing import Optional import pytest @@ -58,7 +59,7 @@ def test_qdrant_add_texts_stores_duplicated_texts(vector_name: Optional[str]) -> from qdrant_client.http import models as rest client = QdrantClient(":memory:") - collection_name = "test" + collection_name = uuid.uuid4().hex vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE) if vector_name is not None: vectors_config = {vector_name: vectors_config} # type: ignore[assignment] @@ -87,7 +88,7 @@ def test_qdrant_add_texts_stores_ids(batch_size: int) -> None: ] client = QdrantClient(":memory:") - collection_name = "test" + collection_name = uuid.uuid4().hex client.recreate_collection( collection_name, vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE), @@ -107,7 +108,7 @@ def test_qdrant_add_texts_stores_embeddings_as_named_vectors(vector_name: str) - """Test end to end Qdrant.add_texts stores named vectors if name is provided.""" from qdrant_client import QdrantClient - collection_name = "test" + collection_name = uuid.uuid4().hex client = QdrantClient(":memory:") client.recreate_collection( diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_embedding_interface.py b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_embedding_interface.py index 5b3d64bae8..9efc4d0c13 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_embedding_interface.py +++ b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_embedding_interface.py @@ -1,3 +1,4 @@ +import uuid from typing import Callable, Optional import pytest @@ -24,7 +25,7 @@ def test_qdrant_embedding_interface( from qdrant_client import QdrantClient client = QdrantClient(":memory:") - collection_name = "test" + collection_name = uuid.uuid4().hex Qdrant( client, @@ -48,7 +49,7 @@ def test_qdrant_embedding_interface_raises_value_error( from qdrant_client import QdrantClient client = QdrantClient(":memory:") - collection_name = "test" + collection_name = uuid.uuid4().hex with pytest.raises(ValueError): Qdrant( diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_from_texts.py b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_from_texts.py index 03aeed59e3..142b5a10e8 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_from_texts.py +++ b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_from_texts.py @@ -1,4 +1,5 @@ import tempfile +import uuid from typing import Optional import pytest @@ -9,13 +10,14 @@ from langchain.vectorstores.qdrant import QdrantException from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, ) +from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running def test_qdrant_from_texts_stores_duplicated_texts() -> None: """Test end to end Qdrant.from_texts stores duplicated texts separately.""" from qdrant_client import QdrantClient - collection_name = "test" + collection_name = uuid.uuid4().hex with tempfile.TemporaryDirectory() as tmpdir: vec_store = Qdrant.from_texts( @@ -38,7 +40,7 @@ def test_qdrant_from_texts_stores_ids( """Test end to end Qdrant.from_texts stores provided ids.""" from qdrant_client import QdrantClient - collection_name = "test" + collection_name = uuid.uuid4().hex with tempfile.TemporaryDirectory() as tmpdir: ids = [ "fa38d572-4c31-4579-aedc-1960d79df6df", @@ -66,7 +68,7 @@ def test_qdrant_from_texts_stores_embeddings_as_named_vectors(vector_name: str) """Test end to end Qdrant.from_texts stores named vectors if name is provided.""" from qdrant_client import QdrantClient - collection_name = "test" + collection_name = uuid.uuid4().hex with tempfile.TemporaryDirectory() as tmpdir: vec_store = Qdrant.from_texts( ["lorem", "ipsum", "dolor", "sit", "amet"], @@ -90,7 +92,7 @@ def test_qdrant_from_texts_reuses_same_collection(vector_name: Optional[str]) -> """Test if Qdrant.from_texts reuses the same collection""" from qdrant_client import QdrantClient - collection_name = "test" + collection_name = uuid.uuid4().hex embeddings = ConsistentFakeEmbeddings() with tempfile.TemporaryDirectory() as tmpdir: vec_store = Qdrant.from_texts( @@ -120,7 +122,7 @@ def test_qdrant_from_texts_raises_error_on_different_dimensionality( vector_name: Optional[str], ) -> None: """Test if Qdrant.from_texts raises an exception if dimensionality does not match""" - collection_name = "test" + collection_name = uuid.uuid4().hex with tempfile.TemporaryDirectory() as tmpdir: vec_store = Qdrant.from_texts( ["lorem", "ipsum", "dolor", "sit", "amet"], @@ -154,7 +156,7 @@ def test_qdrant_from_texts_raises_error_on_different_vector_name( second_vector_name: Optional[str], ) -> None: """Test if Qdrant.from_texts raises an exception if vector name does not match""" - collection_name = "test" + collection_name = uuid.uuid4().hex with tempfile.TemporaryDirectory() as tmpdir: vec_store = Qdrant.from_texts( ["lorem", "ipsum", "dolor", "sit", "amet"], @@ -177,26 +179,32 @@ def test_qdrant_from_texts_raises_error_on_different_vector_name( def test_qdrant_from_texts_raises_error_on_different_distance() -> None: """Test if Qdrant.from_texts raises an exception if distance does not match""" - collection_name = "test" + collection_name = uuid.uuid4().hex with tempfile.TemporaryDirectory() as tmpdir: vec_store = Qdrant.from_texts( ["lorem", "ipsum", "dolor", "sit", "amet"], - ConsistentFakeEmbeddings(dimensionality=10), + ConsistentFakeEmbeddings(), collection_name=collection_name, path=str(tmpdir), distance_func="Cosine", ) del vec_store - with pytest.raises(QdrantException): + with pytest.raises(QdrantException) as excinfo: Qdrant.from_texts( ["foo", "bar"], - ConsistentFakeEmbeddings(dimensionality=5), + ConsistentFakeEmbeddings(), collection_name=collection_name, path=str(tmpdir), distance_func="Euclid", ) + expected_message = ( + "configured for COSINE similarity, but requested EUCLID. Please set " + "`distance_func` parameter to `COSINE`" + ) + assert expected_message in str(excinfo.value) + @pytest.mark.parametrize("vector_name", [None, "custom-vector"]) def test_qdrant_from_texts_recreates_collection_on_force_recreate( @@ -205,7 +213,7 @@ def test_qdrant_from_texts_recreates_collection_on_force_recreate( """Test if Qdrant.from_texts recreates the collection even if config mismatches""" from qdrant_client import QdrantClient - collection_name = "test" + collection_name = uuid.uuid4().hex with tempfile.TemporaryDirectory() as tmpdir: vec_store = Qdrant.from_texts( ["lorem", "ipsum", "dolor", "sit", "amet"], @@ -250,3 +258,27 @@ def test_qdrant_from_texts_stores_metadatas( ) output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo", metadata={"page": 0})] + + +@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running") +def test_from_texts_passed_optimizers_config_and_on_disk_payload() -> None: + from qdrant_client import models + + collection_name = uuid.uuid4().hex + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + optimizers_config = models.OptimizersConfigDiff(memmap_threshold=1000) + vec_store = Qdrant.from_texts( + texts, + ConsistentFakeEmbeddings(), + metadatas=metadatas, + optimizers_config=optimizers_config, + on_disk_payload=True, + on_disk=True, + collection_name=collection_name, + ) + + collection_info = vec_store.client.get_collection(collection_name) + assert collection_info.config.params.vectors.on_disk is True # type: ignore + assert collection_info.config.optimizer_config.memmap_threshold == 1000 + assert collection_info.config.params.on_disk_payload is True