fix(cache): use dumps for RedisCache (#10408)

# Description
Attempts to fix RedisCache for ChatGenerations using `loads` and `dumps`
used in SQLAlchemy cache by @hwchase17 . this is better than pickle
dump, because this won't execute any arbitrary code during
de-serialisation.

# Issues
#7722 & #8666 

# Dependencies
None, but removes the warning introduced in #8041 by @baskaryan

Handle: @jaikanthjay46
pull/10303/head
Jaikanth J 10 months ago committed by GitHub
parent 5944c1851b
commit 9f85f7c543
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,7 +25,6 @@ import hashlib
import inspect
import json
import logging
import warnings
from datetime import timedelta
from functools import lru_cache
from typing import (
@ -54,7 +53,7 @@ except ImportError:
from langchain.llms.base import LLM, get_prompts
from langchain.load.dump import dumps
from langchain.load.load import loads
from langchain.schema import ChatGeneration, Generation
from langchain.schema import Generation
from langchain.schema.cache import RETURN_VAL_TYPE, BaseCache
from langchain.schema.embeddings import Embeddings
from langchain.utils import get_from_env
@ -306,7 +305,18 @@ class RedisCache(BaseCache):
results = self.redis.hgetall(self._key(prompt, llm_string))
if results:
for _, text in results.items():
generations.append(Generation(text=text))
try:
generations.append(loads(text))
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
"properly. This is likely due to the cache being in an "
"older format. Please recreate your cache to avoid this "
"error."
)
# In a previous life we stored the raw text directly
# in the table, so assume it's in that format.
generations.append(Generation(text=text))
return generations if generations else None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
@ -317,12 +327,6 @@ class RedisCache(BaseCache):
"RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
if isinstance(gen, ChatGeneration):
warnings.warn(
"NOTE: Generation has not been cached. RedisCache does not"
" support caching ChatModel outputs."
)
return
# Write to a Redis HASH
key = self._key(prompt, llm_string)
@ -330,7 +334,7 @@ class RedisCache(BaseCache):
pipe.hset(
key,
mapping={
str(idx): generation.text
str(idx): dumps(generation)
for idx, generation in enumerate(return_val)
},
)
@ -441,9 +445,20 @@ class RedisSemanticCache(BaseCache):
)
if results:
for document in results:
generations.extend(
_load_generations_from_json(document.metadata["return_val"])
)
try:
generations.extend(loads(document.metadata["return_val"]))
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
"properly. This is likely due to the cache being in an "
"older format. Please recreate your cache to avoid this "
"error."
)
# In a previous life we stored the raw text directly
# in the table, so assume it's in that format.
generations.extend(
_load_generations_from_json(document.metadata["return_val"])
)
return generations if generations else None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
@ -454,18 +469,12 @@ class RedisSemanticCache(BaseCache):
"RedisSemanticCache only supports caching of "
f"normal LLM generations, got {type(gen)}"
)
if isinstance(gen, ChatGeneration):
warnings.warn(
"NOTE: Generation has not been cached. RedisSentimentCache does not"
" support caching ChatModel outputs."
)
return
llm_cache = self._get_llm_cache(llm_string)
_dump_generations_to_json([g for g in return_val])
metadata = {
"llm_string": llm_string,
"prompt": prompt,
"return_val": _dump_generations_to_json([g for g in return_val]),
"return_val": dumps([g for g in return_val]),
}
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])

@ -6,8 +6,11 @@ import pytest
import langchain
from langchain.cache import RedisCache, RedisSemanticCache
from langchain.load.dump import dumps
from langchain.schema import Generation, LLMResult
from langchain.schema.embeddings import Embeddings
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from langchain.schema.output import ChatGeneration
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
FakeEmbeddings,
@ -56,9 +59,17 @@ def test_redis_cache_chat() -> None:
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
with pytest.warns():
llm.predict("foo")
llm.predict("foo")
llm_string = str(sorted([(k, v) for k, v in params.items()]))
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
langchain.llm_cache.update(
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
)
output = llm.generate([prompt])
expected_output = LLMResult(
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
llm_output={},
)
assert output == expected_output
langchain.llm_cache.redis.flushall()
@ -120,9 +131,16 @@ def test_redis_semantic_cache_chat() -> None:
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
with pytest.warns():
llm.predict("foo")
llm.predict("foo")
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
langchain.llm_cache.update(
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
)
output = llm.generate([prompt])
expected_output = LLMResult(
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
llm_output={},
)
assert output == expected_output
langchain.llm_cache.clear(llm_string=llm_string)

Loading…
Cancel
Save