From 89be10f6b4f53eda96f1008c0e057c7df68f003f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Glauco=20Cust=C3=B3dio?= Date: Mon, 14 Aug 2023 17:59:18 +0100 Subject: [PATCH] add ttl to RedisCache (#9068) Add `ttl` (time to live) to `RedisCache` --- libs/langchain/langchain/cache.py | 43 ++++++++++++++----- .../cache/test_redis_cache.py | 9 ++++ 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/libs/langchain/langchain/cache.py b/libs/langchain/langchain/cache.py index 2136acaad1..a78c060801 100644 --- a/libs/langchain/langchain/cache.py +++ b/libs/langchain/langchain/cache.py @@ -216,10 +216,25 @@ class SQLiteCache(SQLAlchemyCache): class RedisCache(BaseCache): """Cache that uses Redis as a backend.""" - # TODO - implement a TTL policy in Redis - - def __init__(self, redis_: Any): - """Initialize by passing in Redis instance.""" + def __init__(self, redis_: Any, *, ttl: Optional[int] = None): + """ + Initialize an instance of RedisCache. + + This method initializes an object with Redis caching capabilities. + It takes a `redis_` parameter, which should be an instance of a Redis + client class, allowing the object to interact with a Redis + server for caching purposes. + + Parameters: + redis_ (Any): An instance of a Redis client class + (e.g., redis.Redis) used for caching. + This allows the object to communicate with a + Redis server for caching operations. + ttl (int, optional): Time-to-live (TTL) for cached items in seconds. + If provided, it sets the time duration for how long cached + items will remain valid. If not provided, cached items will not + have an automatic expiration. + """ try: from redis import Redis except ImportError: @@ -230,6 +245,7 @@ class RedisCache(BaseCache): if not isinstance(redis_, Redis): raise ValueError("Please pass in Redis object.") self.redis = redis_ + self.ttl = ttl def _key(self, prompt: str, llm_string: str) -> str: """Compute key from prompt and llm_string""" @@ -261,12 +277,19 @@ class RedisCache(BaseCache): return # Write to a Redis HASH key = self._key(prompt, llm_string) - self.redis.hset( - key, - mapping={ - str(idx): generation.text for idx, generation in enumerate(return_val) - }, - ) + + with self.redis.pipeline() as pipe: + pipe.hset( + key, + mapping={ + str(idx): generation.text + for idx, generation in enumerate(return_val) + }, + ) + if self.ttl is not None: + pipe.expire(key, self.ttl) + + pipe.execute() def clear(self, **kwargs: Any) -> None: """Clear cache. If `asynchronous` is True, flush asynchronously.""" diff --git a/libs/langchain/tests/integration_tests/cache/test_redis_cache.py b/libs/langchain/tests/integration_tests/cache/test_redis_cache.py index 7d43c9a405..5d51a12e8e 100644 --- a/libs/langchain/tests/integration_tests/cache/test_redis_cache.py +++ b/libs/langchain/tests/integration_tests/cache/test_redis_cache.py @@ -11,6 +11,15 @@ from tests.unit_tests.llms.fake_llm import FakeLLM REDIS_TEST_URL = "redis://localhost:6379" +def test_redis_cache_ttl() -> None: + import redis + + langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1) + langchain.llm_cache.update("foo", "bar", [Generation(text="fizz")]) + key = langchain.llm_cache._key("foo", "bar") + assert langchain.llm_cache.redis.pttl(key) > 0 + + def test_redis_cache() -> None: import redis