You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/vectorstores/qdrant.py

547 lines
19 KiB
Python

"""Wrapper around Qdrant vector database."""
from __future__ import annotations
import uuid
import warnings
from hashlib import md5
from itertools import islice
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
Union,
)
import numpy as np
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = Dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
class Qdrant(VectorStore):
"""Wrapper around Qdrant vector database.
To use you should have the ``qdrant-client`` package installed.
Example:
.. code-block:: python
from qdrant_client import QdrantClient
from langchain import Qdrant
client = QdrantClient()
collection_name = "MyCollection"
qdrant = Qdrant(client, collection_name, embedding_function)
"""
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
def __init__(
self,
client: Any,
collection_name: str,
embeddings: Optional[Embeddings] = None,
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
embedding_function: Optional[Callable] = None, # deprecated
):
"""Initialize with necessary components."""
try:
import qdrant_client
except ImportError:
raise ValueError(
"Could not import qdrant-client python package. "
"Please install it with `pip install qdrant-client`."
)
if not isinstance(client, qdrant_client.QdrantClient):
raise ValueError(
f"client should be an instance of qdrant_client.QdrantClient, "
f"got {type(client)}"
)
if embeddings is None and embedding_function is None:
raise ValueError(
"`embeddings` value can't be None. Pass `Embeddings` instance."
)
if embeddings is not None and embedding_function is not None:
raise ValueError(
"Both `embeddings` and `embedding_function` are passed. "
"Use `embeddings` only."
)
self.embeddings = embeddings
self._embeddings_function = embedding_function
self.client: qdrant_client.QdrantClient = client
self.collection_name = collection_name
self.content_payload_key = content_payload_key or self.CONTENT_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
if embedding_function is not None:
warnings.warn(
"Using `embedding_function` is deprecated. "
"Pass `Embeddings` instance to `embeddings` instead."
)
if not isinstance(embeddings, Embeddings):
warnings.warn(
"`embeddings` should be an instance of `Embeddings`."
"Using `embeddings` as `embedding_function` which is deprecated"
)
self._embeddings_function = embeddings
self.embeddings = None
def _embed_query(self, query: str) -> List[float]:
"""Embed query text.
Used to provide backward compatibility with `embedding_function` argument.
Args:
query: Query text.
Returns:
List of floats representing the query embedding.
"""
if self.embeddings is not None:
embedding = self.embeddings.embed_query(query)
else:
if self._embeddings_function is not None:
embedding = self._embeddings_function(query)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embedding.tolist() if hasattr(embedding, "tolist") else embedding
def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]:
"""Embed search texts.
Used to provide backward compatibility with `embedding_function` argument.
Args:
texts: Iterable of texts to embed.
Returns:
List of floats representing the texts embedding.
"""
if self.embeddings is not None:
embeddings = self.embeddings.embed_documents(list(texts))
if hasattr(embeddings, "tolist"):
embeddings = embeddings.tolist()
elif self._embeddings_function is not None:
embeddings = []
for text in texts:
embedding = self._embeddings_function(text)
if hasattr(embeddings, "tolist"):
embedding = embedding.tolist()
embeddings.append(embedding)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embeddings
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
batch_size: int = 64,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
Returns:
List of ids from adding the texts into the vectorstore.
"""
from qdrant_client.http import models as rest
ids = []
texts_iterator = iter(texts)
metadatas_iterator = iter(metadatas or [])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = [md5(text.encode("utf-8")).hexdigest() for text in batch_texts]
self.client.upsert(
collection_name=self.collection_name,
points=rest.Batch.construct(
ids=batch_ids,
vectors=self._embed_texts(batch_texts),
payloads=self._build_payloads(
batch_texts,
batch_metadatas,
self.content_payload_key,
self.metadata_payload_key,
),
),
)
ids.extend(batch_ids)
return ids
def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[MetadataFilter] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
Returns:
List of Documents most similar to the query.
"""
results = self.similarity_search_with_score(query, k, filter=filter)
return list(map(itemgetter(0), results))
def similarity_search_with_score(
self, query: str, k: int = 4, filter: Optional[MetadataFilter] = None
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
Returns:
List of Documents most similar to the query and score for each.
"""
if filter is not None and isinstance(filter, dict):
warnings.warn(
"Using dict as a `filter` is deprecated. Please use qdrant-client "
"filters directly: "
"https://qdrant.tech/documentation/concepts/filtering/",
DeprecationWarning,
)
qdrant_filter = self._qdrant_filter_from_dict(filter)
else:
qdrant_filter = filter
results = self.client.search(
collection_name=self.collection_name,
query_vector=self._embed_query(query),
query_filter=qdrant_filter,
with_payload=True,
limit=k,
)
return [
(
self._document_from_scored_point(
result, self.content_payload_key, self.metadata_payload_key
),
result.score,
)
for result in results
]
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
embedding = self._embed_query(query)
results = self.client.search(
collection_name=self.collection_name,
query_vector=embedding,
with_payload=True,
with_vectors=True,
limit=fetch_k,
)
embeddings = [result.vector for result in results]
mmr_selected = maximal_marginal_relevance(
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
)
return [
self._document_from_scored_point(
results[i], self.content_payload_key, self.metadata_payload_key
)
for i in mmr_selected
]
@classmethod
def from_texts(
cls: Type[Qdrant],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
batch_size: int = 64,
**kwargs: Any,
) -> Qdrant:
"""Construct Qdrant wrapper from a list of texts.
Args:
texts: A list of texts to be indexed in Qdrant.
embedding: A subclass of `Embeddings`, responsible for text vectorization.
metadatas:
An optional list of metadata. If provided it has to be of the same
length as a list of texts.
location:
If `:memory:` - use in-memory Qdrant instance.
If `str` - use it as a `url` parameter.
If `None` - fallback to relying on `host` and `port` parameters.
url: either host or str of "Optional[scheme], host, Optional[port],
Optional[prefix]". Default: `None`
port: Port of the REST API interface. Default: 6333
grpc_port: Port of the gRPC interface. Default: 6334
prefer_grpc:
If true - use gPRC interface whenever possible in custom methods.
Default: False
https: If true - use HTTPS(SSL) protocol. Default: None
api_key: API key for authentication in Qdrant Cloud. Default: None
prefix:
If not None - add prefix to the REST URL path.
Example: service/v1 will result in
http://localhost:6333/service/v1/{qdrant-endpoint} for REST API.
Default: None
timeout:
Timeout for REST and gRPC API requests.
Default: 5.0 seconds for REST and unlimited for gRPC
host:
Host name of Qdrant service. If url and host are None, set to
'localhost'. Default: None
path:
Path in which the vectors will be stored while using local mode.
Default: None
collection_name:
Name of the Qdrant collection to be used. If not provided,
it will be created randomly. Default: None
distance_func:
Distance function. One of: "Cosine" / "Euclid" / "Dot".
Default: "Cosine"
content_payload_key:
A payload key used to store the content of the document.
Default: "page_content"
metadata_payload_key:
A payload key used to store the metadata of the document.
Default: "metadata"
**kwargs:
Additional arguments passed directly into REST client initialization
This is a user-friendly interface that:
1. Creates embeddings, one for each text
2. Initializes the Qdrant database as an in-memory docstore by default
(and overridable to a remote docstore)
3. Adds the text embeddings to the Qdrant database
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain import Qdrant
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
qdrant = Qdrant.from_texts(texts, embeddings, "localhost")
"""
try:
import qdrant_client
except ImportError:
raise ValueError(
"Could not import qdrant-client python package. "
"Please install it with `pip install qdrant-client`."
)
from qdrant_client.http import models as rest
# Just do a single quick embedding to get vector size
partial_embeddings = embedding.embed_documents(texts[:1])
vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex
distance_func = distance_func.upper()
client = qdrant_client.QdrantClient(
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
**kwargs,
)
client.recreate_collection(
collection_name=collection_name,
vectors_config=rest.VectorParams(
size=vector_size,
distance=rest.Distance[distance_func],
),
)
texts_iterator = iter(texts)
metadatas_iterator = iter(metadatas or [])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
# Generate the embeddings for all the texts in a batch
batch_embeddings = embedding.embed_documents(batch_texts)
client.upsert(
collection_name=collection_name,
points=rest.Batch.construct(
ids=[md5(text.encode("utf-8")).hexdigest() for text in batch_texts],
vectors=batch_embeddings,
payloads=cls._build_payloads(
batch_texts,
batch_metadatas,
content_payload_key,
metadata_payload_key,
),
),
)
return cls(
client=client,
collection_name=collection_name,
embeddings=embedding,
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
)
@classmethod
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: Optional[List[dict]],
content_payload_key: str,
metadata_payload_key: str,
) -> List[dict]:
payloads = []
for i, text in enumerate(texts):
if text is None:
raise ValueError(
"At least one of the texts is None. Please remove it before "
"calling .from_texts or .add_texts on Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append(
{
content_payload_key: text,
metadata_payload_key: metadata,
}
)
return payloads
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata=scored_point.payload.get(metadata_payload_key) or {},
)
def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
from qdrant_client.http import models as rest
out = []
if isinstance(value, dict):
for _key, value in value.items():
out.extend(self._build_condition(f"{key}.{_key}", value))
elif isinstance(value, list):
for _value in value:
if isinstance(_value, dict):
out.extend(self._build_condition(f"{key}[]", _value))
else:
out.extend(self._build_condition(f"{key}", _value))
else:
out.append(
rest.FieldCondition(
key=f"{self.metadata_payload_key}.{key}",
match=rest.MatchValue(value=value),
)
)
return out
def _qdrant_filter_from_dict(
self, filter: Optional[DictFilter]
) -> Optional[rest.Filter]:
from qdrant_client.http import models as rest
if not filter:
return None
return rest.Filter(
must=[
condition
for key, value in filter.items()
for condition in self._build_condition(key, value)
]
)