fix redis cache chat model (#8041)

Redis cache currently stores model outputs as strings. Chat generations
have Messages which contain more information than just a string. Until
Redis cache supports fully storing messages, cache should not interact
with chat generations.
pull/7945/head
Bagatur 1 year ago committed by GitHub
parent 973593c5c7
commit 7717c24fc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,6 +5,7 @@ import hashlib
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from datetime import timedelta
from typing import (
@ -34,7 +35,7 @@ except ImportError:
from langchain.embeddings.base import Embeddings
from langchain.load.dump import dumps
from langchain.load.load import loads
from langchain.schema import Generation
from langchain.schema import ChatGeneration, Generation
from langchain.vectorstores.redis import Redis as RedisVectorstore
logger = logging.getLogger(__file__)
@ -232,6 +233,12 @@ class RedisCache(BaseCache):
"RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
if isinstance(gen, ChatGeneration):
warnings.warn(
"NOTE: Generation has not been cached. RedisCache does not"
" support caching ChatModel outputs."
)
return
# Write to a Redis HASH
key = self._key(prompt, llm_string)
self.redis.hset(
@ -345,6 +352,12 @@ class RedisSemanticCache(BaseCache):
"RedisSemanticCache only supports caching of "
f"normal LLM generations, got {type(gen)}"
)
if isinstance(gen, ChatGeneration):
warnings.warn(
"NOTE: Generation has not been cached. RedisSentimentCache does not"
" support caching ChatModel outputs."
)
return
llm_cache = self._get_llm_cache(llm_string)
# Write to vectorstore
metadata = {

@ -1,10 +1,12 @@
"""Test Redis cache functionality."""
import pytest
import redis
import langchain
from langchain.cache import RedisCache, RedisSemanticCache
from langchain.schema import Generation, LLMResult
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
REDIS_TEST_URL = "redis://localhost:6379"
@ -28,6 +30,17 @@ def test_redis_cache() -> None:
langchain.llm_cache.redis.flushall()
def test_redis_cache_chat() -> None:
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
with pytest.warns():
llm.predict("foo")
llm.predict("foo")
langchain.llm_cache.redis.flushall()
def test_redis_semantic_cache() -> None:
langchain.llm_cache = RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
@ -53,3 +66,14 @@ def test_redis_semantic_cache() -> None:
# expect different output now without cached result
assert output != expected_output
langchain.llm_cache.clear(llm_string=llm_string)
def test_redis_semantic_cache_chat() -> None:
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
with pytest.warns():
llm.predict("foo")
llm.predict("foo")
langchain.llm_cache.redis.flushall()

Loading…
Cancel
Save