diff --git a/langchain/vectorstores/qdrant.py b/langchain/vectorstores/qdrant.py index 33447561..168da43a 100644 --- a/langchain/vectorstores/qdrant.py +++ b/langchain/vectorstores/qdrant.py @@ -2,10 +2,13 @@ from __future__ import annotations import uuid +import warnings from hashlib import md5 from operator import itemgetter from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union +import numpy as np + from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.vectorstores import VectorStore @@ -37,9 +40,10 @@ class Qdrant(VectorStore): self, client: Any, collection_name: str, - embedding_function: Callable, + embeddings: Optional[Embeddings] = None, content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, + embedding_function: Optional[Callable] = None, # deprecated ): """Initialize with necessary components.""" try: @@ -56,12 +60,85 @@ class Qdrant(VectorStore): f"got {type(client)}" ) + if embeddings is None and embedding_function is None: + raise ValueError( + "`embeddings` value can't be None. Pass `Embeddings` instance." + ) + + if embeddings is not None and embedding_function is not None: + raise ValueError( + "Both `embeddings` and `embedding_function` are passed. " + "Use `embeddings` only." + ) + + self.embeddings = embeddings + self._embeddings_function = embedding_function self.client: qdrant_client.QdrantClient = client self.collection_name = collection_name - self.embedding_function = embedding_function self.content_payload_key = content_payload_key or self.CONTENT_KEY self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY + if embedding_function is not None: + warnings.warn( + "Using `embedding_function` is deprecated. " + "Pass `Embeddings` instance to `embeddings` instead." + ) + + if not isinstance(embeddings, Embeddings): + warnings.warn( + "`embeddings` should be an instance of `Embeddings`." + "Using `embeddings` as `embedding_function` which is deprecated" + ) + self._embeddings_function = embeddings + self.embeddings = None + + def _embed_query(self, query: str) -> List[float]: + """Embed query text. + + Used to provide backward compatibility with `embedding_function` argument. + + Args: + query: Query text. + + Returns: + List of floats representing the query embedding. + """ + if self.embeddings is not None: + embedding = self.embeddings.embed_query(query) + else: + if self._embeddings_function is not None: + embedding = self._embeddings_function(query) + else: + raise ValueError("Neither of embeddings or embedding_function is set") + return embedding.tolist() if hasattr(embedding, "tolist") else embedding + + def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]: + """Embed search texts. + + Used to provide backward compatibility with `embedding_function` argument. + + Args: + texts: Iterable of texts to embed. + + Returns: + List of floats representing the texts embedding. + """ + if self.embeddings is not None: + embeddings = self.embeddings.embed_documents(list(texts)) + if hasattr(embeddings, "tolist"): + embeddings = embeddings.tolist() + elif self._embeddings_function is not None: + embeddings = [] + for text in texts: + embedding = self._embeddings_function(text) + if hasattr(embeddings, "tolist"): + embedding = embedding.tolist() + embeddings.append(embedding) + else: + raise ValueError("Neither of embeddings or embedding_function is set") + + return embeddings + def add_texts( self, texts: Iterable[str], @@ -79,12 +156,16 @@ class Qdrant(VectorStore): """ from qdrant_client.http import models as rest + texts = list( + texts + ) # otherwise iterable might be exhausted after id calculation ids = [md5(text.encode("utf-8")).hexdigest() for text in texts] + self.client.upsert( collection_name=self.collection_name, points=rest.Batch.construct( ids=ids, - vectors=[self.embedding_function(text) for text in texts], + vectors=self._embed_texts(texts), payloads=self._build_payloads( texts, metadatas, @@ -129,10 +210,10 @@ class Qdrant(VectorStore): Returns: List of Documents most similar to the query and score for each. """ - embedding = self.embedding_function(query) + results = self.client.search( collection_name=self.collection_name, - query_vector=embedding, + query_vector=self._embed_query(query), query_filter=self._qdrant_filter_from_dict(filter), with_payload=True, limit=k, @@ -172,7 +253,8 @@ class Qdrant(VectorStore): Returns: List of Documents selected by maximal marginal relevance. """ - embedding = self.embedding_function(query) + + embedding = self._embed_query(query) results = self.client.search( collection_name=self.collection_name, query_vector=embedding, @@ -182,7 +264,7 @@ class Qdrant(VectorStore): ) embeddings = [result.vector for result in results] mmr_selected = maximal_marginal_relevance( - embedding, embeddings, k=k, lambda_mult=lambda_mult + np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult ) return [ self._document_from_scored_point( @@ -337,7 +419,7 @@ class Qdrant(VectorStore): return cls( client=client, collection_name=collection_name, - embedding_function=embedding.embed_query, + embeddings=embedding, content_payload_key=content_payload_key, metadata_payload_key=metadata_payload_key, ) diff --git a/tests/integration_tests/vectorstores/test_qdrant.py b/tests/integration_tests/vectorstores/test_qdrant.py index 5f2832c3..1f43a0bc 100644 --- a/tests/integration_tests/vectorstores/test_qdrant.py +++ b/tests/integration_tests/vectorstores/test_qdrant.py @@ -1,7 +1,10 @@ """Test Qdrant functionality.""" +from typing import Callable, Optional + import pytest from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings from langchain.vectorstores import Qdrant from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings @@ -29,6 +32,22 @@ def test_qdrant(content_payload_key: str, metadata_payload_key: str) -> None: assert output == [Document(page_content="foo")] +def test_qdrant_add_documents() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch: Qdrant = Qdrant.from_texts(texts, FakeEmbeddings(), location=":memory:") + + new_texts = ["foobar", "foobaz"] + docsearch.add_documents([Document(page_content=content) for content in new_texts]) + output = docsearch.similarity_search("foobar", k=1) + # FakeEmbeddings return the same query embedding as the first document embedding + # computed in `embedding.embed_documents`. Since embed_documents is called twice, + # "foo" embedding is the same as "foobar" embedding + assert output == [Document(page_content="foobar")] or output == [ + Document(page_content="foo") + ] + + @pytest.mark.parametrize( ["content_payload_key", "metadata_payload_key"], [ @@ -98,3 +117,51 @@ def test_qdrant_max_marginal_relevance_search( Document(page_content="foo", metadata={"page": 0}), Document(page_content="bar", metadata={"page": 1}), ] + + +@pytest.mark.parametrize( + ["embeddings", "embedding_function"], + [ + (FakeEmbeddings(), None), + (FakeEmbeddings().embed_query, None), + (None, FakeEmbeddings().embed_query), + ], +) +def test_qdrant_embedding_interface( + embeddings: Optional[Embeddings], embedding_function: Optional[Callable] +) -> None: + from qdrant_client import QdrantClient + + client = QdrantClient(":memory:") + collection_name = "test" + + Qdrant( + client, + collection_name, + embeddings=embeddings, + embedding_function=embedding_function, + ) + + +@pytest.mark.parametrize( + ["embeddings", "embedding_function"], + [ + (FakeEmbeddings(), FakeEmbeddings().embed_query), + (None, None), + ], +) +def test_qdrant_embedding_interface_raises( + embeddings: Optional[Embeddings], embedding_function: Optional[Callable] +) -> None: + from qdrant_client import QdrantClient + + client = QdrantClient(":memory:") + collection_name = "test" + + with pytest.raises(ValueError): + Qdrant( + client, + collection_name, + embeddings=embeddings, + embedding_function=embedding_function, + )