forked from Archives/langchain
Accept uuids kwargs for weaviate (#4800)
# Accept uuids kwargs for weaviate Fixes #4791
This commit is contained in:
parent
e78c9be312
commit
6561efebb7
@ -47,7 +47,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# added text_key
|
||||
def add_documents(self, docs: List[Document]) -> List[str]:
|
||||
def add_documents(self, docs: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Upload documents to Weaviate."""
|
||||
from weaviate.util import get_valid_uuid
|
||||
|
||||
@ -56,7 +56,14 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
for i, doc in enumerate(docs):
|
||||
metadata = doc.metadata or {}
|
||||
data_properties = {self._text_key: doc.page_content, **metadata}
|
||||
_id = get_valid_uuid(uuid4())
|
||||
|
||||
# If the UUID of one of the objects already exists
|
||||
# then the existing objectwill be replaced by the new object.
|
||||
if "uuids" in kwargs:
|
||||
_id = kwargs["uuids"][i]
|
||||
else:
|
||||
_id = get_valid_uuid(uuid4())
|
||||
|
||||
batch.add_data_object(data_properties, self._index_name, _id)
|
||||
ids.append(_id)
|
||||
return ids
|
||||
|
@ -351,7 +351,7 @@ class Redis(VectorStore):
|
||||
if self.relevance_score_fn is None:
|
||||
raise ValueError(
|
||||
"relevance_score_fn must be provided to"
|
||||
" Weaviate constructor to normalize scores"
|
||||
" Redis constructor to normalize scores"
|
||||
)
|
||||
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
||||
return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores]
|
||||
|
@ -134,7 +134,12 @@ class Weaviate(VectorStore):
|
||||
for key in metadatas[i].keys():
|
||||
data_properties[key] = json_serializable(metadatas[i][key])
|
||||
|
||||
_id = get_valid_uuid(uuid4())
|
||||
# If the UUID of one of the objects already exists
|
||||
# then the existing objectwill be replaced by the new object.
|
||||
if "uuids" in kwargs:
|
||||
_id = kwargs["uuids"][i]
|
||||
else:
|
||||
_id = get_valid_uuid(uuid4())
|
||||
|
||||
if self._embedding is not None:
|
||||
embeddings = self._embedding.embed_documents(list(doc))
|
||||
@ -385,7 +390,12 @@ class Weaviate(VectorStore):
|
||||
for key in metadatas[i].keys():
|
||||
data_properties[key] = metadatas[i][key]
|
||||
|
||||
_id = get_valid_uuid(uuid4())
|
||||
# If the UUID of one of the objects already exists
|
||||
# then the existing objectwill be replaced by the new object.
|
||||
if "uuids" in kwargs:
|
||||
_id = kwargs["uuids"][i]
|
||||
else:
|
||||
_id = get_valid_uuid(uuid4())
|
||||
|
||||
# if an embedding strategy is not provided, we let
|
||||
# weaviate create the embedding. Note that this will only
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Test Weaviate functionality."""
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Generator, Union
|
||||
from uuid import uuid4
|
||||
|
||||
@ -85,3 +86,28 @@ class TestWeaviateHybridSearchRetriever:
|
||||
assert output == [
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
]
|
||||
|
||||
@pytest.mark.vcr(ignore_localhost=True)
|
||||
def test_get_relevant_documents_with_uuids(self, weaviate_url: str) -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
# Weaviate replaces the object if the UUID already exists
|
||||
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, "same-name") for text in texts]
|
||||
|
||||
client = Client(weaviate_url)
|
||||
|
||||
retriever = WeaviateHybridSearchRetriever(
|
||||
client=client,
|
||||
index_name=f"LangChain_{uuid4().hex}",
|
||||
text_key="text",
|
||||
attributes=["page"],
|
||||
)
|
||||
for i, text in enumerate(texts):
|
||||
# hoge
|
||||
retriever.add_documents(
|
||||
[Document(page_content=text, metadata=metadatas[i])], uuids=[uuids[i]]
|
||||
)
|
||||
|
||||
output = retriever.get_relevant_documents("foo")
|
||||
assert len(output) == 1
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Test Weaviate functionality."""
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Generator, Union
|
||||
|
||||
import pytest
|
||||
@ -80,6 +81,26 @@ class TestWeaviate:
|
||||
)
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
|
||||
@pytest.mark.vcr(ignore_localhost=True)
|
||||
def test_similarity_search_with_uuids(
|
||||
self, weaviate_url: str, embedding_openai: OpenAIEmbeddings
|
||||
) -> None:
|
||||
"""Test end to end construction and search with uuids."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
# Weaviate replaces the object if the UUID already exists
|
||||
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, "same-name") for text in texts]
|
||||
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = Weaviate.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
metadatas=metadatas,
|
||||
weaviate_url=weaviate_url,
|
||||
uuids=uuids,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=2)
|
||||
assert len(output) == 1
|
||||
|
||||
@pytest.mark.vcr(ignore_localhost=True)
|
||||
def test_max_marginal_relevance_search(
|
||||
self, weaviate_url: str, embedding_openai: OpenAIEmbeddings
|
||||
@ -181,3 +202,23 @@ class TestWeaviate:
|
||||
Document(page_content="foo"),
|
||||
Document(page_content="foo"),
|
||||
]
|
||||
|
||||
def test_add_texts_with_given_uuids(self, weaviate_url: str) -> None:
|
||||
texts = ["foo", "bar", "baz"]
|
||||
embedding = FakeEmbeddings()
|
||||
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, text) for text in texts]
|
||||
|
||||
docsearch = Weaviate.from_texts(
|
||||
texts,
|
||||
embedding=embedding,
|
||||
weaviate_url=weaviate_url,
|
||||
uuids=uuids,
|
||||
)
|
||||
|
||||
# Weaviate replaces the object if the UUID already exists
|
||||
docsearch.add_texts(["foo"], uuids=[uuids[0]])
|
||||
output = docsearch.similarity_search_by_vector(
|
||||
embedding.embed_query("foo"), k=2
|
||||
)
|
||||
assert output[0] == Document(page_content="foo")
|
||||
assert output[1] != Document(page_content="foo")
|
||||
|
Loading…
Reference in New Issue
Block a user