Fix: Qdrant ids (#5515)

# Fix Qdrant ids creation

There has been a bug in how the ids were created in the Qdrant vector
store. They were previously calculated based on the texts. However,
there are some scenarios in which two documents may have the same piece
of text but different metadata, and that's a valid case. Deduplication
should be done outside of insertion.

It has been fixed and covered with the integration tests.
---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
searx_updates
Kacper Łukawski 12 months ago committed by GitHub
parent d1f65d8dc1
commit 71a7c16ee0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,6 @@ from __future__ import annotations
import uuid
import warnings
from hashlib import md5
from itertools import islice
from operator import itemgetter
from typing import (
@ -14,6 +13,7 @@ from typing import (
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
@ -109,57 +109,11 @@ class Qdrant(VectorStore):
self._embeddings_function = embeddings
self.embeddings = None
def _embed_query(self, query: str) -> List[float]:
"""Embed query text.
Used to provide backward compatibility with `embedding_function` argument.
Args:
query: Query text.
Returns:
List of floats representing the query embedding.
"""
if self.embeddings is not None:
embedding = self.embeddings.embed_query(query)
else:
if self._embeddings_function is not None:
embedding = self._embeddings_function(query)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embedding.tolist() if hasattr(embedding, "tolist") else embedding
def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]:
"""Embed search texts.
Used to provide backward compatibility with `embedding_function` argument.
Args:
texts: Iterable of texts to embed.
Returns:
List of floats representing the texts embedding.
"""
if self.embeddings is not None:
embeddings = self.embeddings.embed_documents(list(texts))
if hasattr(embeddings, "tolist"):
embeddings = embeddings.tolist()
elif self._embeddings_function is not None:
embeddings = []
for text in texts:
embedding = self._embeddings_function(text)
if hasattr(embeddings, "tolist"):
embedding = embedding.tolist()
embeddings.append(embedding)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embeddings
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
**kwargs: Any,
) -> List[str]:
@ -168,20 +122,26 @@ class Qdrant(VectorStore):
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
ids:
Optional list of ids to associate with the texts. Ids have to be
uuid-like strings.
batch_size:
How many vectors upload per-request.
Default: 64
Returns:
List of ids from adding the texts into the vectorstore.
"""
from qdrant_client.http import models as rest
ids = []
added_ids = []
texts_iterator = iter(texts)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata for each text in a batch
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = [md5(text.encode("utf-8")).hexdigest() for text in batch_texts]
batch_ids = list(islice(ids_iterator, batch_size))
self.client.upsert(
collection_name=self.collection_name,
@ -197,9 +157,9 @@ class Qdrant(VectorStore):
),
)
ids.extend(batch_ids)
added_ids.extend(batch_ids)
return ids
return added_ids
def similarity_search(
self,
@ -313,6 +273,7 @@ class Qdrant(VectorStore):
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
@ -339,6 +300,9 @@ class Qdrant(VectorStore):
metadatas:
An optional list of metadata. If provided it has to be of the same
length as a list of texts.
ids:
Optional list of ids to associate with the texts. Ids have to be
uuid-like strings.
location:
If `:memory:` - use in-memory Qdrant instance.
If `str` - use it as a `url` parameter.
@ -378,6 +342,9 @@ class Qdrant(VectorStore):
metadata_payload_key:
A payload key used to store the metadata of the document.
Default: "metadata"
batch_size:
How many vectors upload per-request.
Default: 64
**kwargs:
Additional arguments passed directly into REST client initialization
@ -439,9 +406,11 @@ class Qdrant(VectorStore):
texts_iterator = iter(texts)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata for each text in a batch
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
# Generate the embeddings for all the texts in a batch
batch_embeddings = embedding.embed_documents(batch_texts)
@ -449,7 +418,7 @@ class Qdrant(VectorStore):
client.upsert(
collection_name=collection_name,
points=rest.Batch.construct(
ids=[md5(text.encode("utf-8")).hexdigest() for text in batch_texts],
ids=batch_ids,
vectors=batch_embeddings,
payloads=cls._build_payloads(
batch_texts,
@ -544,3 +513,50 @@ class Qdrant(VectorStore):
for condition in self._build_condition(key, value)
]
)
def _embed_query(self, query: str) -> List[float]:
"""Embed query text.
Used to provide backward compatibility with `embedding_function` argument.
Args:
query: Query text.
Returns:
List of floats representing the query embedding.
"""
if self.embeddings is not None:
embedding = self.embeddings.embed_query(query)
else:
if self._embeddings_function is not None:
embedding = self._embeddings_function(query)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embedding.tolist() if hasattr(embedding, "tolist") else embedding
def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]:
"""Embed search texts.
Used to provide backward compatibility with `embedding_function` argument.
Args:
texts: Iterable of texts to embed.
Returns:
List of floats representing the texts embedding.
"""
if self.embeddings is not None:
embeddings = self.embeddings.embed_documents(list(texts))
if hasattr(embeddings, "tolist"):
embeddings = embeddings.tolist()
elif self._embeddings_function is not None:
embeddings = []
for text in texts:
embedding = self._embeddings_function(text)
if hasattr(embeddings, "tolist"):
embedding = embedding.tolist()
embeddings.append(embedding)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embeddings

@ -1,4 +1,5 @@
"""Test Qdrant functionality."""
import tempfile
from typing import Callable, Optional
import pytest
@ -247,3 +248,91 @@ def test_qdrant_embedding_interface_raises(
embeddings=embeddings,
embedding_function=embedding_function,
)
def test_qdrant_stores_duplicated_texts() -> None:
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
client = QdrantClient(":memory:")
collection_name = "test"
client.recreate_collection(
collection_name,
vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE),
)
vec_store = Qdrant(
client,
collection_name,
embeddings=ConsistentFakeEmbeddings(),
)
ids = vec_store.add_texts(["abc", "abc"], [{"a": 1}, {"a": 2}])
assert 2 == len(set(ids))
assert 2 == client.count(collection_name).count
def test_qdrant_from_texts_stores_duplicated_texts() -> None:
from qdrant_client import QdrantClient
with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts(
["abc", "abc"],
ConsistentFakeEmbeddings(),
collection_name="test",
path=str(tmpdir),
)
del vec_store
client = QdrantClient(path=str(tmpdir))
assert 2 == client.count("test").count
@pytest.mark.parametrize("batch_size", [1, 64])
def test_qdrant_from_texts_stores_ids(batch_size: int) -> None:
from qdrant_client import QdrantClient
with tempfile.TemporaryDirectory() as tmpdir:
ids = [
"fa38d572-4c31-4579-aedc-1960d79df6df",
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
]
vec_store = Qdrant.from_texts(
["abc", "def"],
ConsistentFakeEmbeddings(),
ids=ids,
collection_name="test",
path=str(tmpdir),
batch_size=batch_size,
)
del vec_store
client = QdrantClient(path=str(tmpdir))
assert 2 == client.count("test").count
stored_ids = [point.id for point in client.scroll("test")[0]]
assert set(ids) == set(stored_ids)
@pytest.mark.parametrize("batch_size", [1, 64])
def test_qdrant_add_texts_stores_ids(batch_size: int) -> None:
from qdrant_client import QdrantClient
ids = [
"fa38d572-4c31-4579-aedc-1960d79df6df",
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
]
client = QdrantClient(":memory:")
collection_name = "test"
client.recreate_collection(
collection_name,
vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE),
)
vec_store = Qdrant(client, "test", ConsistentFakeEmbeddings())
returned_ids = vec_store.add_texts(["abc", "def"], ids=ids, batch_size=batch_size)
assert all(first == second for first, second in zip(ids, returned_ids))
assert 2 == client.count("test").count
stored_ids = [point.id for point in client.scroll("test")[0]]
assert set(ids) == set(stored_ids)

Loading…
Cancel
Save