forked from Archives/langchain
Accept uuids kwargs for weaviate (#4800)
# Accept uuids kwargs for weaviate Fixes #4791
This commit is contained in:
parent
e78c9be312
commit
6561efebb7
@ -47,7 +47,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
# added text_key
|
# added text_key
|
||||||
def add_documents(self, docs: List[Document]) -> List[str]:
|
def add_documents(self, docs: List[Document], **kwargs: Any) -> List[str]:
|
||||||
"""Upload documents to Weaviate."""
|
"""Upload documents to Weaviate."""
|
||||||
from weaviate.util import get_valid_uuid
|
from weaviate.util import get_valid_uuid
|
||||||
|
|
||||||
@ -56,7 +56,14 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
metadata = doc.metadata or {}
|
metadata = doc.metadata or {}
|
||||||
data_properties = {self._text_key: doc.page_content, **metadata}
|
data_properties = {self._text_key: doc.page_content, **metadata}
|
||||||
_id = get_valid_uuid(uuid4())
|
|
||||||
|
# If the UUID of one of the objects already exists
|
||||||
|
# then the existing objectwill be replaced by the new object.
|
||||||
|
if "uuids" in kwargs:
|
||||||
|
_id = kwargs["uuids"][i]
|
||||||
|
else:
|
||||||
|
_id = get_valid_uuid(uuid4())
|
||||||
|
|
||||||
batch.add_data_object(data_properties, self._index_name, _id)
|
batch.add_data_object(data_properties, self._index_name, _id)
|
||||||
ids.append(_id)
|
ids.append(_id)
|
||||||
return ids
|
return ids
|
||||||
|
@ -351,7 +351,7 @@ class Redis(VectorStore):
|
|||||||
if self.relevance_score_fn is None:
|
if self.relevance_score_fn is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"relevance_score_fn must be provided to"
|
"relevance_score_fn must be provided to"
|
||||||
" Weaviate constructor to normalize scores"
|
" Redis constructor to normalize scores"
|
||||||
)
|
)
|
||||||
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
||||||
return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores]
|
return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores]
|
||||||
|
@ -134,7 +134,12 @@ class Weaviate(VectorStore):
|
|||||||
for key in metadatas[i].keys():
|
for key in metadatas[i].keys():
|
||||||
data_properties[key] = json_serializable(metadatas[i][key])
|
data_properties[key] = json_serializable(metadatas[i][key])
|
||||||
|
|
||||||
_id = get_valid_uuid(uuid4())
|
# If the UUID of one of the objects already exists
|
||||||
|
# then the existing objectwill be replaced by the new object.
|
||||||
|
if "uuids" in kwargs:
|
||||||
|
_id = kwargs["uuids"][i]
|
||||||
|
else:
|
||||||
|
_id = get_valid_uuid(uuid4())
|
||||||
|
|
||||||
if self._embedding is not None:
|
if self._embedding is not None:
|
||||||
embeddings = self._embedding.embed_documents(list(doc))
|
embeddings = self._embedding.embed_documents(list(doc))
|
||||||
@ -385,7 +390,12 @@ class Weaviate(VectorStore):
|
|||||||
for key in metadatas[i].keys():
|
for key in metadatas[i].keys():
|
||||||
data_properties[key] = metadatas[i][key]
|
data_properties[key] = metadatas[i][key]
|
||||||
|
|
||||||
_id = get_valid_uuid(uuid4())
|
# If the UUID of one of the objects already exists
|
||||||
|
# then the existing objectwill be replaced by the new object.
|
||||||
|
if "uuids" in kwargs:
|
||||||
|
_id = kwargs["uuids"][i]
|
||||||
|
else:
|
||||||
|
_id = get_valid_uuid(uuid4())
|
||||||
|
|
||||||
# if an embedding strategy is not provided, we let
|
# if an embedding strategy is not provided, we let
|
||||||
# weaviate create the embedding. Note that this will only
|
# weaviate create the embedding. Note that this will only
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Test Weaviate functionality."""
|
"""Test Weaviate functionality."""
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
from typing import Generator, Union
|
from typing import Generator, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@ -85,3 +86,28 @@ class TestWeaviateHybridSearchRetriever:
|
|||||||
assert output == [
|
assert output == [
|
||||||
Document(page_content="foo", metadata={"page": 0}),
|
Document(page_content="foo", metadata={"page": 0}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@pytest.mark.vcr(ignore_localhost=True)
|
||||||
|
def test_get_relevant_documents_with_uuids(self, weaviate_url: str) -> None:
|
||||||
|
"""Test end to end construction and MRR search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
# Weaviate replaces the object if the UUID already exists
|
||||||
|
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, "same-name") for text in texts]
|
||||||
|
|
||||||
|
client = Client(weaviate_url)
|
||||||
|
|
||||||
|
retriever = WeaviateHybridSearchRetriever(
|
||||||
|
client=client,
|
||||||
|
index_name=f"LangChain_{uuid4().hex}",
|
||||||
|
text_key="text",
|
||||||
|
attributes=["page"],
|
||||||
|
)
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
# hoge
|
||||||
|
retriever.add_documents(
|
||||||
|
[Document(page_content=text, metadata=metadatas[i])], uuids=[uuids[i]]
|
||||||
|
)
|
||||||
|
|
||||||
|
output = retriever.get_relevant_documents("foo")
|
||||||
|
assert len(output) == 1
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Test Weaviate functionality."""
|
"""Test Weaviate functionality."""
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
from typing import Generator, Union
|
from typing import Generator, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -80,6 +81,26 @@ class TestWeaviate:
|
|||||||
)
|
)
|
||||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||||
|
|
||||||
|
@pytest.mark.vcr(ignore_localhost=True)
|
||||||
|
def test_similarity_search_with_uuids(
|
||||||
|
self, weaviate_url: str, embedding_openai: OpenAIEmbeddings
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search with uuids."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
# Weaviate replaces the object if the UUID already exists
|
||||||
|
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, "same-name") for text in texts]
|
||||||
|
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = Weaviate.from_texts(
|
||||||
|
texts,
|
||||||
|
embedding_openai,
|
||||||
|
metadatas=metadatas,
|
||||||
|
weaviate_url=weaviate_url,
|
||||||
|
uuids=uuids,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search("foo", k=2)
|
||||||
|
assert len(output) == 1
|
||||||
|
|
||||||
@pytest.mark.vcr(ignore_localhost=True)
|
@pytest.mark.vcr(ignore_localhost=True)
|
||||||
def test_max_marginal_relevance_search(
|
def test_max_marginal_relevance_search(
|
||||||
self, weaviate_url: str, embedding_openai: OpenAIEmbeddings
|
self, weaviate_url: str, embedding_openai: OpenAIEmbeddings
|
||||||
@ -181,3 +202,23 @@ class TestWeaviate:
|
|||||||
Document(page_content="foo"),
|
Document(page_content="foo"),
|
||||||
Document(page_content="foo"),
|
Document(page_content="foo"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def test_add_texts_with_given_uuids(self, weaviate_url: str) -> None:
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
embedding = FakeEmbeddings()
|
||||||
|
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, text) for text in texts]
|
||||||
|
|
||||||
|
docsearch = Weaviate.from_texts(
|
||||||
|
texts,
|
||||||
|
embedding=embedding,
|
||||||
|
weaviate_url=weaviate_url,
|
||||||
|
uuids=uuids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Weaviate replaces the object if the UUID already exists
|
||||||
|
docsearch.add_texts(["foo"], uuids=[uuids[0]])
|
||||||
|
output = docsearch.similarity_search_by_vector(
|
||||||
|
embedding.embed_query("foo"), k=2
|
||||||
|
)
|
||||||
|
assert output[0] == Document(page_content="foo")
|
||||||
|
assert output[1] != Document(page_content="foo")
|
||||||
|
Loading…
Reference in New Issue
Block a user