diff --git a/langchain/vectorstores/qdrant.py b/langchain/vectorstores/qdrant.py index c3fead14..9c5f2f7a 100644 --- a/langchain/vectorstores/qdrant.py +++ b/langchain/vectorstores/qdrant.py @@ -3,7 +3,6 @@ from __future__ import annotations import uuid import warnings -from hashlib import md5 from itertools import islice from operator import itemgetter from typing import ( @@ -14,6 +13,7 @@ from typing import ( Iterable, List, Optional, + Sequence, Tuple, Type, Union, @@ -109,57 +109,11 @@ class Qdrant(VectorStore): self._embeddings_function = embeddings self.embeddings = None - def _embed_query(self, query: str) -> List[float]: - """Embed query text. - - Used to provide backward compatibility with `embedding_function` argument. - - Args: - query: Query text. - - Returns: - List of floats representing the query embedding. - """ - if self.embeddings is not None: - embedding = self.embeddings.embed_query(query) - else: - if self._embeddings_function is not None: - embedding = self._embeddings_function(query) - else: - raise ValueError("Neither of embeddings or embedding_function is set") - return embedding.tolist() if hasattr(embedding, "tolist") else embedding - - def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]: - """Embed search texts. - - Used to provide backward compatibility with `embedding_function` argument. - - Args: - texts: Iterable of texts to embed. - - Returns: - List of floats representing the texts embedding. - """ - if self.embeddings is not None: - embeddings = self.embeddings.embed_documents(list(texts)) - if hasattr(embeddings, "tolist"): - embeddings = embeddings.tolist() - elif self._embeddings_function is not None: - embeddings = [] - for text in texts: - embedding = self._embeddings_function(text) - if hasattr(embeddings, "tolist"): - embedding = embedding.tolist() - embeddings.append(embedding) - else: - raise ValueError("Neither of embeddings or embedding_function is set") - - return embeddings - def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, + ids: Optional[Sequence[str]] = None, batch_size: int = 64, **kwargs: Any, ) -> List[str]: @@ -168,20 +122,26 @@ class Qdrant(VectorStore): Args: texts: Iterable of strings to add to the vectorstore. metadatas: Optional list of metadatas associated with the texts. + ids: + Optional list of ids to associate with the texts. Ids have to be + uuid-like strings. + batch_size: + How many vectors upload per-request. + Default: 64 Returns: List of ids from adding the texts into the vectorstore. """ from qdrant_client.http import models as rest - ids = [] + added_ids = [] texts_iterator = iter(texts) metadatas_iterator = iter(metadatas or []) + ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) while batch_texts := list(islice(texts_iterator, batch_size)): - # Take the corresponding metadata for each text in a batch + # Take the corresponding metadata and id for each text in a batch batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None - - batch_ids = [md5(text.encode("utf-8")).hexdigest() for text in batch_texts] + batch_ids = list(islice(ids_iterator, batch_size)) self.client.upsert( collection_name=self.collection_name, @@ -197,9 +157,9 @@ class Qdrant(VectorStore): ), ) - ids.extend(batch_ids) + added_ids.extend(batch_ids) - return ids + return added_ids def similarity_search( self, @@ -313,6 +273,7 @@ class Qdrant(VectorStore): texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, + ids: Optional[Sequence[str]] = None, location: Optional[str] = None, url: Optional[str] = None, port: Optional[int] = 6333, @@ -339,6 +300,9 @@ class Qdrant(VectorStore): metadatas: An optional list of metadata. If provided it has to be of the same length as a list of texts. + ids: + Optional list of ids to associate with the texts. Ids have to be + uuid-like strings. location: If `:memory:` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter. @@ -378,6 +342,9 @@ class Qdrant(VectorStore): metadata_payload_key: A payload key used to store the metadata of the document. Default: "metadata" + batch_size: + How many vectors upload per-request. + Default: 64 **kwargs: Additional arguments passed directly into REST client initialization @@ -439,9 +406,11 @@ class Qdrant(VectorStore): texts_iterator = iter(texts) metadatas_iterator = iter(metadatas or []) + ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) while batch_texts := list(islice(texts_iterator, batch_size)): - # Take the corresponding metadata for each text in a batch + # Take the corresponding metadata and id for each text in a batch batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None + batch_ids = list(islice(ids_iterator, batch_size)) # Generate the embeddings for all the texts in a batch batch_embeddings = embedding.embed_documents(batch_texts) @@ -449,7 +418,7 @@ class Qdrant(VectorStore): client.upsert( collection_name=collection_name, points=rest.Batch.construct( - ids=[md5(text.encode("utf-8")).hexdigest() for text in batch_texts], + ids=batch_ids, vectors=batch_embeddings, payloads=cls._build_payloads( batch_texts, @@ -544,3 +513,50 @@ class Qdrant(VectorStore): for condition in self._build_condition(key, value) ] ) + + def _embed_query(self, query: str) -> List[float]: + """Embed query text. + + Used to provide backward compatibility with `embedding_function` argument. + + Args: + query: Query text. + + Returns: + List of floats representing the query embedding. + """ + if self.embeddings is not None: + embedding = self.embeddings.embed_query(query) + else: + if self._embeddings_function is not None: + embedding = self._embeddings_function(query) + else: + raise ValueError("Neither of embeddings or embedding_function is set") + return embedding.tolist() if hasattr(embedding, "tolist") else embedding + + def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]: + """Embed search texts. + + Used to provide backward compatibility with `embedding_function` argument. + + Args: + texts: Iterable of texts to embed. + + Returns: + List of floats representing the texts embedding. + """ + if self.embeddings is not None: + embeddings = self.embeddings.embed_documents(list(texts)) + if hasattr(embeddings, "tolist"): + embeddings = embeddings.tolist() + elif self._embeddings_function is not None: + embeddings = [] + for text in texts: + embedding = self._embeddings_function(text) + if hasattr(embeddings, "tolist"): + embedding = embedding.tolist() + embeddings.append(embedding) + else: + raise ValueError("Neither of embeddings or embedding_function is set") + + return embeddings diff --git a/tests/integration_tests/vectorstores/test_qdrant.py b/tests/integration_tests/vectorstores/test_qdrant.py index 1f4db0aa..3dde3753 100644 --- a/tests/integration_tests/vectorstores/test_qdrant.py +++ b/tests/integration_tests/vectorstores/test_qdrant.py @@ -1,4 +1,5 @@ """Test Qdrant functionality.""" +import tempfile from typing import Callable, Optional import pytest @@ -247,3 +248,91 @@ def test_qdrant_embedding_interface_raises( embeddings=embeddings, embedding_function=embedding_function, ) + + +def test_qdrant_stores_duplicated_texts() -> None: + from qdrant_client import QdrantClient + from qdrant_client.http import models as rest + + client = QdrantClient(":memory:") + collection_name = "test" + client.recreate_collection( + collection_name, + vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE), + ) + + vec_store = Qdrant( + client, + collection_name, + embeddings=ConsistentFakeEmbeddings(), + ) + ids = vec_store.add_texts(["abc", "abc"], [{"a": 1}, {"a": 2}]) + + assert 2 == len(set(ids)) + assert 2 == client.count(collection_name).count + + +def test_qdrant_from_texts_stores_duplicated_texts() -> None: + from qdrant_client import QdrantClient + + with tempfile.TemporaryDirectory() as tmpdir: + vec_store = Qdrant.from_texts( + ["abc", "abc"], + ConsistentFakeEmbeddings(), + collection_name="test", + path=str(tmpdir), + ) + del vec_store + + client = QdrantClient(path=str(tmpdir)) + assert 2 == client.count("test").count + + +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_qdrant_from_texts_stores_ids(batch_size: int) -> None: + from qdrant_client import QdrantClient + + with tempfile.TemporaryDirectory() as tmpdir: + ids = [ + "fa38d572-4c31-4579-aedc-1960d79df6df", + "cdc1aa36-d6ab-4fb2-8a94-56674fd27484", + ] + vec_store = Qdrant.from_texts( + ["abc", "def"], + ConsistentFakeEmbeddings(), + ids=ids, + collection_name="test", + path=str(tmpdir), + batch_size=batch_size, + ) + del vec_store + + client = QdrantClient(path=str(tmpdir)) + assert 2 == client.count("test").count + stored_ids = [point.id for point in client.scroll("test")[0]] + assert set(ids) == set(stored_ids) + + +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_qdrant_add_texts_stores_ids(batch_size: int) -> None: + from qdrant_client import QdrantClient + + ids = [ + "fa38d572-4c31-4579-aedc-1960d79df6df", + "cdc1aa36-d6ab-4fb2-8a94-56674fd27484", + ] + + client = QdrantClient(":memory:") + collection_name = "test" + client.recreate_collection( + collection_name, + vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE), + ) + + vec_store = Qdrant(client, "test", ConsistentFakeEmbeddings()) + returned_ids = vec_store.add_texts(["abc", "def"], ids=ids, batch_size=batch_size) + + assert all(first == second for first, second in zip(ids, returned_ids)) + assert 2 == client.count("test").count + stored_ids = [point.id for point in client.scroll("test")[0]] + assert set(ids) == set(stored_ids)