mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Refactor Qdrant integration (#8634)
This small PR introduces new parameters into Qdrant (`on_disk`), fixes some tests and changes the error message to be more clear. Tagging: @baskaryan, @rlancemartin, @eyurtsev
This commit is contained in:
parent
c5fb3b6069
commit
16551536e3
@ -142,7 +142,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
for t, m, embedding in zip(texts, metadatas, embeddings)
|
||||
]
|
||||
# insert the documents in MongoDB Atlas
|
||||
insert_result = self._collection.insert_many(to_insert)
|
||||
insert_result = self._collection.insert_many(to_insert) # type: ignore
|
||||
return insert_result.inserted_ids
|
||||
|
||||
def _similarity_search_with_score(
|
||||
@ -170,7 +170,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
]
|
||||
if post_filter_pipeline is not None:
|
||||
pipeline.extend(post_filter_pipeline)
|
||||
cursor = self._collection.aggregate(pipeline)
|
||||
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
|
||||
docs = []
|
||||
for res in cursor:
|
||||
text = res.pop(self._text_key)
|
||||
|
@ -981,6 +981,7 @@ class Qdrant(VectorStore):
|
||||
wal_config: Optional[common_types.WalConfigDiff] = None,
|
||||
quantization_config: Optional[common_types.QuantizationConfig] = None,
|
||||
init_from: Optional[common_types.InitFrom] = None,
|
||||
on_disk: Optional[bool] = None,
|
||||
force_recreate: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Qdrant:
|
||||
@ -1090,8 +1091,6 @@ class Qdrant(VectorStore):
|
||||
qdrant = cls._construct_instance(
|
||||
texts,
|
||||
embedding,
|
||||
metadatas,
|
||||
ids,
|
||||
location,
|
||||
url,
|
||||
port,
|
||||
@ -1117,6 +1116,7 @@ class Qdrant(VectorStore):
|
||||
wal_config,
|
||||
quantization_config,
|
||||
init_from,
|
||||
on_disk,
|
||||
force_recreate,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1157,6 +1157,7 @@ class Qdrant(VectorStore):
|
||||
wal_config: Optional[common_types.WalConfigDiff] = None,
|
||||
quantization_config: Optional[common_types.QuantizationConfig] = None,
|
||||
init_from: Optional[common_types.InitFrom] = None,
|
||||
on_disk: Optional[bool] = None,
|
||||
force_recreate: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Qdrant:
|
||||
@ -1266,8 +1267,6 @@ class Qdrant(VectorStore):
|
||||
qdrant = cls._construct_instance(
|
||||
texts,
|
||||
embedding,
|
||||
metadatas,
|
||||
ids,
|
||||
location,
|
||||
url,
|
||||
port,
|
||||
@ -1293,6 +1292,7 @@ class Qdrant(VectorStore):
|
||||
wal_config,
|
||||
quantization_config,
|
||||
init_from,
|
||||
on_disk,
|
||||
force_recreate,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1304,8 +1304,6 @@ class Qdrant(VectorStore):
|
||||
cls: Type[Qdrant],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
location: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
port: Optional[int] = 6333,
|
||||
@ -1331,6 +1329,7 @@ class Qdrant(VectorStore):
|
||||
wal_config: Optional[common_types.WalConfigDiff] = None,
|
||||
quantization_config: Optional[common_types.QuantizationConfig] = None,
|
||||
init_from: Optional[common_types.InitFrom] = None,
|
||||
on_disk: Optional[bool] = None,
|
||||
force_recreate: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Qdrant:
|
||||
@ -1421,16 +1420,17 @@ class Qdrant(VectorStore):
|
||||
if current_distance_func != distance_func:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for "
|
||||
f"{current_vector_config.distance} " # type: ignore[union-attr]
|
||||
f"similarity. Please set `distance_func` parameter to "
|
||||
f"`{distance_func}` if you want to reuse it. If you want to "
|
||||
f"recreate the collection, set `force_recreate` parameter to "
|
||||
f"`True`."
|
||||
f"{current_distance_func} similarity, but requested "
|
||||
f"{distance_func}. Please set `distance_func` parameter to "
|
||||
f"`{current_distance_func}` if you want to reuse it. "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
except (UnexpectedResponse, RpcError, ValueError):
|
||||
vectors_config = rest.VectorParams(
|
||||
size=vector_size,
|
||||
distance=rest.Distance[distance_func],
|
||||
on_disk=on_disk,
|
||||
)
|
||||
|
||||
# If vector name was provided, we're going to use the named vectors feature
|
||||
|
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -42,7 +43,7 @@ async def test_qdrant_aadd_texts_stores_duplicated_texts(
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
client = QdrantClient(location=qdrant_location)
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE)
|
||||
if vector_name is not None:
|
||||
vectors_config = {vector_name: vectors_config} # type: ignore[assignment]
|
||||
@ -75,7 +76,7 @@ async def test_qdrant_aadd_texts_stores_ids(
|
||||
]
|
||||
|
||||
client = QdrantClient(location=qdrant_location)
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
client.recreate_collection(
|
||||
collection_name,
|
||||
vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE),
|
||||
@ -101,7 +102,7 @@ async def test_qdrant_aadd_texts_stores_embeddings_as_named_vectors(
|
||||
"""Test end to end Qdrant.aadd_texts stores named vectors if name is provided."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
|
||||
client = QdrantClient(location=qdrant_location)
|
||||
client.recreate_collection(
|
||||
|
@ -224,6 +224,10 @@ async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
|
||||
|
||||
client = QdrantClient()
|
||||
assert 2 == client.count(collection_name).count
|
||||
vector_params = client.get_collection(collection_name).config.params.vectors
|
||||
if vector_name is not None:
|
||||
vector_params = vector_params[vector_name] # type: ignore[index]
|
||||
assert 5 == vector_params.size # type: ignore[union-attr]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -58,7 +59,7 @@ def test_qdrant_add_texts_stores_duplicated_texts(vector_name: Optional[str]) ->
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
client = QdrantClient(":memory:")
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE)
|
||||
if vector_name is not None:
|
||||
vectors_config = {vector_name: vectors_config} # type: ignore[assignment]
|
||||
@ -87,7 +88,7 @@ def test_qdrant_add_texts_stores_ids(batch_size: int) -> None:
|
||||
]
|
||||
|
||||
client = QdrantClient(":memory:")
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
client.recreate_collection(
|
||||
collection_name,
|
||||
vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE),
|
||||
@ -107,7 +108,7 @@ def test_qdrant_add_texts_stores_embeddings_as_named_vectors(vector_name: str) -
|
||||
"""Test end to end Qdrant.add_texts stores named vectors if name is provided."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
|
||||
client = QdrantClient(":memory:")
|
||||
client.recreate_collection(
|
||||
|
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pytest
|
||||
@ -24,7 +25,7 @@ def test_qdrant_embedding_interface(
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
client = QdrantClient(":memory:")
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
|
||||
Qdrant(
|
||||
client,
|
||||
@ -48,7 +49,7 @@ def test_qdrant_embedding_interface_raises_value_error(
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
client = QdrantClient(":memory:")
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Qdrant(
|
||||
|
@ -1,4 +1,5 @@
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -9,13 +10,14 @@ from langchain.vectorstores.qdrant import QdrantException
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
)
|
||||
from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running
|
||||
|
||||
|
||||
def test_qdrant_from_texts_stores_duplicated_texts() -> None:
|
||||
"""Test end to end Qdrant.from_texts stores duplicated texts separately."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
vec_store = Qdrant.from_texts(
|
||||
@ -38,7 +40,7 @@ def test_qdrant_from_texts_stores_ids(
|
||||
"""Test end to end Qdrant.from_texts stores provided ids."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ids = [
|
||||
"fa38d572-4c31-4579-aedc-1960d79df6df",
|
||||
@ -66,7 +68,7 @@ def test_qdrant_from_texts_stores_embeddings_as_named_vectors(vector_name: str)
|
||||
"""Test end to end Qdrant.from_texts stores named vectors if name is provided."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
vec_store = Qdrant.from_texts(
|
||||
["lorem", "ipsum", "dolor", "sit", "amet"],
|
||||
@ -90,7 +92,7 @@ def test_qdrant_from_texts_reuses_same_collection(vector_name: Optional[str]) ->
|
||||
"""Test if Qdrant.from_texts reuses the same collection"""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
embeddings = ConsistentFakeEmbeddings()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
vec_store = Qdrant.from_texts(
|
||||
@ -120,7 +122,7 @@ def test_qdrant_from_texts_raises_error_on_different_dimensionality(
|
||||
vector_name: Optional[str],
|
||||
) -> None:
|
||||
"""Test if Qdrant.from_texts raises an exception if dimensionality does not match"""
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
vec_store = Qdrant.from_texts(
|
||||
["lorem", "ipsum", "dolor", "sit", "amet"],
|
||||
@ -154,7 +156,7 @@ def test_qdrant_from_texts_raises_error_on_different_vector_name(
|
||||
second_vector_name: Optional[str],
|
||||
) -> None:
|
||||
"""Test if Qdrant.from_texts raises an exception if vector name does not match"""
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
vec_store = Qdrant.from_texts(
|
||||
["lorem", "ipsum", "dolor", "sit", "amet"],
|
||||
@ -177,26 +179,32 @@ def test_qdrant_from_texts_raises_error_on_different_vector_name(
|
||||
|
||||
def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
|
||||
"""Test if Qdrant.from_texts raises an exception if distance does not match"""
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
vec_store = Qdrant.from_texts(
|
||||
["lorem", "ipsum", "dolor", "sit", "amet"],
|
||||
ConsistentFakeEmbeddings(dimensionality=10),
|
||||
ConsistentFakeEmbeddings(),
|
||||
collection_name=collection_name,
|
||||
path=str(tmpdir),
|
||||
distance_func="Cosine",
|
||||
)
|
||||
del vec_store
|
||||
|
||||
with pytest.raises(QdrantException):
|
||||
with pytest.raises(QdrantException) as excinfo:
|
||||
Qdrant.from_texts(
|
||||
["foo", "bar"],
|
||||
ConsistentFakeEmbeddings(dimensionality=5),
|
||||
ConsistentFakeEmbeddings(),
|
||||
collection_name=collection_name,
|
||||
path=str(tmpdir),
|
||||
distance_func="Euclid",
|
||||
)
|
||||
|
||||
expected_message = (
|
||||
"configured for COSINE similarity, but requested EUCLID. Please set "
|
||||
"`distance_func` parameter to `COSINE`"
|
||||
)
|
||||
assert expected_message in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
|
||||
def test_qdrant_from_texts_recreates_collection_on_force_recreate(
|
||||
@ -205,7 +213,7 @@ def test_qdrant_from_texts_recreates_collection_on_force_recreate(
|
||||
"""Test if Qdrant.from_texts recreates the collection even if config mismatches"""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
collection_name = "test"
|
||||
collection_name = uuid.uuid4().hex
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
vec_store = Qdrant.from_texts(
|
||||
["lorem", "ipsum", "dolor", "sit", "amet"],
|
||||
@ -250,3 +258,27 @@ def test_qdrant_from_texts_stores_metadatas(
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
|
||||
|
||||
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
|
||||
def test_from_texts_passed_optimizers_config_and_on_disk_payload() -> None:
|
||||
from qdrant_client import models
|
||||
|
||||
collection_name = uuid.uuid4().hex
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
optimizers_config = models.OptimizersConfigDiff(memmap_threshold=1000)
|
||||
vec_store = Qdrant.from_texts(
|
||||
texts,
|
||||
ConsistentFakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
optimizers_config=optimizers_config,
|
||||
on_disk_payload=True,
|
||||
on_disk=True,
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
collection_info = vec_store.client.get_collection(collection_name)
|
||||
assert collection_info.config.params.vectors.on_disk is True # type: ignore
|
||||
assert collection_info.config.optimizer_config.memmap_threshold == 1000
|
||||
assert collection_info.config.params.on_disk_payload is True
|
||||
|
Loading…
Reference in New Issue
Block a user