Implement async support in Qdrant local mode (#8001)

I've extended the support of async API to local Qdrant mode. It is faked
but allows prototyping without spinning a container. The tests are
improved to test the in-memory case as well.

@baskaryan @rlancemartin @eyurtsev @agola11
This commit is contained in:
Kacper Łukawski 2023-07-21 04:04:33 +02:00 committed by GitHub
parent 7717c24fc4
commit ed6a5532ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 196 additions and 83 deletions

View File

@ -114,7 +114,6 @@
" \"rating\": 9.9,\n",
" \"director\": \"Andrei Tarkovsky\",\n",
" \"genre\": \"science fiction\",\n",
" \"rating\": 9.9,\n",
" },\n",
" ),\n",
"]\n",

View File

@ -1,6 +1,8 @@
"""Wrapper around Qdrant vector database."""
from __future__ import annotations
import asyncio
import functools
import uuid
import warnings
from itertools import islice
@ -40,6 +42,30 @@ class QdrantException(Exception):
"""Base class for all the Qdrant related exceptions"""
def sync_call_fallback(method: Callable) -> Callable:
"""
Decorator to call the synchronous method of the class if the async method is not
implemented. This decorator might be only used for the methods that are defined
as async in the class.
"""
@functools.wraps(method)
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
try:
return await method(self, *args, **kwargs)
except NotImplementedError:
# If the async method is not implemented, call the synchronous method
# by removing the first letter from the method name. For example,
# if the async method is called ``aaad_texts``, the synchronous method
# will be called ``aad_texts``.
sync_method = functools.partial(
getattr(self, method.__name__[1:]), *args, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, sync_method)
return wrapper
class Qdrant(VectorStore):
"""Wrapper around Qdrant vector database.
@ -155,6 +181,7 @@ class Qdrant(VectorStore):
return added_ids
@sync_call_fallback
async def aadd_texts(
self,
texts: Iterable[str],
@ -250,6 +277,7 @@ class Qdrant(VectorStore):
)
return list(map(itemgetter(0), results))
@sync_call_fallback
async def asimilarity_search(
self,
query: str,
@ -322,6 +350,7 @@ class Qdrant(VectorStore):
**kwargs,
)
@sync_call_fallback
async def asimilarity_search_with_score(
self,
query: str,
@ -431,6 +460,7 @@ class Qdrant(VectorStore):
)
return list(map(itemgetter(0), results))
@sync_call_fallback
async def asimilarity_search_by_vector(
self,
embedding: List[float],
@ -567,6 +597,7 @@ class Qdrant(VectorStore):
for result in results
]
@sync_call_fallback
async def asimilarity_search_with_score_by_vector(
self,
embedding: List[float],
@ -685,6 +716,7 @@ class Qdrant(VectorStore):
query_embedding, k, fetch_k, lambda_mult, **kwargs
)
@sync_call_fallback
async def amax_marginal_relevance_search(
self,
query: str,
@ -739,33 +771,12 @@ class Qdrant(VectorStore):
Returns:
List of Documents selected by maximal marginal relevance.
"""
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
with_payload=True,
with_vectors=True,
limit=fetch_k,
results = self.max_marginal_relevance_search_with_score_by_vector(
embedding=embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
)
embeddings = [
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
if self.vector_name is not None
else result.vector
for result in results
]
mmr_selected = maximal_marginal_relevance(
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
)
return [
self._document_from_scored_point(
results[i], self.content_payload_key, self.metadata_payload_key
)
for i in mmr_selected
]
return list(map(itemgetter(0), results))
@sync_call_fallback
async def amax_marginal_relevance_search_by_vector(
self,
embedding: List[float],
@ -795,6 +806,61 @@ class Qdrant(VectorStore):
)
return list(map(itemgetter(0), results))
def max_marginal_relevance_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance and distance for
each.
"""
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
with_payload=True,
with_vectors=True,
limit=fetch_k,
)
embeddings = [
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
if self.vector_name is not None
else result.vector
for result in results
]
mmr_selected = maximal_marginal_relevance(
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
)
return [
(
self._document_from_scored_point(
results[i], self.content_payload_key, self.metadata_payload_key
),
results[i].score,
)
for i in mmr_selected
]
@sync_call_fallback
async def amax_marginal_relevance_search_with_score_by_vector(
self,
embedding: List[float],
@ -1038,7 +1104,6 @@ class Qdrant(VectorStore):
content_payload_key,
metadata_payload_key,
vector_name,
batch_size,
shard_number,
replication_factor,
write_consistency_factor,
@ -1055,6 +1120,7 @@ class Qdrant(VectorStore):
return qdrant
@classmethod
@sync_call_fallback
async def afrom_texts(
cls: Type[Qdrant],
texts: List[str],
@ -1214,7 +1280,6 @@ class Qdrant(VectorStore):
content_payload_key,
metadata_payload_key,
vector_name,
batch_size,
shard_number,
replication_factor,
write_consistency_factor,
@ -1253,7 +1318,6 @@ class Qdrant(VectorStore):
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None,

View File

@ -0,0 +1,13 @@
import logging
from typing import List
from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running
logger = logging.getLogger(__name__)
def qdrant_locations() -> List[str]:
if qdrant_is_not_running():
logger.warning("Running Qdrant async tests in memory mode only.")
return [":memory:"]
return ["http://localhost:6333", ":memory:"]

View File

@ -7,23 +7,23 @@ from langchain.vectorstores import Qdrant
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)
from .common import qdrant_is_not_running
# Skipping all the tests in the module if Qdrant is not running on localhost.
pytestmark = pytest.mark.skipif(
qdrant_is_not_running(), reason="Qdrant server is not running"
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import ( # noqa
qdrant_locations,
)
@pytest.mark.asyncio
@pytest.mark.parametrize("batch_size", [1, 64])
async def test_qdrant_aadd_texts_returns_all_ids(batch_size: int) -> None:
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_returns_all_ids(
batch_size: int, qdrant_location: str
) -> None:
"""Test end to end Qdrant.aadd_texts returns unique ids."""
docsearch: Qdrant = Qdrant.from_texts(
["foobar"],
ConsistentFakeEmbeddings(),
batch_size=batch_size,
location=qdrant_location,
)
ids = await docsearch.aadd_texts(["foo", "bar", "baz"])
@ -33,14 +33,15 @@ async def test_qdrant_aadd_texts_returns_all_ids(batch_size: int) -> None:
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_stores_duplicated_texts(
vector_name: Optional[str],
vector_name: Optional[str], qdrant_location: str
) -> None:
"""Test end to end Qdrant.aadd_texts stores duplicated texts separately."""
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
client = QdrantClient()
client = QdrantClient(location=qdrant_location)
collection_name = "test"
vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE)
if vector_name is not None:
@ -61,7 +62,10 @@ async def test_qdrant_aadd_texts_stores_duplicated_texts(
@pytest.mark.asyncio
@pytest.mark.parametrize("batch_size", [1, 64])
async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_stores_ids(
batch_size: int, qdrant_location: str
) -> None:
"""Test end to end Qdrant.aadd_texts stores provided ids."""
from qdrant_client import QdrantClient
@ -70,7 +74,7 @@ async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
]
client = QdrantClient()
client = QdrantClient(location=qdrant_location)
collection_name = "test"
client.recreate_collection(
collection_name,
@ -90,15 +94,16 @@ async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", ["custom-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_stores_embeddings_as_named_vectors(
vector_name: str,
vector_name: str, qdrant_location: str
) -> None:
"""Test end to end Qdrant.aadd_texts stores named vectors if name is provided."""
from qdrant_client import QdrantClient
collection_name = "test"
client = QdrantClient()
client = QdrantClient(location=qdrant_location)
client.recreate_collection(
collection_name,
vectors_config={

View File

@ -9,56 +9,53 @@ from langchain.vectorstores.qdrant import QdrantException
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)
from .common import qdrant_is_not_running
# Skipping all the tests in the module if Qdrant is not running on localhost.
pytestmark = pytest.mark.skipif(
qdrant_is_not_running(), reason="Qdrant server is not running"
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
qdrant_locations,
)
from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running
@pytest.mark.asyncio
async def test_qdrant_from_texts_stores_duplicated_texts() -> None:
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_from_texts_stores_duplicated_texts(qdrant_location: str) -> None:
"""Test end to end Qdrant.afrom_texts stores duplicated texts separately."""
from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex
await Qdrant.afrom_texts(
vec_store = await Qdrant.afrom_texts(
["abc", "abc"],
ConsistentFakeEmbeddings(),
collection_name=collection_name,
location=qdrant_location,
)
client = QdrantClient()
client = vec_store.client
assert 2 == client.count(collection_name).count
@pytest.mark.asyncio
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_from_texts_stores_ids(
batch_size: int, vector_name: Optional[str]
batch_size: int, vector_name: Optional[str], qdrant_location: str
) -> None:
"""Test end to end Qdrant.afrom_texts stores provided ids."""
from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex
ids = [
"fa38d572-4c31-4579-aedc-1960d79df6df",
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
]
await Qdrant.afrom_texts(
vec_store = await Qdrant.afrom_texts(
["abc", "def"],
ConsistentFakeEmbeddings(),
ids=ids,
collection_name=collection_name,
batch_size=batch_size,
vector_name=vector_name,
location=qdrant_location,
)
client = QdrantClient()
client = vec_store.client
assert 2 == client.count(collection_name).count
stored_ids = [point.id for point in client.scroll(collection_name)[0]]
assert set(ids) == set(stored_ids)
@ -66,22 +63,23 @@ async def test_qdrant_from_texts_stores_ids(
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", ["custom-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
vector_name: str,
qdrant_location: str,
) -> None:
"""Test end to end Qdrant.afrom_texts stores named vectors if name is provided."""
from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex
await Qdrant.afrom_texts(
vec_store = await Qdrant.afrom_texts(
["lorem", "ipsum", "dolor", "sit", "amet"],
ConsistentFakeEmbeddings(),
collection_name=collection_name,
vector_name=vector_name,
location=qdrant_location,
)
client = QdrantClient()
client = vec_store.client
assert 5 == client.count(collection_name).count
assert all(
vector_name in point.vector # type: ignore[operator]
@ -91,12 +89,11 @@ async def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_reuses_same_collection(
vector_name: Optional[str],
) -> None:
"""Test if Qdrant.afrom_texts reuses the same collection"""
from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex
embeddings = ConsistentFakeEmbeddings()
@ -107,19 +104,20 @@ async def test_qdrant_from_texts_reuses_same_collection(
vector_name=vector_name,
)
await Qdrant.afrom_texts(
vec_store = await Qdrant.afrom_texts(
["foo", "bar"],
embeddings,
collection_name=collection_name,
vector_name=vector_name,
)
client = QdrantClient()
client = vec_store.client
assert 7 == client.count(collection_name).count
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_raises_error_on_different_dimensionality(
vector_name: Optional[str],
) -> None:
@ -152,6 +150,7 @@ async def test_qdrant_from_texts_raises_error_on_different_dimensionality(
("my-first-vector", "my-second_vector"),
],
)
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_raises_error_on_different_vector_name(
first_vector_name: Optional[str],
second_vector_name: Optional[str],
@ -176,6 +175,7 @@ async def test_qdrant_from_texts_raises_error_on_different_vector_name(
@pytest.mark.asyncio
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
"""Test if Qdrant.afrom_texts raises an exception if distance does not match"""
collection_name = uuid.uuid4().hex
@ -198,6 +198,7 @@ async def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
vector_name: Optional[str],
) -> None:
@ -229,8 +230,12 @@ async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_from_texts_stores_metadatas(
batch_size: int, content_payload_key: str, metadata_payload_key: str
batch_size: int,
content_payload_key: str,
metadata_payload_key: str,
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@ -242,6 +247,7 @@ async def test_qdrant_from_texts_stores_metadatas(
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
location=qdrant_location,
)
output = await docsearch.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": 0})]

View File

@ -7,12 +7,8 @@ from langchain.vectorstores import Qdrant
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)
from .common import qdrant_is_not_running
# Skipping all the tests in the module if Qdrant is not running on localhost.
pytestmark = pytest.mark.skipif(
qdrant_is_not_running(), reason="Qdrant server is not running"
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
qdrant_locations,
)
@ -21,11 +17,13 @@ pytestmark = pytest.mark.skipif(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_max_marginal_relevance_search(
batch_size: int,
content_payload_key: str,
metadata_payload_key: str,
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and MRR search."""
texts = ["foo", "bar", "baz"]
@ -38,8 +36,12 @@ async def test_qdrant_max_marginal_relevance_search(
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
vector_name=vector_name,
location=qdrant_location,
distance_func="EUCLID", # Euclid distance used to avoid normalization
)
output = await docsearch.amax_marginal_relevance_search(
"foo", k=2, fetch_k=3, lambda_mult=0.0
)
output = await docsearch.amax_marginal_relevance_search("foo", k=2, fetch_k=3)
assert output == [
Document(page_content="foo", metadata={"page": 0}),
Document(page_content="baz", metadata={"page": 2}),

View File

@ -9,12 +9,8 @@ from langchain.vectorstores import Qdrant
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)
from .common import qdrant_is_not_running
# Skipping all the tests in the module if Qdrant is not running on localhost.
pytestmark = pytest.mark.skipif(
qdrant_is_not_running(), reason="Qdrant server is not running"
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
qdrant_locations,
)
@ -23,11 +19,13 @@ pytestmark = pytest.mark.skipif(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search(
batch_size: int,
content_payload_key: str,
metadata_payload_key: str,
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@ -38,6 +36,7 @@ async def test_qdrant_similarity_search(
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
vector_name=vector_name,
location=qdrant_location,
)
output = await docsearch.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
@ -48,11 +47,13 @@ async def test_qdrant_similarity_search(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_by_vector(
batch_size: int,
content_payload_key: str,
metadata_payload_key: str,
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@ -63,6 +64,7 @@ async def test_qdrant_similarity_search_by_vector(
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
vector_name=vector_name,
location=qdrant_location,
)
embeddings = ConsistentFakeEmbeddings().embed_query("foo")
output = await docsearch.asimilarity_search_by_vector(embeddings, k=1)
@ -74,11 +76,13 @@ async def test_qdrant_similarity_search_by_vector(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_with_score_by_vector(
batch_size: int,
content_payload_key: str,
metadata_payload_key: str,
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@ -89,6 +93,7 @@ async def test_qdrant_similarity_search_with_score_by_vector(
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
vector_name=vector_name,
location=qdrant_location,
)
embeddings = ConsistentFakeEmbeddings().embed_query("foo")
output = await docsearch.asimilarity_search_with_score_by_vector(embeddings, k=1)
@ -101,8 +106,9 @@ async def test_qdrant_similarity_search_with_score_by_vector(
@pytest.mark.asyncio
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_filters(
batch_size: int, vector_name: Optional[str]
batch_size: int, vector_name: Optional[str], qdrant_location: str
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@ -116,6 +122,7 @@ async def test_qdrant_similarity_search_filters(
metadatas=metadatas,
batch_size=batch_size,
vector_name=vector_name,
location=qdrant_location,
)
output = await docsearch.asimilarity_search(
@ -131,8 +138,10 @@ async def test_qdrant_similarity_search_filters(
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@ -145,6 +154,7 @@ async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
ConsistentFakeEmbeddings(),
metadatas=metadatas,
vector_name=vector_name,
location=qdrant_location,
)
output = await docsearch.asimilarity_search_with_relevance_scores(
"foo", k=3, score_threshold=None
@ -157,8 +167,10 @@ async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@ -171,6 +183,7 @@ async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
ConsistentFakeEmbeddings(),
metadatas=metadatas,
vector_name=vector_name,
location=qdrant_location,
)
score_threshold = 0.98
@ -184,8 +197,10 @@ async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@ -198,6 +213,7 @@ async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
ConsistentFakeEmbeddings(),
metadatas=metadatas,
vector_name=vector_name,
location=qdrant_location,
)
score_threshold = 0.99 # for almost exact match
# test negative filter condition
@ -217,8 +233,10 @@ async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_filters_with_qdrant_filters(
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@ -231,6 +249,7 @@ async def test_qdrant_similarity_search_filters_with_qdrant_filters(
ConsistentFakeEmbeddings(),
metadatas=metadatas,
vector_name=vector_name,
location=qdrant_location,
)
qdrant_filter = rest.Filter(
@ -263,11 +282,13 @@ async def test_qdrant_similarity_search_filters_with_qdrant_filters(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_with_relevance_scores(
batch_size: int,
content_payload_key: str,
metadata_payload_key: str,
vector_name: str,
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@ -278,6 +299,7 @@ async def test_qdrant_similarity_search_with_relevance_scores(
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
vector_name=vector_name,
location=qdrant_location,
)
output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3)

View File

@ -13,7 +13,6 @@ from tests.integration_tests.vectorstores.fake_embeddings import (
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.skip(reason="Qdrant local behaves differently from Qdrant server")
def test_qdrant_max_marginal_relevance_search(
batch_size: int,
content_payload_key: str,
@ -32,8 +31,11 @@ def test_qdrant_max_marginal_relevance_search(
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
vector_name=vector_name,
distance_func="EUCLID", # Euclid distance used to avoid normalization
)
output = docsearch.max_marginal_relevance_search(
"foo", k=2, fetch_k=3, lambda_mult=0.0
)
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
assert output == [
Document(page_content="foo", metadata={"page": 0}),
Document(page_content="baz", metadata={"page": 2}),