Accept uuids kwargs for weaviate (#4800)

# Accept uuids kwargs for weaviate

Fixes #4791
This commit is contained in:
yujiosaka 2023-05-17 07:26:46 +09:00 committed by GitHub
parent e78c9be312
commit 6561efebb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 89 additions and 5 deletions

View File

@ -47,7 +47,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
arbitrary_types_allowed = True arbitrary_types_allowed = True
# added text_key # added text_key
def add_documents(self, docs: List[Document]) -> List[str]: def add_documents(self, docs: List[Document], **kwargs: Any) -> List[str]:
"""Upload documents to Weaviate.""" """Upload documents to Weaviate."""
from weaviate.util import get_valid_uuid from weaviate.util import get_valid_uuid
@ -56,7 +56,14 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
metadata = doc.metadata or {} metadata = doc.metadata or {}
data_properties = {self._text_key: doc.page_content, **metadata} data_properties = {self._text_key: doc.page_content, **metadata}
_id = get_valid_uuid(uuid4())
# If the UUID of one of the objects already exists
# then the existing objectwill be replaced by the new object.
if "uuids" in kwargs:
_id = kwargs["uuids"][i]
else:
_id = get_valid_uuid(uuid4())
batch.add_data_object(data_properties, self._index_name, _id) batch.add_data_object(data_properties, self._index_name, _id)
ids.append(_id) ids.append(_id)
return ids return ids

View File

@ -351,7 +351,7 @@ class Redis(VectorStore):
if self.relevance_score_fn is None: if self.relevance_score_fn is None:
raise ValueError( raise ValueError(
"relevance_score_fn must be provided to" "relevance_score_fn must be provided to"
" Weaviate constructor to normalize scores" " Redis constructor to normalize scores"
) )
docs_and_scores = self.similarity_search_with_score(query, k=k) docs_and_scores = self.similarity_search_with_score(query, k=k)
return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores] return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores]

View File

@ -134,7 +134,12 @@ class Weaviate(VectorStore):
for key in metadatas[i].keys(): for key in metadatas[i].keys():
data_properties[key] = json_serializable(metadatas[i][key]) data_properties[key] = json_serializable(metadatas[i][key])
_id = get_valid_uuid(uuid4()) # If the UUID of one of the objects already exists
# then the existing objectwill be replaced by the new object.
if "uuids" in kwargs:
_id = kwargs["uuids"][i]
else:
_id = get_valid_uuid(uuid4())
if self._embedding is not None: if self._embedding is not None:
embeddings = self._embedding.embed_documents(list(doc)) embeddings = self._embedding.embed_documents(list(doc))
@ -385,7 +390,12 @@ class Weaviate(VectorStore):
for key in metadatas[i].keys(): for key in metadatas[i].keys():
data_properties[key] = metadatas[i][key] data_properties[key] = metadatas[i][key]
_id = get_valid_uuid(uuid4()) # If the UUID of one of the objects already exists
# then the existing objectwill be replaced by the new object.
if "uuids" in kwargs:
_id = kwargs["uuids"][i]
else:
_id = get_valid_uuid(uuid4())
# if an embedding strategy is not provided, we let # if an embedding strategy is not provided, we let
# weaviate create the embedding. Note that this will only # weaviate create the embedding. Note that this will only

View File

@ -1,6 +1,7 @@
"""Test Weaviate functionality.""" """Test Weaviate functionality."""
import logging import logging
import os import os
import uuid
from typing import Generator, Union from typing import Generator, Union
from uuid import uuid4 from uuid import uuid4
@ -85,3 +86,28 @@ class TestWeaviateHybridSearchRetriever:
assert output == [ assert output == [
Document(page_content="foo", metadata={"page": 0}), Document(page_content="foo", metadata={"page": 0}),
] ]
@pytest.mark.vcr(ignore_localhost=True)
def test_get_relevant_documents_with_uuids(self, weaviate_url: str) -> None:
"""Test end to end construction and MRR search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
# Weaviate replaces the object if the UUID already exists
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, "same-name") for text in texts]
client = Client(weaviate_url)
retriever = WeaviateHybridSearchRetriever(
client=client,
index_name=f"LangChain_{uuid4().hex}",
text_key="text",
attributes=["page"],
)
for i, text in enumerate(texts):
# hoge
retriever.add_documents(
[Document(page_content=text, metadata=metadatas[i])], uuids=[uuids[i]]
)
output = retriever.get_relevant_documents("foo")
assert len(output) == 1

View File

@ -1,6 +1,7 @@
"""Test Weaviate functionality.""" """Test Weaviate functionality."""
import logging import logging
import os import os
import uuid
from typing import Generator, Union from typing import Generator, Union
import pytest import pytest
@ -80,6 +81,26 @@ class TestWeaviate:
) )
assert output == [Document(page_content="foo", metadata={"page": 0})] assert output == [Document(page_content="foo", metadata={"page": 0})]
@pytest.mark.vcr(ignore_localhost=True)
def test_similarity_search_with_uuids(
self, weaviate_url: str, embedding_openai: OpenAIEmbeddings
) -> None:
"""Test end to end construction and search with uuids."""
texts = ["foo", "bar", "baz"]
# Weaviate replaces the object if the UUID already exists
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, "same-name") for text in texts]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = Weaviate.from_texts(
texts,
embedding_openai,
metadatas=metadatas,
weaviate_url=weaviate_url,
uuids=uuids,
)
output = docsearch.similarity_search("foo", k=2)
assert len(output) == 1
@pytest.mark.vcr(ignore_localhost=True) @pytest.mark.vcr(ignore_localhost=True)
def test_max_marginal_relevance_search( def test_max_marginal_relevance_search(
self, weaviate_url: str, embedding_openai: OpenAIEmbeddings self, weaviate_url: str, embedding_openai: OpenAIEmbeddings
@ -181,3 +202,23 @@ class TestWeaviate:
Document(page_content="foo"), Document(page_content="foo"),
Document(page_content="foo"), Document(page_content="foo"),
] ]
def test_add_texts_with_given_uuids(self, weaviate_url: str) -> None:
texts = ["foo", "bar", "baz"]
embedding = FakeEmbeddings()
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, text) for text in texts]
docsearch = Weaviate.from_texts(
texts,
embedding=embedding,
weaviate_url=weaviate_url,
uuids=uuids,
)
# Weaviate replaces the object if the UUID already exists
docsearch.add_texts(["foo"], uuids=[uuids[0]])
output = docsearch.similarity_search_by_vector(
embedding.embed_query("foo"), k=2
)
assert output[0] == Document(page_content="foo")
assert output[1] != Document(page_content="foo")