forked from Archives/langchain
Update qdrant interface (#3971)
Hello 1) Passing `embedding_function` as a callable seems to be outdated and the common interface is to pass `Embeddings` instance 2) At the moment `Qdrant.add_texts` is designed to be used with `embeddings.embed_query`, which is 1) slow 2) causes ambiguity due to 1. It should be used with `embeddings.embed_documents` This PR solves both problems and also provides some new tests
This commit is contained in:
parent
76ed41f48a
commit
2324f19c85
@ -2,10 +2,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
import warnings
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.vectorstores import VectorStore
|
from langchain.vectorstores import VectorStore
|
||||||
@ -37,9 +40,10 @@ class Qdrant(VectorStore):
|
|||||||
self,
|
self,
|
||||||
client: Any,
|
client: Any,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
embedding_function: Callable,
|
embeddings: Optional[Embeddings] = None,
|
||||||
content_payload_key: str = CONTENT_KEY,
|
content_payload_key: str = CONTENT_KEY,
|
||||||
metadata_payload_key: str = METADATA_KEY,
|
metadata_payload_key: str = METADATA_KEY,
|
||||||
|
embedding_function: Optional[Callable] = None, # deprecated
|
||||||
):
|
):
|
||||||
"""Initialize with necessary components."""
|
"""Initialize with necessary components."""
|
||||||
try:
|
try:
|
||||||
@ -56,12 +60,85 @@ class Qdrant(VectorStore):
|
|||||||
f"got {type(client)}"
|
f"got {type(client)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if embeddings is None and embedding_function is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`embeddings` value can't be None. Pass `Embeddings` instance."
|
||||||
|
)
|
||||||
|
|
||||||
|
if embeddings is not None and embedding_function is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Both `embeddings` and `embedding_function` are passed. "
|
||||||
|
"Use `embeddings` only."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embeddings = embeddings
|
||||||
|
self._embeddings_function = embedding_function
|
||||||
self.client: qdrant_client.QdrantClient = client
|
self.client: qdrant_client.QdrantClient = client
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.embedding_function = embedding_function
|
|
||||||
self.content_payload_key = content_payload_key or self.CONTENT_KEY
|
self.content_payload_key = content_payload_key or self.CONTENT_KEY
|
||||||
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
|
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
|
||||||
|
|
||||||
|
if embedding_function is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Using `embedding_function` is deprecated. "
|
||||||
|
"Pass `Embeddings` instance to `embeddings` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(embeddings, Embeddings):
|
||||||
|
warnings.warn(
|
||||||
|
"`embeddings` should be an instance of `Embeddings`."
|
||||||
|
"Using `embeddings` as `embedding_function` which is deprecated"
|
||||||
|
)
|
||||||
|
self._embeddings_function = embeddings
|
||||||
|
self.embeddings = None
|
||||||
|
|
||||||
|
def _embed_query(self, query: str) -> List[float]:
|
||||||
|
"""Embed query text.
|
||||||
|
|
||||||
|
Used to provide backward compatibility with `embedding_function` argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of floats representing the query embedding.
|
||||||
|
"""
|
||||||
|
if self.embeddings is not None:
|
||||||
|
embedding = self.embeddings.embed_query(query)
|
||||||
|
else:
|
||||||
|
if self._embeddings_function is not None:
|
||||||
|
embedding = self._embeddings_function(query)
|
||||||
|
else:
|
||||||
|
raise ValueError("Neither of embeddings or embedding_function is set")
|
||||||
|
return embedding.tolist() if hasattr(embedding, "tolist") else embedding
|
||||||
|
|
||||||
|
def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]:
|
||||||
|
"""Embed search texts.
|
||||||
|
|
||||||
|
Used to provide backward compatibility with `embedding_function` argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Iterable of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of floats representing the texts embedding.
|
||||||
|
"""
|
||||||
|
if self.embeddings is not None:
|
||||||
|
embeddings = self.embeddings.embed_documents(list(texts))
|
||||||
|
if hasattr(embeddings, "tolist"):
|
||||||
|
embeddings = embeddings.tolist()
|
||||||
|
elif self._embeddings_function is not None:
|
||||||
|
embeddings = []
|
||||||
|
for text in texts:
|
||||||
|
embedding = self._embeddings_function(text)
|
||||||
|
if hasattr(embeddings, "tolist"):
|
||||||
|
embedding = embedding.tolist()
|
||||||
|
embeddings.append(embedding)
|
||||||
|
else:
|
||||||
|
raise ValueError("Neither of embeddings or embedding_function is set")
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
def add_texts(
|
def add_texts(
|
||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
@ -79,12 +156,16 @@ class Qdrant(VectorStore):
|
|||||||
"""
|
"""
|
||||||
from qdrant_client.http import models as rest
|
from qdrant_client.http import models as rest
|
||||||
|
|
||||||
|
texts = list(
|
||||||
|
texts
|
||||||
|
) # otherwise iterable might be exhausted after id calculation
|
||||||
ids = [md5(text.encode("utf-8")).hexdigest() for text in texts]
|
ids = [md5(text.encode("utf-8")).hexdigest() for text in texts]
|
||||||
|
|
||||||
self.client.upsert(
|
self.client.upsert(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
points=rest.Batch.construct(
|
points=rest.Batch.construct(
|
||||||
ids=ids,
|
ids=ids,
|
||||||
vectors=[self.embedding_function(text) for text in texts],
|
vectors=self._embed_texts(texts),
|
||||||
payloads=self._build_payloads(
|
payloads=self._build_payloads(
|
||||||
texts,
|
texts,
|
||||||
metadatas,
|
metadatas,
|
||||||
@ -129,10 +210,10 @@ class Qdrant(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of Documents most similar to the query and score for each.
|
List of Documents most similar to the query and score for each.
|
||||||
"""
|
"""
|
||||||
embedding = self.embedding_function(query)
|
|
||||||
results = self.client.search(
|
results = self.client.search(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
query_vector=embedding,
|
query_vector=self._embed_query(query),
|
||||||
query_filter=self._qdrant_filter_from_dict(filter),
|
query_filter=self._qdrant_filter_from_dict(filter),
|
||||||
with_payload=True,
|
with_payload=True,
|
||||||
limit=k,
|
limit=k,
|
||||||
@ -172,7 +253,8 @@ class Qdrant(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of Documents selected by maximal marginal relevance.
|
List of Documents selected by maximal marginal relevance.
|
||||||
"""
|
"""
|
||||||
embedding = self.embedding_function(query)
|
|
||||||
|
embedding = self._embed_query(query)
|
||||||
results = self.client.search(
|
results = self.client.search(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
query_vector=embedding,
|
query_vector=embedding,
|
||||||
@ -182,7 +264,7 @@ class Qdrant(VectorStore):
|
|||||||
)
|
)
|
||||||
embeddings = [result.vector for result in results]
|
embeddings = [result.vector for result in results]
|
||||||
mmr_selected = maximal_marginal_relevance(
|
mmr_selected = maximal_marginal_relevance(
|
||||||
embedding, embeddings, k=k, lambda_mult=lambda_mult
|
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
self._document_from_scored_point(
|
self._document_from_scored_point(
|
||||||
@ -337,7 +419,7 @@ class Qdrant(VectorStore):
|
|||||||
return cls(
|
return cls(
|
||||||
client=client,
|
client=client,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
embedding_function=embedding.embed_query,
|
embeddings=embedding,
|
||||||
content_payload_key=content_payload_key,
|
content_payload_key=content_payload_key,
|
||||||
metadata_payload_key=metadata_payload_key,
|
metadata_payload_key=metadata_payload_key,
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
"""Test Qdrant functionality."""
|
"""Test Qdrant functionality."""
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.vectorstores import Qdrant
|
from langchain.vectorstores import Qdrant
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||||
|
|
||||||
@ -29,6 +32,22 @@ def test_qdrant(content_payload_key: str, metadata_payload_key: str) -> None:
|
|||||||
assert output == [Document(page_content="foo")]
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_qdrant_add_documents() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch: Qdrant = Qdrant.from_texts(texts, FakeEmbeddings(), location=":memory:")
|
||||||
|
|
||||||
|
new_texts = ["foobar", "foobaz"]
|
||||||
|
docsearch.add_documents([Document(page_content=content) for content in new_texts])
|
||||||
|
output = docsearch.similarity_search("foobar", k=1)
|
||||||
|
# FakeEmbeddings return the same query embedding as the first document embedding
|
||||||
|
# computed in `embedding.embed_documents`. Since embed_documents is called twice,
|
||||||
|
# "foo" embedding is the same as "foobar" embedding
|
||||||
|
assert output == [Document(page_content="foobar")] or output == [
|
||||||
|
Document(page_content="foo")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
["content_payload_key", "metadata_payload_key"],
|
["content_payload_key", "metadata_payload_key"],
|
||||||
[
|
[
|
||||||
@ -98,3 +117,51 @@ def test_qdrant_max_marginal_relevance_search(
|
|||||||
Document(page_content="foo", metadata={"page": 0}),
|
Document(page_content="foo", metadata={"page": 0}),
|
||||||
Document(page_content="bar", metadata={"page": 1}),
|
Document(page_content="bar", metadata={"page": 1}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["embeddings", "embedding_function"],
|
||||||
|
[
|
||||||
|
(FakeEmbeddings(), None),
|
||||||
|
(FakeEmbeddings().embed_query, None),
|
||||||
|
(None, FakeEmbeddings().embed_query),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_qdrant_embedding_interface(
|
||||||
|
embeddings: Optional[Embeddings], embedding_function: Optional[Callable]
|
||||||
|
) -> None:
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
|
client = QdrantClient(":memory:")
|
||||||
|
collection_name = "test"
|
||||||
|
|
||||||
|
Qdrant(
|
||||||
|
client,
|
||||||
|
collection_name,
|
||||||
|
embeddings=embeddings,
|
||||||
|
embedding_function=embedding_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["embeddings", "embedding_function"],
|
||||||
|
[
|
||||||
|
(FakeEmbeddings(), FakeEmbeddings().embed_query),
|
||||||
|
(None, None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_qdrant_embedding_interface_raises(
|
||||||
|
embeddings: Optional[Embeddings], embedding_function: Optional[Callable]
|
||||||
|
) -> None:
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
|
client = QdrantClient(":memory:")
|
||||||
|
collection_name = "test"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Qdrant(
|
||||||
|
client,
|
||||||
|
collection_name,
|
||||||
|
embeddings=embeddings,
|
||||||
|
embedding_function=embedding_function,
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user