Accept uuids kwargs for weaviate (#4800)

# Accept uuids kwargs for weaviate

Fixes #4791
This commit is contained in:
yujiosaka 2023-05-17 07:26:46 +09:00 committed by GitHub
parent e78c9be312
commit 6561efebb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 89 additions and 5 deletions

View File

@ -47,7 +47,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
arbitrary_types_allowed = True
# added text_key
def add_documents(self, docs: List[Document]) -> List[str]:
def add_documents(self, docs: List[Document], **kwargs: Any) -> List[str]:
"""Upload documents to Weaviate."""
from weaviate.util import get_valid_uuid
@ -56,7 +56,14 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
for i, doc in enumerate(docs):
metadata = doc.metadata or {}
data_properties = {self._text_key: doc.page_content, **metadata}
_id = get_valid_uuid(uuid4())
# If the UUID of one of the objects already exists
# then the existing objectwill be replaced by the new object.
if "uuids" in kwargs:
_id = kwargs["uuids"][i]
else:
_id = get_valid_uuid(uuid4())
batch.add_data_object(data_properties, self._index_name, _id)
ids.append(_id)
return ids

View File

@ -351,7 +351,7 @@ class Redis(VectorStore):
if self.relevance_score_fn is None:
raise ValueError(
"relevance_score_fn must be provided to"
" Weaviate constructor to normalize scores"
" Redis constructor to normalize scores"
)
docs_and_scores = self.similarity_search_with_score(query, k=k)
return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores]

View File

@ -134,7 +134,12 @@ class Weaviate(VectorStore):
for key in metadatas[i].keys():
data_properties[key] = json_serializable(metadatas[i][key])
_id = get_valid_uuid(uuid4())
# If the UUID of one of the objects already exists
# then the existing objectwill be replaced by the new object.
if "uuids" in kwargs:
_id = kwargs["uuids"][i]
else:
_id = get_valid_uuid(uuid4())
if self._embedding is not None:
embeddings = self._embedding.embed_documents(list(doc))
@ -385,7 +390,12 @@ class Weaviate(VectorStore):
for key in metadatas[i].keys():
data_properties[key] = metadatas[i][key]
_id = get_valid_uuid(uuid4())
# If the UUID of one of the objects already exists
# then the existing objectwill be replaced by the new object.
if "uuids" in kwargs:
_id = kwargs["uuids"][i]
else:
_id = get_valid_uuid(uuid4())
# if an embedding strategy is not provided, we let
# weaviate create the embedding. Note that this will only

View File

@ -1,6 +1,7 @@
"""Test Weaviate functionality."""
import logging
import os
import uuid
from typing import Generator, Union
from uuid import uuid4
@ -85,3 +86,28 @@ class TestWeaviateHybridSearchRetriever:
assert output == [
Document(page_content="foo", metadata={"page": 0}),
]
@pytest.mark.vcr(ignore_localhost=True)
def test_get_relevant_documents_with_uuids(self, weaviate_url: str) -> None:
"""Test end to end construction and MRR search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
# Weaviate replaces the object if the UUID already exists
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, "same-name") for text in texts]
client = Client(weaviate_url)
retriever = WeaviateHybridSearchRetriever(
client=client,
index_name=f"LangChain_{uuid4().hex}",
text_key="text",
attributes=["page"],
)
for i, text in enumerate(texts):
# hoge
retriever.add_documents(
[Document(page_content=text, metadata=metadatas[i])], uuids=[uuids[i]]
)
output = retriever.get_relevant_documents("foo")
assert len(output) == 1

View File

@ -1,6 +1,7 @@
"""Test Weaviate functionality."""
import logging
import os
import uuid
from typing import Generator, Union
import pytest
@ -80,6 +81,26 @@ class TestWeaviate:
)
assert output == [Document(page_content="foo", metadata={"page": 0})]
@pytest.mark.vcr(ignore_localhost=True)
def test_similarity_search_with_uuids(
self, weaviate_url: str, embedding_openai: OpenAIEmbeddings
) -> None:
"""Test end to end construction and search with uuids."""
texts = ["foo", "bar", "baz"]
# Weaviate replaces the object if the UUID already exists
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, "same-name") for text in texts]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = Weaviate.from_texts(
texts,
embedding_openai,
metadatas=metadatas,
weaviate_url=weaviate_url,
uuids=uuids,
)
output = docsearch.similarity_search("foo", k=2)
assert len(output) == 1
@pytest.mark.vcr(ignore_localhost=True)
def test_max_marginal_relevance_search(
self, weaviate_url: str, embedding_openai: OpenAIEmbeddings
@ -181,3 +202,23 @@ class TestWeaviate:
Document(page_content="foo"),
Document(page_content="foo"),
]
def test_add_texts_with_given_uuids(self, weaviate_url: str) -> None:
texts = ["foo", "bar", "baz"]
embedding = FakeEmbeddings()
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, text) for text in texts]
docsearch = Weaviate.from_texts(
texts,
embedding=embedding,
weaviate_url=weaviate_url,
uuids=uuids,
)
# Weaviate replaces the object if the UUID already exists
docsearch.add_texts(["foo"], uuids=[uuids[0]])
output = docsearch.similarity_search_by_vector(
embedding.embed_query("foo"), k=2
)
assert output[0] == Document(page_content="foo")
assert output[1] != Document(page_content="foo")