Redis: Fix keys (#10413)

- Description: Fixes user issue with custom keys for ``from_texts`` and
``from_documents`` methods.
  - Issue: #10411 
  - Tag maintainer: @baskaryan 
  - Twitter handle: @spartee
This commit is contained in:
Sam Partee 2023-09-09 20:46:26 -04:00 committed by GitHub
parent ee3f950a67
commit d09ef9eb52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 1 deletions

View File

@ -374,6 +374,11 @@ class Redis(VectorStore):
if "generate" in kwargs: if "generate" in kwargs:
kwargs.pop("generate") kwargs.pop("generate")
# see if the user specified keys
keys = None
if "keys" in kwargs:
keys = kwargs.pop("keys")
# Name of the search index if not given # Name of the search index if not given
if not index_name: if not index_name:
index_name = uuid.uuid4().hex index_name = uuid.uuid4().hex
@ -422,7 +427,7 @@ class Redis(VectorStore):
instance._create_index(dim=len(embeddings[0])) instance._create_index(dim=len(embeddings[0]))
# Add data to Redis # Add data to Redis
keys = instance.add_texts(texts, metadatas, embeddings) keys = instance.add_texts(texts, metadatas, embeddings, keys=keys)
return instance, keys return instance, keys
@classmethod @classmethod

View File

@ -136,6 +136,32 @@ def test_redis_from_documents(texts: List[str]) -> None:
assert drop(docsearch.index_name) assert drop(docsearch.index_name)
def test_custom_keys(texts: List[str]) -> None:
keys_in = ["test_key_1", "test_key_2", "test_key_3"]
docsearch, keys_out = Redis.from_texts_return_keys(
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, keys=keys_in
)
assert keys_in == keys_out
assert drop(docsearch.index_name)
def test_custom_keys_from_docs(texts: List[str]) -> None:
keys_in = ["test_key_1", "test_key_2", "test_key_3"]
docs = [Document(page_content=t, metadata={"a": "b"}) for t in texts]
docsearch = Redis.from_documents(
docs, FakeEmbeddings(), redis_url=TEST_REDIS_URL, keys=keys_in
)
client = docsearch.client
# test keys are correct
assert client.hget("test_key_1", "content")
# test metadata is stored
assert client.hget("test_key_1", "a") == bytes("b", "utf-8")
# test all keys are stored
assert client.hget("test_key_2", "content")
assert drop(docsearch.index_name)
# -- test filters -- # # -- test filters -- #