diff --git a/langchain/retrievers/weaviate_hybrid_search.py b/langchain/retrievers/weaviate_hybrid_search.py index 8d8da48b..2f743ced 100644 --- a/langchain/retrievers/weaviate_hybrid_search.py +++ b/langchain/retrievers/weaviate_hybrid_search.py @@ -47,7 +47,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever): arbitrary_types_allowed = True # added text_key - def add_documents(self, docs: List[Document]) -> List[str]: + def add_documents(self, docs: List[Document], **kwargs: Any) -> List[str]: """Upload documents to Weaviate.""" from weaviate.util import get_valid_uuid @@ -56,7 +56,14 @@ class WeaviateHybridSearchRetriever(BaseRetriever): for i, doc in enumerate(docs): metadata = doc.metadata or {} data_properties = {self._text_key: doc.page_content, **metadata} - _id = get_valid_uuid(uuid4()) + + # If the UUID of one of the objects already exists + # then the existing objectwill be replaced by the new object. + if "uuids" in kwargs: + _id = kwargs["uuids"][i] + else: + _id = get_valid_uuid(uuid4()) + batch.add_data_object(data_properties, self._index_name, _id) ids.append(_id) return ids diff --git a/langchain/vectorstores/redis.py b/langchain/vectorstores/redis.py index 0adec6a4..d8c2d05a 100644 --- a/langchain/vectorstores/redis.py +++ b/langchain/vectorstores/redis.py @@ -351,7 +351,7 @@ class Redis(VectorStore): if self.relevance_score_fn is None: raise ValueError( "relevance_score_fn must be provided to" - " Weaviate constructor to normalize scores" + " Redis constructor to normalize scores" ) docs_and_scores = self.similarity_search_with_score(query, k=k) return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores] diff --git a/langchain/vectorstores/weaviate.py b/langchain/vectorstores/weaviate.py index 78d8beec..34ed923d 100644 --- a/langchain/vectorstores/weaviate.py +++ b/langchain/vectorstores/weaviate.py @@ -134,7 +134,12 @@ class Weaviate(VectorStore): for key in metadatas[i].keys(): data_properties[key] = json_serializable(metadatas[i][key]) - _id = get_valid_uuid(uuid4()) + # If the UUID of one of the objects already exists + # then the existing objectwill be replaced by the new object. + if "uuids" in kwargs: + _id = kwargs["uuids"][i] + else: + _id = get_valid_uuid(uuid4()) if self._embedding is not None: embeddings = self._embedding.embed_documents(list(doc)) @@ -385,7 +390,12 @@ class Weaviate(VectorStore): for key in metadatas[i].keys(): data_properties[key] = metadatas[i][key] - _id = get_valid_uuid(uuid4()) + # If the UUID of one of the objects already exists + # then the existing objectwill be replaced by the new object. + if "uuids" in kwargs: + _id = kwargs["uuids"][i] + else: + _id = get_valid_uuid(uuid4()) # if an embedding strategy is not provided, we let # weaviate create the embedding. Note that this will only diff --git a/tests/integration_tests/retrievers/test_weaviate_hybrid_search.py b/tests/integration_tests/retrievers/test_weaviate_hybrid_search.py index a5013c42..fabdbb97 100644 --- a/tests/integration_tests/retrievers/test_weaviate_hybrid_search.py +++ b/tests/integration_tests/retrievers/test_weaviate_hybrid_search.py @@ -1,6 +1,7 @@ """Test Weaviate functionality.""" import logging import os +import uuid from typing import Generator, Union from uuid import uuid4 @@ -85,3 +86,28 @@ class TestWeaviateHybridSearchRetriever: assert output == [ Document(page_content="foo", metadata={"page": 0}), ] + + @pytest.mark.vcr(ignore_localhost=True) + def test_get_relevant_documents_with_uuids(self, weaviate_url: str) -> None: + """Test end to end construction and MRR search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + # Weaviate replaces the object if the UUID already exists + uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, "same-name") for text in texts] + + client = Client(weaviate_url) + + retriever = WeaviateHybridSearchRetriever( + client=client, + index_name=f"LangChain_{uuid4().hex}", + text_key="text", + attributes=["page"], + ) + for i, text in enumerate(texts): + # hoge + retriever.add_documents( + [Document(page_content=text, metadata=metadatas[i])], uuids=[uuids[i]] + ) + + output = retriever.get_relevant_documents("foo") + assert len(output) == 1 diff --git a/tests/integration_tests/vectorstores/test_weaviate.py b/tests/integration_tests/vectorstores/test_weaviate.py index 7170bf64..127695ec 100644 --- a/tests/integration_tests/vectorstores/test_weaviate.py +++ b/tests/integration_tests/vectorstores/test_weaviate.py @@ -1,6 +1,7 @@ """Test Weaviate functionality.""" import logging import os +import uuid from typing import Generator, Union import pytest @@ -80,6 +81,26 @@ class TestWeaviate: ) assert output == [Document(page_content="foo", metadata={"page": 0})] + @pytest.mark.vcr(ignore_localhost=True) + def test_similarity_search_with_uuids( + self, weaviate_url: str, embedding_openai: OpenAIEmbeddings + ) -> None: + """Test end to end construction and search with uuids.""" + texts = ["foo", "bar", "baz"] + # Weaviate replaces the object if the UUID already exists + uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, "same-name") for text in texts] + + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = Weaviate.from_texts( + texts, + embedding_openai, + metadatas=metadatas, + weaviate_url=weaviate_url, + uuids=uuids, + ) + output = docsearch.similarity_search("foo", k=2) + assert len(output) == 1 + @pytest.mark.vcr(ignore_localhost=True) def test_max_marginal_relevance_search( self, weaviate_url: str, embedding_openai: OpenAIEmbeddings @@ -181,3 +202,23 @@ class TestWeaviate: Document(page_content="foo"), Document(page_content="foo"), ] + + def test_add_texts_with_given_uuids(self, weaviate_url: str) -> None: + texts = ["foo", "bar", "baz"] + embedding = FakeEmbeddings() + uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, text) for text in texts] + + docsearch = Weaviate.from_texts( + texts, + embedding=embedding, + weaviate_url=weaviate_url, + uuids=uuids, + ) + + # Weaviate replaces the object if the UUID already exists + docsearch.add_texts(["foo"], uuids=[uuids[0]]) + output = docsearch.similarity_search_by_vector( + embedding.embed_query("foo"), k=2 + ) + assert output[0] == Document(page_content="foo") + assert output[1] != Document(page_content="foo")