You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/retrievers/qdrant_sparse_vector_retrie...

210 lines
7.4 KiB
Python

import uuid
from itertools import islice
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Tuple,
cast,
)
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_community.vectorstores.qdrant import Qdrant, QdrantException
class QdrantSparseVectorRetriever(BaseRetriever):
"""Qdrant sparse vector retriever."""
client: Any
"""'qdrant_client' instance to use."""
collection_name: str
"""Qdrant collection name."""
sparse_vector_name: str
"""Name of the sparse vector to use."""
sparse_encoder: Callable[[str], Tuple[List[int], List[float]]]
"""Sparse encoder function to use."""
k: int = 4
"""Number of documents to return per query. Defaults to 4."""
filter: Optional[Any] = None
"""Qdrant qdrant_client.models.Filter to use for queries. Defaults to None."""
content_payload_key: str = "content"
"""Payload field containing the document content. Defaults to 'content'"""
metadata_payload_key: str = "metadata"
"""Payload field containing the document metadata. Defaults to 'metadata'."""
search_options: Dict[str, Any] = {}
"""Additional search options to pass to qdrant_client.QdrantClient.search()."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that 'qdrant_client' python package exists in environment."""
try:
from grpc import RpcError
from qdrant_client import QdrantClient, models
from qdrant_client.http.exceptions import UnexpectedResponse
except ImportError:
raise ImportError(
"Could not import qdrant-client python package. "
"Please install it with `pip install qdrant-client`."
)
client = values["client"]
if not isinstance(client, QdrantClient):
raise ValueError(
f"client should be an instance of qdrant_client.QdrantClient, "
f"got {type(client)}"
)
filter = values["filter"]
if filter is not None and not isinstance(filter, models.Filter):
raise ValueError(
f"filter should be an instance of qdrant_client.models.Filter, "
f"got {type(filter)}"
)
client = cast(QdrantClient, client)
collection_name = values["collection_name"]
sparse_vector_name = values["sparse_vector_name"]
try:
collection_info = client.get_collection(collection_name)
sparse_vectors_config = collection_info.config.params.sparse_vectors
if sparse_vector_name not in sparse_vectors_config:
raise QdrantException(
f"Existing Qdrant collection {collection_name} does not "
f"contain sparse vector named {sparse_vector_name}."
f"Did you mean one of {', '.join(sparse_vectors_config.keys())}?"
)
except (UnexpectedResponse, RpcError, ValueError):
raise QdrantException(
f"Qdrant collection {collection_name} does not exist."
)
return values
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
from qdrant_client import QdrantClient, models
client = cast(QdrantClient, self.client)
query_indices, query_values = self.sparse_encoder(query)
results = client.search(
self.collection_name,
query_filter=self.filter,
query_vector=models.NamedSparseVector(
name=self.sparse_vector_name,
vector=models.SparseVector(
indices=query_indices,
values=query_values,
),
),
limit=self.k,
with_vectors=False,
**self.search_options,
)
return [
Qdrant._document_from_scored_point(
point,
self.collection_name,
self.content_payload_key,
self.metadata_payload_key,
)
for point in results
]
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Run more documents through the embeddings and add to the vectorstore.
Args:
documents (List[Document]: Documents to add to the vectorstore.
Returns:
List[str]: List of IDs of the added texts.
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return self.add_texts(texts, metadatas, **kwargs)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
**kwargs: Any,
) -> List[str]:
from qdrant_client import QdrantClient
added_ids = []
client = cast(QdrantClient, self.client)
for batch_ids, points in self._generate_rest_batches(
texts, metadatas, ids, batch_size
):
client.upsert(self.collection_name, points=points, **kwargs)
added_ids.extend(batch_ids)
return added_ids
def _generate_rest_batches(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
) -> Generator[Tuple[List[str], List[Any]], None, None]:
from qdrant_client import models as rest
texts_iterator = iter(texts)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
# Generate the sparse embeddings for all the texts in a batch
batch_embeddings: List[Tuple[List[int], List[float]]] = [
self.sparse_encoder(text) for text in batch_texts
]
points = [
rest.PointStruct(
id=point_id,
vector={
self.sparse_vector_name: rest.SparseVector(
indices=sparse_vector[0],
values=sparse_vector[1],
)
},
payload=payload,
)
for point_id, sparse_vector, payload in zip(
batch_ids,
batch_embeddings,
Qdrant._build_payloads(
batch_texts,
batch_metadatas,
self.content_payload_key,
self.metadata_payload_key,
),
)
]
yield batch_ids, points