Update qdrant interface (#3971)

Hello

1) Passing `embedding_function` as a callable seems to be outdated and
the common interface is to pass `Embeddings` instance

2) At the moment `Qdrant.add_texts` is designed to be used with
`embeddings.embed_query`, which is 1) slow 2) causes ambiguity due to 1.
It should be used with `embeddings.embed_documents`

This PR solves both problems and also provides some new tests
This commit is contained in:
George 2023-05-06 03:46:40 +04:00 committed by GitHub
parent 76ed41f48a
commit 2324f19c85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 157 additions and 8 deletions

View File

@ -2,10 +2,13 @@
from __future__ import annotations
import uuid
import warnings
from hashlib import md5
from operator import itemgetter
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
import numpy as np
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import VectorStore
@ -37,9 +40,10 @@ class Qdrant(VectorStore):
self,
client: Any,
collection_name: str,
embedding_function: Callable,
embeddings: Optional[Embeddings] = None,
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
embedding_function: Optional[Callable] = None, # deprecated
):
"""Initialize with necessary components."""
try:
@ -56,12 +60,85 @@ class Qdrant(VectorStore):
f"got {type(client)}"
)
if embeddings is None and embedding_function is None:
raise ValueError(
"`embeddings` value can't be None. Pass `Embeddings` instance."
)
if embeddings is not None and embedding_function is not None:
raise ValueError(
"Both `embeddings` and `embedding_function` are passed. "
"Use `embeddings` only."
)
self.embeddings = embeddings
self._embeddings_function = embedding_function
self.client: qdrant_client.QdrantClient = client
self.collection_name = collection_name
self.embedding_function = embedding_function
self.content_payload_key = content_payload_key or self.CONTENT_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
if embedding_function is not None:
warnings.warn(
"Using `embedding_function` is deprecated. "
"Pass `Embeddings` instance to `embeddings` instead."
)
if not isinstance(embeddings, Embeddings):
warnings.warn(
"`embeddings` should be an instance of `Embeddings`."
"Using `embeddings` as `embedding_function` which is deprecated"
)
self._embeddings_function = embeddings
self.embeddings = None
def _embed_query(self, query: str) -> List[float]:
"""Embed query text.
Used to provide backward compatibility with `embedding_function` argument.
Args:
query: Query text.
Returns:
List of floats representing the query embedding.
"""
if self.embeddings is not None:
embedding = self.embeddings.embed_query(query)
else:
if self._embeddings_function is not None:
embedding = self._embeddings_function(query)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embedding.tolist() if hasattr(embedding, "tolist") else embedding
def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]:
"""Embed search texts.
Used to provide backward compatibility with `embedding_function` argument.
Args:
texts: Iterable of texts to embed.
Returns:
List of floats representing the texts embedding.
"""
if self.embeddings is not None:
embeddings = self.embeddings.embed_documents(list(texts))
if hasattr(embeddings, "tolist"):
embeddings = embeddings.tolist()
elif self._embeddings_function is not None:
embeddings = []
for text in texts:
embedding = self._embeddings_function(text)
if hasattr(embeddings, "tolist"):
embedding = embedding.tolist()
embeddings.append(embedding)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embeddings
def add_texts(
self,
texts: Iterable[str],
@ -79,12 +156,16 @@ class Qdrant(VectorStore):
"""
from qdrant_client.http import models as rest
texts = list(
texts
) # otherwise iterable might be exhausted after id calculation
ids = [md5(text.encode("utf-8")).hexdigest() for text in texts]
self.client.upsert(
collection_name=self.collection_name,
points=rest.Batch.construct(
ids=ids,
vectors=[self.embedding_function(text) for text in texts],
vectors=self._embed_texts(texts),
payloads=self._build_payloads(
texts,
metadatas,
@ -129,10 +210,10 @@ class Qdrant(VectorStore):
Returns:
List of Documents most similar to the query and score for each.
"""
embedding = self.embedding_function(query)
results = self.client.search(
collection_name=self.collection_name,
query_vector=embedding,
query_vector=self._embed_query(query),
query_filter=self._qdrant_filter_from_dict(filter),
with_payload=True,
limit=k,
@ -172,7 +253,8 @@ class Qdrant(VectorStore):
Returns:
List of Documents selected by maximal marginal relevance.
"""
embedding = self.embedding_function(query)
embedding = self._embed_query(query)
results = self.client.search(
collection_name=self.collection_name,
query_vector=embedding,
@ -182,7 +264,7 @@ class Qdrant(VectorStore):
)
embeddings = [result.vector for result in results]
mmr_selected = maximal_marginal_relevance(
embedding, embeddings, k=k, lambda_mult=lambda_mult
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
)
return [
self._document_from_scored_point(
@ -337,7 +419,7 @@ class Qdrant(VectorStore):
return cls(
client=client,
collection_name=collection_name,
embedding_function=embedding.embed_query,
embeddings=embedding,
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
)

View File

@ -1,7 +1,10 @@
"""Test Qdrant functionality."""
from typing import Callable, Optional
import pytest
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import Qdrant
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
@ -29,6 +32,22 @@ def test_qdrant(content_payload_key: str, metadata_payload_key: str) -> None:
assert output == [Document(page_content="foo")]
def test_qdrant_add_documents() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch: Qdrant = Qdrant.from_texts(texts, FakeEmbeddings(), location=":memory:")
new_texts = ["foobar", "foobaz"]
docsearch.add_documents([Document(page_content=content) for content in new_texts])
output = docsearch.similarity_search("foobar", k=1)
# FakeEmbeddings return the same query embedding as the first document embedding
# computed in `embedding.embed_documents`. Since embed_documents is called twice,
# "foo" embedding is the same as "foobar" embedding
assert output == [Document(page_content="foobar")] or output == [
Document(page_content="foo")
]
@pytest.mark.parametrize(
["content_payload_key", "metadata_payload_key"],
[
@ -98,3 +117,51 @@ def test_qdrant_max_marginal_relevance_search(
Document(page_content="foo", metadata={"page": 0}),
Document(page_content="bar", metadata={"page": 1}),
]
@pytest.mark.parametrize(
["embeddings", "embedding_function"],
[
(FakeEmbeddings(), None),
(FakeEmbeddings().embed_query, None),
(None, FakeEmbeddings().embed_query),
],
)
def test_qdrant_embedding_interface(
embeddings: Optional[Embeddings], embedding_function: Optional[Callable]
) -> None:
from qdrant_client import QdrantClient
client = QdrantClient(":memory:")
collection_name = "test"
Qdrant(
client,
collection_name,
embeddings=embeddings,
embedding_function=embedding_function,
)
@pytest.mark.parametrize(
["embeddings", "embedding_function"],
[
(FakeEmbeddings(), FakeEmbeddings().embed_query),
(None, None),
],
)
def test_qdrant_embedding_interface_raises(
embeddings: Optional[Embeddings], embedding_function: Optional[Callable]
) -> None:
from qdrant_client import QdrantClient
client = QdrantClient(":memory:")
collection_name = "test"
with pytest.raises(ValueError):
Qdrant(
client,
collection_name,
embeddings=embeddings,
embedding_function=embedding_function,
)