You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/vectorstores/qdrant.py

244 lines
8.0 KiB
Python

"""Wrapper around Qdrant vector database."""
import uuid
from operator import itemgetter
from typing import Any, Callable, Iterable, List, Optional, Tuple
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
class Qdrant(VectorStore):
"""Wrapper around Qdrant vector database.
To use you should have the ``qdrant-client`` package installed.
Example:
.. code-block:: python
from langchain import Qdrant
client = QdrantClient()
collection_name = "MyCollection"
qdrant = Qdrant(client, collection_name, embedding_function)
"""
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
def __init__(self, client: Any, collection_name: str, embedding_function: Callable):
"""Initialize with necessary components."""
try:
import qdrant_client
except ImportError:
raise ValueError(
"Could not import qdrant-client python package. "
"Please it install it with `pip install qdrant-client`."
)
if not isinstance(client, qdrant_client.QdrantClient):
raise ValueError(
f"client should be an instance of qdrant_client.QdrantClient, "
f"got {type(client)}"
)
self.client: qdrant_client.QdrantClient = client
self.collection_name = collection_name
self.embedding_function = embedding_function
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
Returns:
List of ids from adding the texts into the vectorstore.
"""
from qdrant_client.http import models as rest
ids = [uuid.uuid4().hex for _ in texts]
self.client.upsert(
collection_name=self.collection_name,
points=rest.Batch(
ids=ids,
vectors=[self.embedding_function(text) for text in texts],
payloads=self._build_payloads(texts, metadatas),
),
)
return ids
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
results = self.similarity_search_with_score(query, k)
return list(map(itemgetter(0), results))
def similarity_search_with_score(
self, query: str, k: int = 4
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query and score for each
"""
embedding = self.embedding_function(query)
results = self.client.search(
collection_name=self.collection_name,
query_vector=embedding,
with_payload=True,
limit=k,
)
return [
(
self._document_from_scored_point(result),
result.score,
)
for result in results
]
def max_marginal_relevance_search(
self, query: str, k: int = 4, fetch_k: int = 20
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Returns:
List of Documents selected by maximal marginal relevance.
"""
embedding = self.embedding_function(query)
results = self.client.search(
collection_name=self.collection_name,
query_vector=embedding,
with_payload=True,
with_vectors=True,
limit=k,
)
embeddings = [result.vector for result in results]
mmr_selected = maximal_marginal_relevance(embedding, embeddings, k=k)
return [self._document_from_scored_point(results[i]) for i in mmr_selected]
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "Qdrant":
"""Construct Qdrant wrapper from raw documents.
This is a user friendly interface that:
1. Embeds documents.
2. Creates an in memory docstore
3. Initializes the Qdrant database
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain import Qdrant
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
qdrant = Qdrant.from_texts(texts, embeddings)
"""
try:
import qdrant_client
except ImportError:
raise ValueError(
"Could not import qdrant-client python package. "
"Please it install it with `pip install qdrant-client`."
)
from qdrant_client.http import models as rest
# Just do a single quick embedding to get vector size
partial_embeddings = embedding.embed_documents(texts[:1])
vector_size = len(partial_embeddings[0])
qdrant_host = get_from_dict_or_env(kwargs, "host", "QDRANT_HOST")
kwargs.pop("host")
collection_name = kwargs.pop("collection_name", uuid.uuid4().hex)
distance_func = kwargs.pop("distance_func", "Cosine").upper()
client = qdrant_client.QdrantClient(host=qdrant_host, **kwargs)
client.recreate_collection(
collection_name=collection_name,
vectors_config=rest.VectorParams(
size=vector_size,
distance=rest.Distance[distance_func],
),
)
# Now generate the embeddings for all the texts
embeddings = embedding.embed_documents(texts)
client.upsert(
collection_name=collection_name,
points=rest.Batch(
ids=[uuid.uuid4().hex for _ in texts],
vectors=embeddings,
payloads=cls._build_payloads(texts, metadatas),
),
)
return cls(client, collection_name, embedding.embed_query)
@classmethod
def _build_payloads(
cls, texts: Iterable[str], metadatas: Optional[List[dict]]
) -> List[dict]:
payloads = []
for i, text in enumerate(texts):
if text is None:
raise ValueError(
"At least one of the texts is None. Please remove it before "
"calling .from_texts or .add_texts on Qdrant instance."
)
payloads.append(
{
cls.CONTENT_KEY: text,
cls.METADATA_KEY: metadatas[i] if metadatas is not None else None,
}
)
return payloads
@classmethod
def _document_from_scored_point(cls, scored_point: Any) -> Document:
return Document(
page_content=scored_point.payload.get(cls.CONTENT_KEY),
metadata=scored_point.payload.get(cls.METADATA_KEY) or {},
)