@ -2,10 +2,13 @@
from __future__ import annotations
import uuid
import warnings
from hashlib import md5
from operator import itemgetter
from typing import Any , Callable , Dict , Iterable , List , Optional , Tuple , Type , Union
import numpy as np
from langchain . docstore . document import Document
from langchain . embeddings . base import Embeddings
from langchain . vectorstores import VectorStore
@ -37,9 +40,10 @@ class Qdrant(VectorStore):
self ,
client : Any ,
collection_name : str ,
embedding _function: Callabl e,
embedding s: Optional [ Embeddings ] = Non e,
content_payload_key : str = CONTENT_KEY ,
metadata_payload_key : str = METADATA_KEY ,
embedding_function : Optional [ Callable ] = None , # deprecated
) :
""" Initialize with necessary components. """
try :
@ -56,12 +60,85 @@ class Qdrant(VectorStore):
f " got { type ( client ) } "
)
if embeddings is None and embedding_function is None :
raise ValueError (
" `embeddings` value can ' t be None. Pass `Embeddings` instance. "
)
if embeddings is not None and embedding_function is not None :
raise ValueError (
" Both `embeddings` and `embedding_function` are passed. "
" Use `embeddings` only. "
)
self . embeddings = embeddings
self . _embeddings_function = embedding_function
self . client : qdrant_client . QdrantClient = client
self . collection_name = collection_name
self . embedding_function = embedding_function
self . content_payload_key = content_payload_key or self . CONTENT_KEY
self . metadata_payload_key = metadata_payload_key or self . METADATA_KEY
if embedding_function is not None :
warnings . warn (
" Using `embedding_function` is deprecated. "
" Pass `Embeddings` instance to `embeddings` instead. "
)
if not isinstance ( embeddings , Embeddings ) :
warnings . warn (
" `embeddings` should be an instance of `Embeddings`. "
" Using `embeddings` as `embedding_function` which is deprecated "
)
self . _embeddings_function = embeddings
self . embeddings = None
def _embed_query ( self , query : str ) - > List [ float ] :
""" Embed query text.
Used to provide backward compatibility with ` embedding_function ` argument .
Args :
query : Query text .
Returns :
List of floats representing the query embedding .
"""
if self . embeddings is not None :
embedding = self . embeddings . embed_query ( query )
else :
if self . _embeddings_function is not None :
embedding = self . _embeddings_function ( query )
else :
raise ValueError ( " Neither of embeddings or embedding_function is set " )
return embedding . tolist ( ) if hasattr ( embedding , " tolist " ) else embedding
def _embed_texts ( self , texts : Iterable [ str ] ) - > List [ List [ float ] ] :
""" Embed search texts.
Used to provide backward compatibility with ` embedding_function ` argument .
Args :
texts : Iterable of texts to embed .
Returns :
List of floats representing the texts embedding .
"""
if self . embeddings is not None :
embeddings = self . embeddings . embed_documents ( list ( texts ) )
if hasattr ( embeddings , " tolist " ) :
embeddings = embeddings . tolist ( )
elif self . _embeddings_function is not None :
embeddings = [ ]
for text in texts :
embedding = self . _embeddings_function ( text )
if hasattr ( embeddings , " tolist " ) :
embedding = embedding . tolist ( )
embeddings . append ( embedding )
else :
raise ValueError ( " Neither of embeddings or embedding_function is set " )
return embeddings
def add_texts (
self ,
texts : Iterable [ str ] ,
@ -79,12 +156,16 @@ class Qdrant(VectorStore):
"""
from qdrant_client . http import models as rest
texts = list (
texts
) # otherwise iterable might be exhausted after id calculation
ids = [ md5 ( text . encode ( " utf-8 " ) ) . hexdigest ( ) for text in texts ]
self . client . upsert (
collection_name = self . collection_name ,
points = rest . Batch . construct (
ids = ids ,
vectors = [ self . embedding_function ( text ) for text in texts ] ,
vectors = self . _embed_texts ( texts ) ,
payloads = self . _build_payloads (
texts ,
metadatas ,
@ -129,10 +210,10 @@ class Qdrant(VectorStore):
Returns :
List of Documents most similar to the query and score for each .
"""
embedding = self . embedding_function ( query )
results = self . client . search (
collection_name = self . collection_name ,
query_vector = embedding ,
query_vector = self . _embed_query ( query ) ,
query_filter = self . _qdrant_filter_from_dict ( filter ) ,
with_payload = True ,
limit = k ,
@ -172,7 +253,8 @@ class Qdrant(VectorStore):
Returns :
List of Documents selected by maximal marginal relevance .
"""
embedding = self . embedding_function ( query )
embedding = self . _embed_query ( query )
results = self . client . search (
collection_name = self . collection_name ,
query_vector = embedding ,
@ -182,7 +264,7 @@ class Qdrant(VectorStore):
)
embeddings = [ result . vector for result in results ]
mmr_selected = maximal_marginal_relevance (
embedding, embeddings , k = k , lambda_mult = lambda_mult
np. array ( embedding) , embeddings , k = k , lambda_mult = lambda_mult
)
return [
self . _document_from_scored_point (
@ -337,7 +419,7 @@ class Qdrant(VectorStore):
return cls (
client = client ,
collection_name = collection_name ,
embedding _function= embedding . embed_query ,
embedding s= embedding ,
content_payload_key = content_payload_key ,
metadata_payload_key = metadata_payload_key ,
)