mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Implement async support in Qdrant local mode (#8001)
I've extended the support of async API to local Qdrant mode. It is faked but allows prototyping without spinning a container. The tests are improved to test the in-memory case as well. @baskaryan @rlancemartin @eyurtsev @agola11
This commit is contained in:
parent
7717c24fc4
commit
ed6a5532ac
@ -114,7 +114,6 @@
|
||||
" \"rating\": 9.9,\n",
|
||||
" \"director\": \"Andrei Tarkovsky\",\n",
|
||||
" \"genre\": \"science fiction\",\n",
|
||||
" \"rating\": 9.9,\n",
|
||||
" },\n",
|
||||
" ),\n",
|
||||
"]\n",
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""Wrapper around Qdrant vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import uuid
|
||||
import warnings
|
||||
from itertools import islice
|
||||
@ -40,6 +42,30 @@ class QdrantException(Exception):
|
||||
"""Base class for all the Qdrant related exceptions"""
|
||||
|
||||
|
||||
def sync_call_fallback(method: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to call the synchronous method of the class if the async method is not
|
||||
implemented. This decorator might be only used for the methods that are defined
|
||||
as async in the class.
|
||||
"""
|
||||
|
||||
@functools.wraps(method)
|
||||
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return await method(self, *args, **kwargs)
|
||||
except NotImplementedError:
|
||||
# If the async method is not implemented, call the synchronous method
|
||||
# by removing the first letter from the method name. For example,
|
||||
# if the async method is called ``aaad_texts``, the synchronous method
|
||||
# will be called ``aad_texts``.
|
||||
sync_method = functools.partial(
|
||||
getattr(self, method.__name__[1:]), *args, **kwargs
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, sync_method)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Qdrant(VectorStore):
|
||||
"""Wrapper around Qdrant vector database.
|
||||
|
||||
@ -155,6 +181,7 @@ class Qdrant(VectorStore):
|
||||
|
||||
return added_ids
|
||||
|
||||
@sync_call_fallback
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
@ -250,6 +277,7 @@ class Qdrant(VectorStore):
|
||||
)
|
||||
return list(map(itemgetter(0), results))
|
||||
|
||||
@sync_call_fallback
|
||||
async def asimilarity_search(
|
||||
self,
|
||||
query: str,
|
||||
@ -322,6 +350,7 @@ class Qdrant(VectorStore):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@sync_call_fallback
|
||||
async def asimilarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
@ -431,6 +460,7 @@ class Qdrant(VectorStore):
|
||||
)
|
||||
return list(map(itemgetter(0), results))
|
||||
|
||||
@sync_call_fallback
|
||||
async def asimilarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
@ -567,6 +597,7 @@ class Qdrant(VectorStore):
|
||||
for result in results
|
||||
]
|
||||
|
||||
@sync_call_fallback
|
||||
async def asimilarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
@ -685,6 +716,7 @@ class Qdrant(VectorStore):
|
||||
query_embedding, k, fetch_k, lambda_mult, **kwargs
|
||||
)
|
||||
|
||||
@sync_call_fallback
|
||||
async def amax_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
@ -739,33 +771,12 @@ class Qdrant(VectorStore):
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
query_vector = embedding
|
||||
if self.vector_name is not None:
|
||||
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
|
||||
|
||||
results = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_vector=query_vector,
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
limit=fetch_k,
|
||||
results = self.max_marginal_relevance_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
|
||||
)
|
||||
embeddings = [
|
||||
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
|
||||
if self.vector_name is not None
|
||||
else result.vector
|
||||
for result in results
|
||||
]
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
|
||||
)
|
||||
return [
|
||||
self._document_from_scored_point(
|
||||
results[i], self.content_payload_key, self.metadata_payload_key
|
||||
)
|
||||
for i in mmr_selected
|
||||
]
|
||||
return list(map(itemgetter(0), results))
|
||||
|
||||
@sync_call_fallback
|
||||
async def amax_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
@ -795,6 +806,61 @@ class Qdrant(VectorStore):
|
||||
)
|
||||
return list(map(itemgetter(0), results))
|
||||
|
||||
def max_marginal_relevance_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
Defaults to 20.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance and distance for
|
||||
each.
|
||||
"""
|
||||
query_vector = embedding
|
||||
if self.vector_name is not None:
|
||||
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
|
||||
|
||||
results = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_vector=query_vector,
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
limit=fetch_k,
|
||||
)
|
||||
embeddings = [
|
||||
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
|
||||
if self.vector_name is not None
|
||||
else result.vector
|
||||
for result in results
|
||||
]
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
|
||||
)
|
||||
return [
|
||||
(
|
||||
self._document_from_scored_point(
|
||||
results[i], self.content_payload_key, self.metadata_payload_key
|
||||
),
|
||||
results[i].score,
|
||||
)
|
||||
for i in mmr_selected
|
||||
]
|
||||
|
||||
@sync_call_fallback
|
||||
async def amax_marginal_relevance_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
@ -1038,7 +1104,6 @@ class Qdrant(VectorStore):
|
||||
content_payload_key,
|
||||
metadata_payload_key,
|
||||
vector_name,
|
||||
batch_size,
|
||||
shard_number,
|
||||
replication_factor,
|
||||
write_consistency_factor,
|
||||
@ -1055,6 +1120,7 @@ class Qdrant(VectorStore):
|
||||
return qdrant
|
||||
|
||||
@classmethod
|
||||
@sync_call_fallback
|
||||
async def afrom_texts(
|
||||
cls: Type[Qdrant],
|
||||
texts: List[str],
|
||||
@ -1214,7 +1280,6 @@ class Qdrant(VectorStore):
|
||||
content_payload_key,
|
||||
metadata_payload_key,
|
||||
vector_name,
|
||||
batch_size,
|
||||
shard_number,
|
||||
replication_factor,
|
||||
write_consistency_factor,
|
||||
@ -1253,7 +1318,6 @@ class Qdrant(VectorStore):
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
batch_size: int = 64,
|
||||
shard_number: Optional[int] = None,
|
||||
replication_factor: Optional[int] = None,
|
||||
write_consistency_factor: Optional[int] = None,
|
||||
|
@ -0,0 +1,13 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def qdrant_locations() -> List[str]:
|
||||
if qdrant_is_not_running():
|
||||
logger.warning("Running Qdrant async tests in memory mode only.")
|
||||
return [":memory:"]
|
||||
return ["http://localhost:6333", ":memory:"]
|
@ -7,23 +7,23 @@ from langchain.vectorstores import Qdrant
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
)
|
||||
|
||||
from .common import qdrant_is_not_running
|
||||
|
||||
# Skipping all the tests in the module if Qdrant is not running on localhost.
|
||||
pytestmark = pytest.mark.skipif(
|
||||
qdrant_is_not_running(), reason="Qdrant server is not running"
|
||||
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import ( # noqa
|
||||
qdrant_locations,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
async def test_qdrant_aadd_texts_returns_all_ids(batch_size: int) -> None:
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_aadd_texts_returns_all_ids(
|
||||
batch_size: int, qdrant_location: str
|
||||
) -> None:
|
||||
"""Test end to end Qdrant.aadd_texts returns unique ids."""
|
||||
docsearch: Qdrant = Qdrant.from_texts(
|
||||
["foobar"],
|
||||
ConsistentFakeEmbeddings(),
|
||||
batch_size=batch_size,
|
||||
location=qdrant_location,
|
||||
)
|
||||
|
||||
ids = await docsearch.aadd_texts(["foo", "bar", "baz"])
|
||||
@ -33,14 +33,15 @@ async def test_qdrant_aadd_texts_returns_all_ids(batch_size: int) -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_aadd_texts_stores_duplicated_texts(
|
||||
vector_name: Optional[str],
|
||||
vector_name: Optional[str], qdrant_location: str
|
||||
) -> None:
|
||||
"""Test end to end Qdrant.aadd_texts stores duplicated texts separately."""
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
client = QdrantClient()
|
||||
client = QdrantClient(location=qdrant_location)
|
||||
collection_name = "test"
|
||||
vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE)
|
||||
if vector_name is not None:
|
||||
@ -61,7 +62,10 @@ async def test_qdrant_aadd_texts_stores_duplicated_texts(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_aadd_texts_stores_ids(
|
||||
batch_size: int, qdrant_location: str
|
||||
) -> None:
|
||||
"""Test end to end Qdrant.aadd_texts stores provided ids."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
@ -70,7 +74,7 @@ async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
|
||||
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
|
||||
]
|
||||
|
||||
client = QdrantClient()
|
||||
client = QdrantClient(location=qdrant_location)
|
||||
collection_name = "test"
|
||||
client.recreate_collection(
|
||||
collection_name,
|
||||
@ -90,15 +94,16 @@ async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("vector_name", ["custom-vector"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_aadd_texts_stores_embeddings_as_named_vectors(
|
||||
vector_name: str,
|
||||
vector_name: str, qdrant_location: str
|
||||
) -> None:
|
||||
"""Test end to end Qdrant.aadd_texts stores named vectors if name is provided."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
collection_name = "test"
|
||||
|
||||
client = QdrantClient()
|
||||
client = QdrantClient(location=qdrant_location)
|
||||
client.recreate_collection(
|
||||
collection_name,
|
||||
vectors_config={
|
@ -9,56 +9,53 @@ from langchain.vectorstores.qdrant import QdrantException
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
)
|
||||
|
||||
from .common import qdrant_is_not_running
|
||||
|
||||
# Skipping all the tests in the module if Qdrant is not running on localhost.
|
||||
pytestmark = pytest.mark.skipif(
|
||||
qdrant_is_not_running(), reason="Qdrant server is not running"
|
||||
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
|
||||
qdrant_locations,
|
||||
)
|
||||
from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qdrant_from_texts_stores_duplicated_texts() -> None:
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_from_texts_stores_duplicated_texts(qdrant_location: str) -> None:
|
||||
"""Test end to end Qdrant.afrom_texts stores duplicated texts separately."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
collection_name = uuid.uuid4().hex
|
||||
|
||||
await Qdrant.afrom_texts(
|
||||
vec_store = await Qdrant.afrom_texts(
|
||||
["abc", "abc"],
|
||||
ConsistentFakeEmbeddings(),
|
||||
collection_name=collection_name,
|
||||
location=qdrant_location,
|
||||
)
|
||||
|
||||
client = QdrantClient()
|
||||
client = vec_store.client
|
||||
assert 2 == client.count(collection_name).count
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_from_texts_stores_ids(
|
||||
batch_size: int, vector_name: Optional[str]
|
||||
batch_size: int, vector_name: Optional[str], qdrant_location: str
|
||||
) -> None:
|
||||
"""Test end to end Qdrant.afrom_texts stores provided ids."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
collection_name = uuid.uuid4().hex
|
||||
ids = [
|
||||
"fa38d572-4c31-4579-aedc-1960d79df6df",
|
||||
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
|
||||
]
|
||||
await Qdrant.afrom_texts(
|
||||
vec_store = await Qdrant.afrom_texts(
|
||||
["abc", "def"],
|
||||
ConsistentFakeEmbeddings(),
|
||||
ids=ids,
|
||||
collection_name=collection_name,
|
||||
batch_size=batch_size,
|
||||
vector_name=vector_name,
|
||||
location=qdrant_location,
|
||||
)
|
||||
|
||||
client = QdrantClient()
|
||||
client = vec_store.client
|
||||
assert 2 == client.count(collection_name).count
|
||||
stored_ids = [point.id for point in client.scroll(collection_name)[0]]
|
||||
assert set(ids) == set(stored_ids)
|
||||
@ -66,22 +63,23 @@ async def test_qdrant_from_texts_stores_ids(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("vector_name", ["custom-vector"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
|
||||
vector_name: str,
|
||||
qdrant_location: str,
|
||||
) -> None:
|
||||
"""Test end to end Qdrant.afrom_texts stores named vectors if name is provided."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
collection_name = uuid.uuid4().hex
|
||||
|
||||
await Qdrant.afrom_texts(
|
||||
vec_store = await Qdrant.afrom_texts(
|
||||
["lorem", "ipsum", "dolor", "sit", "amet"],
|
||||
ConsistentFakeEmbeddings(),
|
||||
collection_name=collection_name,
|
||||
vector_name=vector_name,
|
||||
location=qdrant_location,
|
||||
)
|
||||
|
||||
client = QdrantClient()
|
||||
client = vec_store.client
|
||||
assert 5 == client.count(collection_name).count
|
||||
assert all(
|
||||
vector_name in point.vector # type: ignore[operator]
|
||||
@ -91,12 +89,11 @@ async def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
|
||||
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
|
||||
async def test_qdrant_from_texts_reuses_same_collection(
|
||||
vector_name: Optional[str],
|
||||
) -> None:
|
||||
"""Test if Qdrant.afrom_texts reuses the same collection"""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
collection_name = uuid.uuid4().hex
|
||||
embeddings = ConsistentFakeEmbeddings()
|
||||
|
||||
@ -107,19 +104,20 @@ async def test_qdrant_from_texts_reuses_same_collection(
|
||||
vector_name=vector_name,
|
||||
)
|
||||
|
||||
await Qdrant.afrom_texts(
|
||||
vec_store = await Qdrant.afrom_texts(
|
||||
["foo", "bar"],
|
||||
embeddings,
|
||||
collection_name=collection_name,
|
||||
vector_name=vector_name,
|
||||
)
|
||||
|
||||
client = QdrantClient()
|
||||
client = vec_store.client
|
||||
assert 7 == client.count(collection_name).count
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
|
||||
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
|
||||
async def test_qdrant_from_texts_raises_error_on_different_dimensionality(
|
||||
vector_name: Optional[str],
|
||||
) -> None:
|
||||
@ -152,6 +150,7 @@ async def test_qdrant_from_texts_raises_error_on_different_dimensionality(
|
||||
("my-first-vector", "my-second_vector"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
|
||||
async def test_qdrant_from_texts_raises_error_on_different_vector_name(
|
||||
first_vector_name: Optional[str],
|
||||
second_vector_name: Optional[str],
|
||||
@ -176,6 +175,7 @@ async def test_qdrant_from_texts_raises_error_on_different_vector_name(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
|
||||
async def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
|
||||
"""Test if Qdrant.afrom_texts raises an exception if distance does not match"""
|
||||
collection_name = uuid.uuid4().hex
|
||||
@ -198,6 +198,7 @@ async def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
|
||||
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
|
||||
async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
|
||||
vector_name: Optional[str],
|
||||
) -> None:
|
||||
@ -229,8 +230,12 @@ async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_from_texts_stores_metadatas(
|
||||
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
||||
batch_size: int,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
qdrant_location: str,
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
@ -242,6 +247,7 @@ async def test_qdrant_from_texts_stores_metadatas(
|
||||
content_payload_key=content_payload_key,
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
batch_size=batch_size,
|
||||
location=qdrant_location,
|
||||
)
|
||||
output = await docsearch.asimilarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
@ -7,12 +7,8 @@ from langchain.vectorstores import Qdrant
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
)
|
||||
|
||||
from .common import qdrant_is_not_running
|
||||
|
||||
# Skipping all the tests in the module if Qdrant is not running on localhost.
|
||||
pytestmark = pytest.mark.skipif(
|
||||
qdrant_is_not_running(), reason="Qdrant server is not running"
|
||||
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
|
||||
qdrant_locations,
|
||||
)
|
||||
|
||||
|
||||
@ -21,11 +17,13 @@ pytestmark = pytest.mark.skipif(
|
||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"])
|
||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"])
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_max_marginal_relevance_search(
|
||||
batch_size: int,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
vector_name: Optional[str],
|
||||
qdrant_location: str,
|
||||
) -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
@ -38,8 +36,12 @@ async def test_qdrant_max_marginal_relevance_search(
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
batch_size=batch_size,
|
||||
vector_name=vector_name,
|
||||
location=qdrant_location,
|
||||
distance_func="EUCLID", # Euclid distance used to avoid normalization
|
||||
)
|
||||
output = await docsearch.amax_marginal_relevance_search(
|
||||
"foo", k=2, fetch_k=3, lambda_mult=0.0
|
||||
)
|
||||
output = await docsearch.amax_marginal_relevance_search("foo", k=2, fetch_k=3)
|
||||
assert output == [
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
Document(page_content="baz", metadata={"page": 2}),
|
@ -9,12 +9,8 @@ from langchain.vectorstores import Qdrant
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
)
|
||||
|
||||
from .common import qdrant_is_not_running
|
||||
|
||||
# Skipping all the tests in the module if Qdrant is not running on localhost.
|
||||
pytestmark = pytest.mark.skipif(
|
||||
qdrant_is_not_running(), reason="Qdrant server is not running"
|
||||
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
|
||||
qdrant_locations,
|
||||
)
|
||||
|
||||
|
||||
@ -23,11 +19,13 @@ pytestmark = pytest.mark.skipif(
|
||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_similarity_search(
|
||||
batch_size: int,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
vector_name: Optional[str],
|
||||
qdrant_location: str,
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
@ -38,6 +36,7 @@ async def test_qdrant_similarity_search(
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
batch_size=batch_size,
|
||||
vector_name=vector_name,
|
||||
location=qdrant_location,
|
||||
)
|
||||
output = await docsearch.asimilarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
@ -48,11 +47,13 @@ async def test_qdrant_similarity_search(
|
||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_similarity_search_by_vector(
|
||||
batch_size: int,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
vector_name: Optional[str],
|
||||
qdrant_location: str,
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
@ -63,6 +64,7 @@ async def test_qdrant_similarity_search_by_vector(
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
batch_size=batch_size,
|
||||
vector_name=vector_name,
|
||||
location=qdrant_location,
|
||||
)
|
||||
embeddings = ConsistentFakeEmbeddings().embed_query("foo")
|
||||
output = await docsearch.asimilarity_search_by_vector(embeddings, k=1)
|
||||
@ -74,11 +76,13 @@ async def test_qdrant_similarity_search_by_vector(
|
||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_similarity_search_with_score_by_vector(
|
||||
batch_size: int,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
vector_name: Optional[str],
|
||||
qdrant_location: str,
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
@ -89,6 +93,7 @@ async def test_qdrant_similarity_search_with_score_by_vector(
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
batch_size=batch_size,
|
||||
vector_name=vector_name,
|
||||
location=qdrant_location,
|
||||
)
|
||||
embeddings = ConsistentFakeEmbeddings().embed_query("foo")
|
||||
output = await docsearch.asimilarity_search_with_score_by_vector(embeddings, k=1)
|
||||
@ -101,8 +106,9 @@ async def test_qdrant_similarity_search_with_score_by_vector(
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_similarity_search_filters(
|
||||
batch_size: int, vector_name: Optional[str]
|
||||
batch_size: int, vector_name: Optional[str], qdrant_location: str
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
@ -116,6 +122,7 @@ async def test_qdrant_similarity_search_filters(
|
||||
metadatas=metadatas,
|
||||
batch_size=batch_size,
|
||||
vector_name=vector_name,
|
||||
location=qdrant_location,
|
||||
)
|
||||
|
||||
output = await docsearch.asimilarity_search(
|
||||
@ -131,8 +138,10 @@ async def test_qdrant_similarity_search_filters(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
|
||||
vector_name: Optional[str],
|
||||
qdrant_location: str,
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
@ -145,6 +154,7 @@ async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
|
||||
ConsistentFakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
vector_name=vector_name,
|
||||
location=qdrant_location,
|
||||
)
|
||||
output = await docsearch.asimilarity_search_with_relevance_scores(
|
||||
"foo", k=3, score_threshold=None
|
||||
@ -157,8 +167,10 @@ async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
|
||||
vector_name: Optional[str],
|
||||
qdrant_location: str,
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
@ -171,6 +183,7 @@ async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
|
||||
ConsistentFakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
vector_name=vector_name,
|
||||
location=qdrant_location,
|
||||
)
|
||||
|
||||
score_threshold = 0.98
|
||||
@ -184,8 +197,10 @@ async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
|
||||
vector_name: Optional[str],
|
||||
qdrant_location: str,
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
@ -198,6 +213,7 @@ async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
|
||||
ConsistentFakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
vector_name=vector_name,
|
||||
location=qdrant_location,
|
||||
)
|
||||
score_threshold = 0.99 # for almost exact match
|
||||
# test negative filter condition
|
||||
@ -217,8 +233,10 @@ async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_similarity_search_filters_with_qdrant_filters(
|
||||
vector_name: Optional[str],
|
||||
qdrant_location: str,
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
@ -231,6 +249,7 @@ async def test_qdrant_similarity_search_filters_with_qdrant_filters(
|
||||
ConsistentFakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
vector_name=vector_name,
|
||||
location=qdrant_location,
|
||||
)
|
||||
|
||||
qdrant_filter = rest.Filter(
|
||||
@ -263,11 +282,13 @@ async def test_qdrant_similarity_search_filters_with_qdrant_filters(
|
||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||
async def test_qdrant_similarity_search_with_relevance_scores(
|
||||
batch_size: int,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
vector_name: str,
|
||||
qdrant_location: str,
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
@ -278,6 +299,7 @@ async def test_qdrant_similarity_search_with_relevance_scores(
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
batch_size=batch_size,
|
||||
vector_name=vector_name,
|
||||
location=qdrant_location,
|
||||
)
|
||||
output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3)
|
||||
|
@ -13,7 +13,6 @@ from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"])
|
||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"])
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
@pytest.mark.skip(reason="Qdrant local behaves differently from Qdrant server")
|
||||
def test_qdrant_max_marginal_relevance_search(
|
||||
batch_size: int,
|
||||
content_payload_key: str,
|
||||
@ -32,8 +31,11 @@ def test_qdrant_max_marginal_relevance_search(
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
batch_size=batch_size,
|
||||
vector_name=vector_name,
|
||||
distance_func="EUCLID", # Euclid distance used to avoid normalization
|
||||
)
|
||||
output = docsearch.max_marginal_relevance_search(
|
||||
"foo", k=2, fetch_k=3, lambda_mult=0.0
|
||||
)
|
||||
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
|
||||
assert output == [
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
Document(page_content="baz", metadata={"page": 2}),
|
||||
|
Loading…
Reference in New Issue
Block a user