diff --git a/libs/langchain/langchain/cache.py b/libs/langchain/langchain/cache.py index 36a3b09ad0..e77528924c 100644 --- a/libs/langchain/langchain/cache.py +++ b/libs/langchain/langchain/cache.py @@ -25,7 +25,6 @@ import hashlib import inspect import json import logging -import warnings from datetime import timedelta from functools import lru_cache from typing import ( @@ -54,7 +53,7 @@ except ImportError: from langchain.llms.base import LLM, get_prompts from langchain.load.dump import dumps from langchain.load.load import loads -from langchain.schema import ChatGeneration, Generation +from langchain.schema import Generation from langchain.schema.cache import RETURN_VAL_TYPE, BaseCache from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_env @@ -306,7 +305,18 @@ class RedisCache(BaseCache): results = self.redis.hgetall(self._key(prompt, llm_string)) if results: for _, text in results.items(): - generations.append(Generation(text=text)) + try: + generations.append(loads(text)) + except Exception: + logger.warning( + "Retrieving a cache value that could not be deserialized " + "properly. This is likely due to the cache being in an " + "older format. Please recreate your cache to avoid this " + "error." + ) + # In a previous life we stored the raw text directly + # in the table, so assume it's in that format. + generations.append(Generation(text=text)) return generations if generations else None def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: @@ -317,12 +327,6 @@ class RedisCache(BaseCache): "RedisCache only supports caching of normal LLM generations, " f"got {type(gen)}" ) - if isinstance(gen, ChatGeneration): - warnings.warn( - "NOTE: Generation has not been cached. RedisCache does not" - " support caching ChatModel outputs." - ) - return # Write to a Redis HASH key = self._key(prompt, llm_string) @@ -330,7 +334,7 @@ class RedisCache(BaseCache): pipe.hset( key, mapping={ - str(idx): generation.text + str(idx): dumps(generation) for idx, generation in enumerate(return_val) }, ) @@ -441,9 +445,20 @@ class RedisSemanticCache(BaseCache): ) if results: for document in results: - generations.extend( - _load_generations_from_json(document.metadata["return_val"]) - ) + try: + generations.extend(loads(document.metadata["return_val"])) + except Exception: + logger.warning( + "Retrieving a cache value that could not be deserialized " + "properly. This is likely due to the cache being in an " + "older format. Please recreate your cache to avoid this " + "error." + ) + # In a previous life we stored the raw text directly + # in the table, so assume it's in that format. + generations.extend( + _load_generations_from_json(document.metadata["return_val"]) + ) return generations if generations else None def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: @@ -454,18 +469,12 @@ class RedisSemanticCache(BaseCache): "RedisSemanticCache only supports caching of " f"normal LLM generations, got {type(gen)}" ) - if isinstance(gen, ChatGeneration): - warnings.warn( - "NOTE: Generation has not been cached. RedisSentimentCache does not" - " support caching ChatModel outputs." - ) - return llm_cache = self._get_llm_cache(llm_string) - _dump_generations_to_json([g for g in return_val]) + metadata = { "llm_string": llm_string, "prompt": prompt, - "return_val": _dump_generations_to_json([g for g in return_val]), + "return_val": dumps([g for g in return_val]), } llm_cache.add_texts(texts=[prompt], metadatas=[metadata]) diff --git a/libs/langchain/tests/integration_tests/cache/test_redis_cache.py b/libs/langchain/tests/integration_tests/cache/test_redis_cache.py index e78aedd204..4edca62ee7 100644 --- a/libs/langchain/tests/integration_tests/cache/test_redis_cache.py +++ b/libs/langchain/tests/integration_tests/cache/test_redis_cache.py @@ -6,8 +6,11 @@ import pytest import langchain from langchain.cache import RedisCache, RedisSemanticCache +from langchain.load.dump import dumps from langchain.schema import Generation, LLMResult from langchain.schema.embeddings import Embeddings +from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage +from langchain.schema.output import ChatGeneration from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, FakeEmbeddings, @@ -56,9 +59,17 @@ def test_redis_cache_chat() -> None: llm = FakeChatModel() params = llm.dict() params["stop"] = None - with pytest.warns(): - llm.predict("foo") - llm.predict("foo") + llm_string = str(sorted([(k, v) for k, v in params.items()])) + prompt: List[BaseMessage] = [HumanMessage(content="foo")] + langchain.llm_cache.update( + dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))] + ) + output = llm.generate([prompt]) + expected_output = LLMResult( + generations=[[ChatGeneration(message=AIMessage(content="fizz"))]], + llm_output={}, + ) + assert output == expected_output langchain.llm_cache.redis.flushall() @@ -120,9 +131,16 @@ def test_redis_semantic_cache_chat() -> None: params = llm.dict() params["stop"] = None llm_string = str(sorted([(k, v) for k, v in params.items()])) - with pytest.warns(): - llm.predict("foo") - llm.predict("foo") + prompt: List[BaseMessage] = [HumanMessage(content="foo")] + langchain.llm_cache.update( + dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))] + ) + output = llm.generate([prompt]) + expected_output = LLMResult( + generations=[[ChatGeneration(message=AIMessage(content="fizz"))]], + llm_output={}, + ) + assert output == expected_output langchain.llm_cache.clear(llm_string=llm_string)