diff --git a/libs/langchain/langchain/vectorstores/redis/base.py b/libs/langchain/langchain/vectorstores/redis/base.py index fe973b4831..320c6730e3 100644 --- a/libs/langchain/langchain/vectorstores/redis/base.py +++ b/libs/langchain/langchain/vectorstores/redis/base.py @@ -374,6 +374,11 @@ class Redis(VectorStore): if "generate" in kwargs: kwargs.pop("generate") + # see if the user specified keys + keys = None + if "keys" in kwargs: + keys = kwargs.pop("keys") + # Name of the search index if not given if not index_name: index_name = uuid.uuid4().hex @@ -422,7 +427,7 @@ class Redis(VectorStore): instance._create_index(dim=len(embeddings[0])) # Add data to Redis - keys = instance.add_texts(texts, metadatas, embeddings) + keys = instance.add_texts(texts, metadatas, embeddings, keys=keys) return instance, keys @classmethod diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_redis.py b/libs/langchain/tests/integration_tests/vectorstores/test_redis.py index cbcd78d070..6128a8445a 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_redis.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_redis.py @@ -136,6 +136,32 @@ def test_redis_from_documents(texts: List[str]) -> None: assert drop(docsearch.index_name) +def test_custom_keys(texts: List[str]) -> None: + keys_in = ["test_key_1", "test_key_2", "test_key_3"] + docsearch, keys_out = Redis.from_texts_return_keys( + texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, keys=keys_in + ) + assert keys_in == keys_out + assert drop(docsearch.index_name) + + +def test_custom_keys_from_docs(texts: List[str]) -> None: + keys_in = ["test_key_1", "test_key_2", "test_key_3"] + docs = [Document(page_content=t, metadata={"a": "b"}) for t in texts] + + docsearch = Redis.from_documents( + docs, FakeEmbeddings(), redis_url=TEST_REDIS_URL, keys=keys_in + ) + client = docsearch.client + # test keys are correct + assert client.hget("test_key_1", "content") + # test metadata is stored + assert client.hget("test_key_1", "a") == bytes("b", "utf-8") + # test all keys are stored + assert client.hget("test_key_2", "content") + assert drop(docsearch.index_name) + + # -- test filters -- #