From 71a7c16ee03ac17261a759db047e966b9792db9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Fri, 2 Jun 2023 17:57:34 +0200 Subject: [PATCH] Fix: Qdrant ids (#5515) # Fix Qdrant ids creation There has been a bug in how the ids were created in the Qdrant vector store. They were previously calculated based on the texts. However, there are some scenarios in which two documents may have the same piece of text but different metadata, and that's a valid case. Deduplication should be done outside of insertion. It has been fixed and covered with the integration tests. --------- Co-authored-by: Dev 2049 --- langchain/vectorstores/qdrant.py | 128 ++++++++++-------- .../vectorstores/test_qdrant.py | 89 ++++++++++++ 2 files changed, 161 insertions(+), 56 deletions(-) diff --git a/langchain/vectorstores/qdrant.py b/langchain/vectorstores/qdrant.py index c3fead14..9c5f2f7a 100644 --- a/langchain/vectorstores/qdrant.py +++ b/langchain/vectorstores/qdrant.py @@ -3,7 +3,6 @@ from __future__ import annotations import uuid import warnings -from hashlib import md5 from itertools import islice from operator import itemgetter from typing import ( @@ -14,6 +13,7 @@ from typing import ( Iterable, List, Optional, + Sequence, Tuple, Type, Union, @@ -109,57 +109,11 @@ class Qdrant(VectorStore): self._embeddings_function = embeddings self.embeddings = None - def _embed_query(self, query: str) -> List[float]: - """Embed query text. - - Used to provide backward compatibility with `embedding_function` argument. - - Args: - query: Query text. - - Returns: - List of floats representing the query embedding. - """ - if self.embeddings is not None: - embedding = self.embeddings.embed_query(query) - else: - if self._embeddings_function is not None: - embedding = self._embeddings_function(query) - else: - raise ValueError("Neither of embeddings or embedding_function is set") - return embedding.tolist() if hasattr(embedding, "tolist") else embedding - - def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]: - """Embed search texts. - - Used to provide backward compatibility with `embedding_function` argument. - - Args: - texts: Iterable of texts to embed. - - Returns: - List of floats representing the texts embedding. - """ - if self.embeddings is not None: - embeddings = self.embeddings.embed_documents(list(texts)) - if hasattr(embeddings, "tolist"): - embeddings = embeddings.tolist() - elif self._embeddings_function is not None: - embeddings = [] - for text in texts: - embedding = self._embeddings_function(text) - if hasattr(embeddings, "tolist"): - embedding = embedding.tolist() - embeddings.append(embedding) - else: - raise ValueError("Neither of embeddings or embedding_function is set") - - return embeddings - def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, + ids: Optional[Sequence[str]] = None, batch_size: int = 64, **kwargs: Any, ) -> List[str]: @@ -168,20 +122,26 @@ class Qdrant(VectorStore): Args: texts: Iterable of strings to add to the vectorstore. metadatas: Optional list of metadatas associated with the texts. + ids: + Optional list of ids to associate with the texts. Ids have to be + uuid-like strings. + batch_size: + How many vectors upload per-request. + Default: 64 Returns: List of ids from adding the texts into the vectorstore. """ from qdrant_client.http import models as rest - ids = [] + added_ids = [] texts_iterator = iter(texts) metadatas_iterator = iter(metadatas or []) + ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) while batch_texts := list(islice(texts_iterator, batch_size)): - # Take the corresponding metadata for each text in a batch + # Take the corresponding metadata and id for each text in a batch batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None - - batch_ids = [md5(text.encode("utf-8")).hexdigest() for text in batch_texts] + batch_ids = list(islice(ids_iterator, batch_size)) self.client.upsert( collection_name=self.collection_name, @@ -197,9 +157,9 @@ class Qdrant(VectorStore): ), ) - ids.extend(batch_ids) + added_ids.extend(batch_ids) - return ids + return added_ids def similarity_search( self, @@ -313,6 +273,7 @@ class Qdrant(VectorStore): texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, + ids: Optional[Sequence[str]] = None, location: Optional[str] = None, url: Optional[str] = None, port: Optional[int] = 6333, @@ -339,6 +300,9 @@ class Qdrant(VectorStore): metadatas: An optional list of metadata. If provided it has to be of the same length as a list of texts. + ids: + Optional list of ids to associate with the texts. Ids have to be + uuid-like strings. location: If `:memory:` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter. @@ -378,6 +342,9 @@ class Qdrant(VectorStore): metadata_payload_key: A payload key used to store the metadata of the document. Default: "metadata" + batch_size: + How many vectors upload per-request. + Default: 64 **kwargs: Additional arguments passed directly into REST client initialization @@ -439,9 +406,11 @@ class Qdrant(VectorStore): texts_iterator = iter(texts) metadatas_iterator = iter(metadatas or []) + ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) while batch_texts := list(islice(texts_iterator, batch_size)): - # Take the corresponding metadata for each text in a batch + # Take the corresponding metadata and id for each text in a batch batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None + batch_ids = list(islice(ids_iterator, batch_size)) # Generate the embeddings for all the texts in a batch batch_embeddings = embedding.embed_documents(batch_texts) @@ -449,7 +418,7 @@ class Qdrant(VectorStore): client.upsert( collection_name=collection_name, points=rest.Batch.construct( - ids=[md5(text.encode("utf-8")).hexdigest() for text in batch_texts], + ids=batch_ids, vectors=batch_embeddings, payloads=cls._build_payloads( batch_texts, @@ -544,3 +513,50 @@ class Qdrant(VectorStore): for condition in self._build_condition(key, value) ] ) + + def _embed_query(self, query: str) -> List[float]: + """Embed query text. + + Used to provide backward compatibility with `embedding_function` argument. + + Args: + query: Query text. + + Returns: + List of floats representing the query embedding. + """ + if self.embeddings is not None: + embedding = self.embeddings.embed_query(query) + else: + if self._embeddings_function is not None: + embedding = self._embeddings_function(query) + else: + raise ValueError("Neither of embeddings or embedding_function is set") + return embedding.tolist() if hasattr(embedding, "tolist") else embedding + + def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]: + """Embed search texts. + + Used to provide backward compatibility with `embedding_function` argument. + + Args: + texts: Iterable of texts to embed. + + Returns: + List of floats representing the texts embedding. + """ + if self.embeddings is not None: + embeddings = self.embeddings.embed_documents(list(texts)) + if hasattr(embeddings, "tolist"): + embeddings = embeddings.tolist() + elif self._embeddings_function is not None: + embeddings = [] + for text in texts: + embedding = self._embeddings_function(text) + if hasattr(embeddings, "tolist"): + embedding = embedding.tolist() + embeddings.append(embedding) + else: + raise ValueError("Neither of embeddings or embedding_function is set") + + return embeddings diff --git a/tests/integration_tests/vectorstores/test_qdrant.py b/tests/integration_tests/vectorstores/test_qdrant.py index 1f4db0aa..3dde3753 100644 --- a/tests/integration_tests/vectorstores/test_qdrant.py +++ b/tests/integration_tests/vectorstores/test_qdrant.py @@ -1,4 +1,5 @@ """Test Qdrant functionality.""" +import tempfile from typing import Callable, Optional import pytest @@ -247,3 +248,91 @@ def test_qdrant_embedding_interface_raises( embeddings=embeddings, embedding_function=embedding_function, ) + + +def test_qdrant_stores_duplicated_texts() -> None: + from qdrant_client import QdrantClient + from qdrant_client.http import models as rest + + client = QdrantClient(":memory:") + collection_name = "test" + client.recreate_collection( + collection_name, + vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE), + ) + + vec_store = Qdrant( + client, + collection_name, + embeddings=ConsistentFakeEmbeddings(), + ) + ids = vec_store.add_texts(["abc", "abc"], [{"a": 1}, {"a": 2}]) + + assert 2 == len(set(ids)) + assert 2 == client.count(collection_name).count + + +def test_qdrant_from_texts_stores_duplicated_texts() -> None: + from qdrant_client import QdrantClient + + with tempfile.TemporaryDirectory() as tmpdir: + vec_store = Qdrant.from_texts( + ["abc", "abc"], + ConsistentFakeEmbeddings(), + collection_name="test", + path=str(tmpdir), + ) + del vec_store + + client = QdrantClient(path=str(tmpdir)) + assert 2 == client.count("test").count + + +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_qdrant_from_texts_stores_ids(batch_size: int) -> None: + from qdrant_client import QdrantClient + + with tempfile.TemporaryDirectory() as tmpdir: + ids = [ + "fa38d572-4c31-4579-aedc-1960d79df6df", + "cdc1aa36-d6ab-4fb2-8a94-56674fd27484", + ] + vec_store = Qdrant.from_texts( + ["abc", "def"], + ConsistentFakeEmbeddings(), + ids=ids, + collection_name="test", + path=str(tmpdir), + batch_size=batch_size, + ) + del vec_store + + client = QdrantClient(path=str(tmpdir)) + assert 2 == client.count("test").count + stored_ids = [point.id for point in client.scroll("test")[0]] + assert set(ids) == set(stored_ids) + + +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_qdrant_add_texts_stores_ids(batch_size: int) -> None: + from qdrant_client import QdrantClient + + ids = [ + "fa38d572-4c31-4579-aedc-1960d79df6df", + "cdc1aa36-d6ab-4fb2-8a94-56674fd27484", + ] + + client = QdrantClient(":memory:") + collection_name = "test" + client.recreate_collection( + collection_name, + vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE), + ) + + vec_store = Qdrant(client, "test", ConsistentFakeEmbeddings()) + returned_ids = vec_store.add_texts(["abc", "def"], ids=ids, batch_size=batch_size) + + assert all(first == second for first, second in zip(ids, returned_ids)) + assert 2 == client.count("test").count + stored_ids = [point.id for point in client.scroll("test")[0]] + assert set(ids) == set(stored_ids)