mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
parent
dfb93dd2b5
commit
d21333d710
115
libs/langchain/langchain/storage/redis.py
Normal file
115
libs/langchain/langchain/storage/redis.py
Normal file
@ -0,0 +1,115 @@
|
||||
from typing import Any, Iterator, List, Optional, Sequence, Tuple, cast
|
||||
|
||||
from langchain.schema import BaseStore
|
||||
|
||||
|
||||
class RedisStore(BaseStore[str, bytes]):
|
||||
"""BaseStore implementation using Redis as the underlying store.
|
||||
|
||||
Examples:
|
||||
Create a RedisStore instance and perform operations on it:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Instantiate the RedisStore with a Redis connection
|
||||
from langchain.storage import RedisStore
|
||||
from langchain.vectorstores.redis import get_client
|
||||
|
||||
client = get_client('redis://localhost:6379')
|
||||
redis_store = RedisStore(client)
|
||||
|
||||
# Set values for keys
|
||||
redis_store.mset([("key1", b"value1"), ("key2", b"value2")])
|
||||
|
||||
# Get values for keys
|
||||
values = redis_store.mget(["key1", "key2"])
|
||||
# [b"value1", b"value2"]
|
||||
|
||||
# Delete keys
|
||||
redis_store.mdelete(["key1"])
|
||||
|
||||
# Iterate over keys
|
||||
for key in redis_store.yield_keys():
|
||||
print(key)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, client: Any, *, ttl: Optional[int] = None, namespace: Optional[str] = None
|
||||
) -> None:
|
||||
"""Initialize the RedisStore with a Redis connection.
|
||||
|
||||
Args:
|
||||
client: A Redis connection instance
|
||||
ttl: time to expire keys in seconds if provided,
|
||||
if None keys will never expire
|
||||
namespace: if provided, all keys will be prefixed with this namespace
|
||||
"""
|
||||
try:
|
||||
from redis import Redis
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"The RedisStore requires the redis library to be installed. "
|
||||
"pip install redis"
|
||||
) from e
|
||||
|
||||
if not isinstance(client, Redis):
|
||||
raise TypeError(
|
||||
f"Expected Redis client, got {type(client).__name__} instead."
|
||||
)
|
||||
|
||||
self.client = client
|
||||
|
||||
if not isinstance(ttl, int) and ttl is not None:
|
||||
raise TypeError(f"Expected int or None, got {type(ttl)} instead.")
|
||||
|
||||
self.ttl = ttl
|
||||
self.namespace = namespace
|
||||
self.namespace_delimiter = "/"
|
||||
|
||||
def _get_prefixed_key(self, key: str) -> str:
|
||||
"""Get the key with the namespace prefix.
|
||||
|
||||
Args:
|
||||
key (str): The original key.
|
||||
|
||||
Returns:
|
||||
str: The key with the namespace prefix.
|
||||
"""
|
||||
if self.namespace:
|
||||
return f"{self.namespace}{self.namespace_delimiter}{key}"
|
||||
return key
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||
"""Get the values associated with the given keys."""
|
||||
return cast(
|
||||
List[Optional[bytes]],
|
||||
self.client.mget([self._get_prefixed_key(key) for key in keys]),
|
||||
)
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||
"""Set the given key-value pairs."""
|
||||
pipe = self.client.pipeline()
|
||||
|
||||
for key, value in key_value_pairs:
|
||||
pipe.set(self._get_prefixed_key(key), value, ex=self.ttl)
|
||||
pipe.execute()
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
"""Delete the given keys."""
|
||||
_keys = [self._get_prefixed_key(key) for key in keys]
|
||||
self.client.delete(*_keys)
|
||||
|
||||
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
"""Yield keys in the store."""
|
||||
if prefix:
|
||||
pattern = self._get_prefixed_key(prefix)
|
||||
else:
|
||||
pattern = self._get_prefixed_key("*")
|
||||
scan_iter = cast(Iterator[bytes], self.client.scan_iter(match=pattern))
|
||||
for key in scan_iter:
|
||||
decoded_key = key.decode("utf-8")
|
||||
if self.namespace:
|
||||
relative_key = decoded_key[len(self.namespace) + 1 :]
|
||||
yield relative_key
|
||||
else:
|
||||
yield decoded_key
|
105
libs/langchain/tests/integration_tests/storage/test_redis.py
Normal file
105
libs/langchain/tests/integration_tests/storage/test_redis.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""Implement integration tests for Redis storage."""
|
||||
import os
|
||||
import typing
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
|
||||
from langchain.storage.redis import RedisStore
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
try:
|
||||
from redis import Redis
|
||||
except ImportError:
|
||||
# Ignoring mypy here to allow assignment of Any to Redis in the event
|
||||
# that redis is not installed.
|
||||
Redis = Any # type:ignore
|
||||
else:
|
||||
Redis = Any # type:ignore
|
||||
|
||||
|
||||
pytest.importorskip("redis")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_client() -> Redis:
|
||||
"""Yield redis client."""
|
||||
# Using standard port, but protecting against accidental data loss
|
||||
# by requiring a password.
|
||||
# This fixture flushes the database!
|
||||
# The only role of the password is to prevent users from accidentally
|
||||
# deleting their data.
|
||||
# The password should establish the identity of the server being.
|
||||
port = 6379
|
||||
password = os.environ.get("REDIS_PASSWORD") or str(uuid.uuid4())
|
||||
client = redis.Redis(host="localhost", port=port, password=password, db=0)
|
||||
try:
|
||||
client.ping()
|
||||
except redis.exceptions.ConnectionError:
|
||||
pytest.skip(
|
||||
"Redis server is not running or is not accessible. "
|
||||
"Verify that credentials are correct. "
|
||||
)
|
||||
# ATTENTION: This will delete all keys in the database!
|
||||
client.flushdb()
|
||||
return client
|
||||
|
||||
|
||||
def test_mget(redis_client: Redis) -> None:
|
||||
"""Test mget method."""
|
||||
store = RedisStore(redis_client, ttl=None)
|
||||
keys = ["key1", "key2"]
|
||||
redis_client.mset({"key1": b"value1", "key2": b"value2"})
|
||||
result = store.mget(keys)
|
||||
assert result == [b"value1", b"value2"]
|
||||
|
||||
|
||||
def test_mset(redis_client: Redis) -> None:
|
||||
"""Test that multiple keys can be set."""
|
||||
store = RedisStore(redis_client, ttl=None)
|
||||
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
|
||||
store.mset(key_value_pairs)
|
||||
result = redis_client.mget(["key1", "key2"])
|
||||
assert result == [b"value1", b"value2"]
|
||||
|
||||
|
||||
def test_mdelete(redis_client: Redis) -> None:
|
||||
"""Test that deletion works as expected."""
|
||||
store = RedisStore(redis_client, ttl=None)
|
||||
keys = ["key1", "key2"]
|
||||
redis_client.mset({"key1": b"value1", "key2": b"value2"})
|
||||
store.mdelete(keys)
|
||||
result = redis_client.mget(keys)
|
||||
assert result == [None, None]
|
||||
|
||||
|
||||
def test_yield_keys(redis_client: Redis) -> None:
|
||||
store = RedisStore(redis_client, ttl=None)
|
||||
redis_client.mset({"key1": b"value1", "key2": b"value2"})
|
||||
assert sorted(store.yield_keys()) == ["key1", "key2"]
|
||||
assert sorted(store.yield_keys(prefix="key*")) == ["key1", "key2"]
|
||||
assert sorted(store.yield_keys(prefix="lang*")) == []
|
||||
|
||||
|
||||
def test_namespace(redis_client: Redis) -> None:
|
||||
"""Test that a namespace is prepended to all keys properly."""
|
||||
store = RedisStore(redis_client, ttl=None, namespace="meow")
|
||||
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
|
||||
store.mset(key_value_pairs)
|
||||
|
||||
assert sorted(redis_client.scan_iter("*")) == [
|
||||
b"meow/key1",
|
||||
b"meow/key2",
|
||||
]
|
||||
|
||||
store.mdelete(["key1"])
|
||||
|
||||
assert sorted(redis_client.scan_iter("*")) == [
|
||||
b"meow/key2",
|
||||
]
|
||||
|
||||
assert list(store.yield_keys()) == ["key2"]
|
||||
assert list(store.yield_keys(prefix="key*")) == ["key2"]
|
||||
assert list(store.yield_keys(prefix="key1")) == []
|
11
libs/langchain/tests/unit_tests/storage/test_redis.py
Normal file
11
libs/langchain/tests/unit_tests/storage/test_redis.py
Normal file
@ -0,0 +1,11 @@
|
||||
"""Light weight unit test that attempts to import RedisStore.
|
||||
|
||||
The actual code is tested in integration tests.
|
||||
|
||||
This test is intended to catch errors in the import process.
|
||||
"""
|
||||
|
||||
|
||||
def test_import_storage() -> None:
|
||||
"""Attempt to import storage modules."""
|
||||
from langchain.storage.redis import RedisStore # noqa
|
Loading…
Reference in New Issue
Block a user