Refactor Qdrant integration (#8634)

This small PR introduces new parameters into Qdrant (`on_disk`), fixes
some tests and changes the error message to be more clear.

Tagging: @baskaryan, @rlancemartin, @eyurtsev
This commit is contained in:
Kacper Łukawski 2023-08-02 19:30:18 +02:00 committed by GitHub
parent c5fb3b6069
commit 16551536e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 71 additions and 32 deletions

View File

@ -142,7 +142,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
for t, m, embedding in zip(texts, metadatas, embeddings)
]
# insert the documents in MongoDB Atlas
insert_result = self._collection.insert_many(to_insert)
insert_result = self._collection.insert_many(to_insert) # type: ignore
return insert_result.inserted_ids
def _similarity_search_with_score(
@ -170,7 +170,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
]
if post_filter_pipeline is not None:
pipeline.extend(post_filter_pipeline)
cursor = self._collection.aggregate(pipeline)
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
docs = []
for res in cursor:
text = res.pop(self._text_key)

View File

@ -981,6 +981,7 @@ class Qdrant(VectorStore):
wal_config: Optional[common_types.WalConfigDiff] = None,
quantization_config: Optional[common_types.QuantizationConfig] = None,
init_from: Optional[common_types.InitFrom] = None,
on_disk: Optional[bool] = None,
force_recreate: bool = False,
**kwargs: Any,
) -> Qdrant:
@ -1090,8 +1091,6 @@ class Qdrant(VectorStore):
qdrant = cls._construct_instance(
texts,
embedding,
metadatas,
ids,
location,
url,
port,
@ -1117,6 +1116,7 @@ class Qdrant(VectorStore):
wal_config,
quantization_config,
init_from,
on_disk,
force_recreate,
**kwargs,
)
@ -1157,6 +1157,7 @@ class Qdrant(VectorStore):
wal_config: Optional[common_types.WalConfigDiff] = None,
quantization_config: Optional[common_types.QuantizationConfig] = None,
init_from: Optional[common_types.InitFrom] = None,
on_disk: Optional[bool] = None,
force_recreate: bool = False,
**kwargs: Any,
) -> Qdrant:
@ -1266,8 +1267,6 @@ class Qdrant(VectorStore):
qdrant = cls._construct_instance(
texts,
embedding,
metadatas,
ids,
location,
url,
port,
@ -1293,6 +1292,7 @@ class Qdrant(VectorStore):
wal_config,
quantization_config,
init_from,
on_disk,
force_recreate,
**kwargs,
)
@ -1304,8 +1304,6 @@ class Qdrant(VectorStore):
cls: Type[Qdrant],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
@ -1331,6 +1329,7 @@ class Qdrant(VectorStore):
wal_config: Optional[common_types.WalConfigDiff] = None,
quantization_config: Optional[common_types.QuantizationConfig] = None,
init_from: Optional[common_types.InitFrom] = None,
on_disk: Optional[bool] = None,
force_recreate: bool = False,
**kwargs: Any,
) -> Qdrant:
@ -1421,16 +1420,17 @@ class Qdrant(VectorStore):
if current_distance_func != distance_func:
raise QdrantException(
f"Existing Qdrant collection is configured for "
f"{current_vector_config.distance} " # type: ignore[union-attr]
f"similarity. Please set `distance_func` parameter to "
f"`{distance_func}` if you want to reuse it. If you want to "
f"recreate the collection, set `force_recreate` parameter to "
f"`True`."
f"{current_distance_func} similarity, but requested "
f"{distance_func}. Please set `distance_func` parameter to "
f"`{current_distance_func}` if you want to reuse it. "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
except (UnexpectedResponse, RpcError, ValueError):
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[distance_func],
on_disk=on_disk,
)
# If vector name was provided, we're going to use the named vectors feature

View File

@ -1,3 +1,4 @@
import uuid
from typing import Optional
import pytest
@ -42,7 +43,7 @@ async def test_qdrant_aadd_texts_stores_duplicated_texts(
from qdrant_client.http import models as rest
client = QdrantClient(location=qdrant_location)
collection_name = "test"
collection_name = uuid.uuid4().hex
vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE)
if vector_name is not None:
vectors_config = {vector_name: vectors_config} # type: ignore[assignment]
@ -75,7 +76,7 @@ async def test_qdrant_aadd_texts_stores_ids(
]
client = QdrantClient(location=qdrant_location)
collection_name = "test"
collection_name = uuid.uuid4().hex
client.recreate_collection(
collection_name,
vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE),
@ -101,7 +102,7 @@ async def test_qdrant_aadd_texts_stores_embeddings_as_named_vectors(
"""Test end to end Qdrant.aadd_texts stores named vectors if name is provided."""
from qdrant_client import QdrantClient
collection_name = "test"
collection_name = uuid.uuid4().hex
client = QdrantClient(location=qdrant_location)
client.recreate_collection(

View File

@ -224,6 +224,10 @@ async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
client = QdrantClient()
assert 2 == client.count(collection_name).count
vector_params = client.get_collection(collection_name).config.params.vectors
if vector_name is not None:
vector_params = vector_params[vector_name] # type: ignore[index]
assert 5 == vector_params.size # type: ignore[union-attr]
@pytest.mark.asyncio

View File

@ -1,3 +1,4 @@
import uuid
from typing import Optional
import pytest
@ -58,7 +59,7 @@ def test_qdrant_add_texts_stores_duplicated_texts(vector_name: Optional[str]) ->
from qdrant_client.http import models as rest
client = QdrantClient(":memory:")
collection_name = "test"
collection_name = uuid.uuid4().hex
vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE)
if vector_name is not None:
vectors_config = {vector_name: vectors_config} # type: ignore[assignment]
@ -87,7 +88,7 @@ def test_qdrant_add_texts_stores_ids(batch_size: int) -> None:
]
client = QdrantClient(":memory:")
collection_name = "test"
collection_name = uuid.uuid4().hex
client.recreate_collection(
collection_name,
vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE),
@ -107,7 +108,7 @@ def test_qdrant_add_texts_stores_embeddings_as_named_vectors(vector_name: str) -
"""Test end to end Qdrant.add_texts stores named vectors if name is provided."""
from qdrant_client import QdrantClient
collection_name = "test"
collection_name = uuid.uuid4().hex
client = QdrantClient(":memory:")
client.recreate_collection(

View File

@ -1,3 +1,4 @@
import uuid
from typing import Callable, Optional
import pytest
@ -24,7 +25,7 @@ def test_qdrant_embedding_interface(
from qdrant_client import QdrantClient
client = QdrantClient(":memory:")
collection_name = "test"
collection_name = uuid.uuid4().hex
Qdrant(
client,
@ -48,7 +49,7 @@ def test_qdrant_embedding_interface_raises_value_error(
from qdrant_client import QdrantClient
client = QdrantClient(":memory:")
collection_name = "test"
collection_name = uuid.uuid4().hex
with pytest.raises(ValueError):
Qdrant(

View File

@ -1,4 +1,5 @@
import tempfile
import uuid
from typing import Optional
import pytest
@ -9,13 +10,14 @@ from langchain.vectorstores.qdrant import QdrantException
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)
from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running
def test_qdrant_from_texts_stores_duplicated_texts() -> None:
"""Test end to end Qdrant.from_texts stores duplicated texts separately."""
from qdrant_client import QdrantClient
collection_name = "test"
collection_name = uuid.uuid4().hex
with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts(
@ -38,7 +40,7 @@ def test_qdrant_from_texts_stores_ids(
"""Test end to end Qdrant.from_texts stores provided ids."""
from qdrant_client import QdrantClient
collection_name = "test"
collection_name = uuid.uuid4().hex
with tempfile.TemporaryDirectory() as tmpdir:
ids = [
"fa38d572-4c31-4579-aedc-1960d79df6df",
@ -66,7 +68,7 @@ def test_qdrant_from_texts_stores_embeddings_as_named_vectors(vector_name: str)
"""Test end to end Qdrant.from_texts stores named vectors if name is provided."""
from qdrant_client import QdrantClient
collection_name = "test"
collection_name = uuid.uuid4().hex
with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts(
["lorem", "ipsum", "dolor", "sit", "amet"],
@ -90,7 +92,7 @@ def test_qdrant_from_texts_reuses_same_collection(vector_name: Optional[str]) ->
"""Test if Qdrant.from_texts reuses the same collection"""
from qdrant_client import QdrantClient
collection_name = "test"
collection_name = uuid.uuid4().hex
embeddings = ConsistentFakeEmbeddings()
with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts(
@ -120,7 +122,7 @@ def test_qdrant_from_texts_raises_error_on_different_dimensionality(
vector_name: Optional[str],
) -> None:
"""Test if Qdrant.from_texts raises an exception if dimensionality does not match"""
collection_name = "test"
collection_name = uuid.uuid4().hex
with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts(
["lorem", "ipsum", "dolor", "sit", "amet"],
@ -154,7 +156,7 @@ def test_qdrant_from_texts_raises_error_on_different_vector_name(
second_vector_name: Optional[str],
) -> None:
"""Test if Qdrant.from_texts raises an exception if vector name does not match"""
collection_name = "test"
collection_name = uuid.uuid4().hex
with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts(
["lorem", "ipsum", "dolor", "sit", "amet"],
@ -177,26 +179,32 @@ def test_qdrant_from_texts_raises_error_on_different_vector_name(
def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
"""Test if Qdrant.from_texts raises an exception if distance does not match"""
collection_name = "test"
collection_name = uuid.uuid4().hex
with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts(
["lorem", "ipsum", "dolor", "sit", "amet"],
ConsistentFakeEmbeddings(dimensionality=10),
ConsistentFakeEmbeddings(),
collection_name=collection_name,
path=str(tmpdir),
distance_func="Cosine",
)
del vec_store
with pytest.raises(QdrantException):
with pytest.raises(QdrantException) as excinfo:
Qdrant.from_texts(
["foo", "bar"],
ConsistentFakeEmbeddings(dimensionality=5),
ConsistentFakeEmbeddings(),
collection_name=collection_name,
path=str(tmpdir),
distance_func="Euclid",
)
expected_message = (
"configured for COSINE similarity, but requested EUCLID. Please set "
"`distance_func` parameter to `COSINE`"
)
assert expected_message in str(excinfo.value)
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
def test_qdrant_from_texts_recreates_collection_on_force_recreate(
@ -205,7 +213,7 @@ def test_qdrant_from_texts_recreates_collection_on_force_recreate(
"""Test if Qdrant.from_texts recreates the collection even if config mismatches"""
from qdrant_client import QdrantClient
collection_name = "test"
collection_name = uuid.uuid4().hex
with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts(
["lorem", "ipsum", "dolor", "sit", "amet"],
@ -250,3 +258,27 @@ def test_qdrant_from_texts_stores_metadatas(
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": 0})]
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
def test_from_texts_passed_optimizers_config_and_on_disk_payload() -> None:
from qdrant_client import models
collection_name = uuid.uuid4().hex
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
optimizers_config = models.OptimizersConfigDiff(memmap_threshold=1000)
vec_store = Qdrant.from_texts(
texts,
ConsistentFakeEmbeddings(),
metadatas=metadatas,
optimizers_config=optimizers_config,
on_disk_payload=True,
on_disk=True,
collection_name=collection_name,
)
collection_info = vec_store.client.get_collection(collection_name)
assert collection_info.config.params.vectors.on_disk is True # type: ignore
assert collection_info.config.optimizer_config.memmap_threshold == 1000
assert collection_info.config.params.on_disk_payload is True