refactor: Qdrant async improvements (#14492)

Follow up on https://github.com/langchain-ai/langchain/pull/13048.
This PR intends to simplify the Qdrant async implementation by replacing
the internal GRPC methods with the `QdrantAsyncClient` methods.
This is a backward compatible change with no additional steps required
after merge.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/15451/head
Anush 9 months ago committed by GitHub
parent cda68d717c
commit 58cc7878e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,11 +22,11 @@ from typing import (
)
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from langchain_core.vectorstores import VectorStore
from langchain_community.docstore.document import Document
from langchain_community.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
@ -94,6 +94,7 @@ class Qdrant(VectorStore):
metadata_payload_key: str = METADATA_KEY,
distance_strategy: str = "COSINE",
vector_name: Optional[str] = VECTOR_NAME,
async_client: Optional[Any] = None,
embedding_function: Optional[Callable] = None, # deprecated
):
"""Initialize with necessary components."""
@ -111,6 +112,14 @@ class Qdrant(VectorStore):
f"got {type(client)}"
)
if async_client is not None and not isinstance(
async_client, qdrant_client.AsyncQdrantClient
):
raise ValueError(
f"async_client should be an instance of qdrant_client.AsyncQdrantClient"
f"got {type(async_client)}"
)
if embeddings is None and embedding_function is None:
raise ValueError(
"`embeddings` value can't be None. Pass `Embeddings` instance."
@ -125,6 +134,7 @@ class Qdrant(VectorStore):
self._embeddings = embeddings
self._embeddings_function = embedding_function
self.client: qdrant_client.QdrantClient = client
self.async_client: Optional[qdrant_client.AsyncQdrantClient] = async_client
self.collection_name = collection_name
self.content_payload_key = content_payload_key or self.CONTENT_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
@ -208,18 +218,21 @@ class Qdrant(VectorStore):
Returns:
List of ids from adding the texts into the vectorstore.
"""
from qdrant_client import grpc # noqa
from qdrant_client.conversions.conversion import RestToGrpc
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
if self.async_client is None or isinstance(
self.async_client._client, AsyncQdrantLocal
):
raise NotImplementedError(
"QdrantLocal cannot interoperate with sync and async clients"
)
added_ids = []
async for batch_ids, points in self._agenerate_rest_batches(
texts, metadatas, ids, batch_size
):
await self.client.async_grpc_points.Upsert(
grpc.UpsertPoints(
collection_name=self.collection_name,
points=[RestToGrpc.convert_point_struct(point) for point in points],
)
await self.async_client.upsert(
collection_name=self.collection_name, points=points, **kwargs
)
added_ids.extend(batch_ids)
@ -399,7 +412,7 @@ class Qdrant(VectorStore):
- 'all' - query all replicas, and return values present in all replicas
**kwargs:
Any other named arguments to pass through to
QdrantClient.async_grpc_points.Search().
AsyncQdrantClient.Search().
Returns:
List of documents most similar to the query text and distance for each.
@ -514,7 +527,7 @@ class Qdrant(VectorStore):
- 'all' - query all replicas, and return values present in all replicas
**kwargs:
Any other named arguments to pass through to
QdrantClient.async_grpc_points.Search().
AsyncQdrantClient.Search().
Returns:
List of Documents most similar to the query.
@ -614,56 +627,6 @@ class Qdrant(VectorStore):
for result in results
]
async def _asearch_with_score_by_vector(
self,
embedding: List[float],
*,
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
with_vectors: bool = False,
**kwargs: Any,
) -> Any:
"""Return results most similar to embedding vector."""
from qdrant_client import grpc # noqa
from qdrant_client.conversions.conversion import RestToGrpc
from qdrant_client.http import models as rest
if filter is not None and isinstance(filter, dict):
warnings.warn(
"Using dict as a `filter` is deprecated. Please use qdrant-client "
"filters directly: "
"https://qdrant.tech/documentation/concepts/filtering/",
DeprecationWarning,
)
qdrant_filter = self._qdrant_filter_from_dict(filter)
else:
qdrant_filter = filter
if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter):
qdrant_filter = RestToGrpc.convert_filter(qdrant_filter)
response = await self.client.async_grpc_points.Search(
grpc.SearchPoints(
collection_name=self.collection_name,
vector_name=self.vector_name,
vector=embedding,
filter=qdrant_filter,
params=search_params,
limit=k,
offset=offset,
with_payload=grpc.WithPayloadSelector(enable=True),
with_vectors=grpc.WithVectorsSelector(enable=with_vectors),
score_threshold=score_threshold,
read_consistency=consistency,
**kwargs,
)
)
return response
@sync_call_fallback
async def asimilarity_search_with_score_by_vector(
self,
@ -706,30 +669,55 @@ class Qdrant(VectorStore):
- 'all' - query all replicas, and return values present in all replicas
**kwargs:
Any other named arguments to pass through to
QdrantClient.async_grpc_points.Search().
AsyncQdrantClient.Search().
Returns:
List of documents most similar to the query text and distance for each.
"""
response = await self._asearch_with_score_by_vector(
embedding,
k=k,
filter=filter,
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
if self.async_client is None or isinstance(
self.async_client._client, AsyncQdrantLocal
):
raise NotImplementedError(
"QdrantLocal cannot interoperate with sync and async clients"
)
if filter is not None and isinstance(filter, dict):
warnings.warn(
"Using dict as a `filter` is deprecated. Please use qdrant-client "
"filters directly: "
"https://qdrant.tech/documentation/concepts/filtering/",
DeprecationWarning,
)
qdrant_filter = self._qdrant_filter_from_dict(filter)
else:
qdrant_filter = filter
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, embedding) # type: ignore[assignment]
results = await self.async_client.search(
collection_name=self.collection_name,
query_vector=query_vector,
query_filter=qdrant_filter,
search_params=search_params,
limit=k,
offset=offset,
with_payload=True,
with_vectors=False, # Langchain does not expect vectors to be returned
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
return [
(
self._document_from_scored_point_grpc(
self._document_from_scored_point(
result, self.content_payload_key, self.metadata_payload_key
),
result.score,
)
for result in response.result
for result in results
]
def max_marginal_relevance_search(
@ -843,7 +831,7 @@ class Qdrant(VectorStore):
- 'all' - query all replicas, and return values present in all replicas
**kwargs:
Any other named arguments to pass through to
QdrantClient.async_grpc_points.Search().
AsyncQdrantClient.Search().
Returns:
List of Documents selected by maximal marginal relevance.
"""
@ -968,7 +956,7 @@ class Qdrant(VectorStore):
- 'all' - query all replicas, and return values present in all replicas
**kwargs:
Any other named arguments to pass through to
QdrantClient.async_grpc_points.Search().
AsyncQdrantClient.Search().
Returns:
List of Documents selected by maximal marginal relevance and distance for
each.
@ -1099,41 +1087,45 @@ class Qdrant(VectorStore):
List of Documents selected by maximal marginal relevance and distance for
each.
"""
from qdrant_client.conversions.conversion import GrpcToRest
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
response = await self._asearch_with_score_by_vector(
embedding,
k=fetch_k,
filter=filter,
if self.async_client is None or isinstance(
self.async_client._client, AsyncQdrantLocal
):
raise NotImplementedError(
"QdrantLocal cannot interoperate with sync and async clients"
)
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
results = await self.async_client.search(
collection_name=self.collection_name,
query_vector=query_vector,
query_filter=filter,
search_params=search_params,
limit=fetch_k,
with_payload=True,
with_vectors=True,
score_threshold=score_threshold,
consistency=consistency,
with_vectors=True,
**kwargs,
)
results = [
GrpcToRest.convert_vectors(result.vectors) for result in response.result
]
embeddings: List[List[float]] = [
result.get(self.vector_name) # type: ignore
if isinstance(result, dict)
else result
embeddings = [
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
if self.vector_name is not None
else result.vector
for result in results
]
mmr_selected: List[int] = maximal_marginal_relevance(
np.array(embedding),
embeddings,
k=k,
lambda_mult=lambda_mult,
mmr_selected = maximal_marginal_relevance(
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
)
return [
(
self._document_from_scored_point_grpc(
response.result[i],
self.content_payload_key,
self.metadata_payload_key,
self._document_from_scored_point(
results[i], self.content_payload_key, self.metadata_payload_key
),
response.result[i].score,
results[i].score,
)
for i in mmr_selected
]
@ -1543,7 +1535,7 @@ class Qdrant(VectorStore):
**kwargs: Any,
) -> Qdrant:
try:
import qdrant_client
import qdrant_client # noqa
except ImportError:
raise ValueError(
"Could not import qdrant-client python package. "
@ -1558,7 +1550,7 @@ class Qdrant(VectorStore):
vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex
distance_func = distance_func.upper()
client = qdrant_client.QdrantClient(
client, async_client = cls._generate_clients(
location=location,
url=url,
port=port,
@ -1669,6 +1661,7 @@ class Qdrant(VectorStore):
metadata_payload_key=metadata_payload_key,
distance_strategy=distance_func,
vector_name=vector_name,
async_client=async_client,
)
return qdrant
@ -1707,7 +1700,7 @@ class Qdrant(VectorStore):
**kwargs: Any,
) -> Qdrant:
try:
import qdrant_client
import qdrant_client # noqa
except ImportError:
raise ValueError(
"Could not import qdrant-client python package. "
@ -1722,7 +1715,7 @@ class Qdrant(VectorStore):
vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex
distance_func = distance_func.upper()
client = qdrant_client.QdrantClient(
client, async_client = cls._generate_clients(
location=location,
url=url,
port=port,
@ -1833,6 +1826,7 @@ class Qdrant(VectorStore):
metadata_payload_key=metadata_payload_key,
distance_strategy=distance_func,
vector_name=vector_name,
async_client=async_client,
)
return qdrant
@ -1922,21 +1916,6 @@ class Qdrant(VectorStore):
metadata=scored_point.payload.get(metadata_payload_key) or {},
)
@classmethod
def _document_from_scored_point_grpc(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
from qdrant_client.conversions.conversion import grpc_to_payload
payload = grpc_to_payload(scored_point.payload)
return Document(
page_content=payload[content_payload_key],
metadata=payload.get(metadata_payload_key) or {},
)
def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
from qdrant_client.http import models as rest
@ -2134,3 +2113,57 @@ class Qdrant(VectorStore):
]
yield batch_ids, points
@staticmethod
def _generate_clients(
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
**kwargs: Any,
) -> Tuple[Any, Any]:
from qdrant_client import AsyncQdrantClient, QdrantClient
sync_client = QdrantClient(
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
**kwargs,
)
if location == ":memory:" or path is not None:
# Local Qdrant cannot co-exist with Sync and Async clients
# We fallback to sync operations in this case
async_client = None
else:
async_client = AsyncQdrantClient(
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
**kwargs,
)
return sync_client, async_client

Loading…
Cancel
Save