From 7717c24fc4caa2b878e5dc88d4c0b78d2f932766 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Thu, 20 Jul 2023 19:00:05 -0700 Subject: [PATCH] fix redis cache chat model (#8041) Redis cache currently stores model outputs as strings. Chat generations have Messages which contain more information than just a string. Until Redis cache supports fully storing messages, cache should not interact with chat generations. --- langchain/cache.py | 15 +++++++++++- .../cache/test_redis_cache.py | 24 +++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/langchain/cache.py b/langchain/cache.py index 5ba675e49d..b589a29891 100644 --- a/langchain/cache.py +++ b/langchain/cache.py @@ -5,6 +5,7 @@ import hashlib import inspect import json import logging +import warnings from abc import ABC, abstractmethod from datetime import timedelta from typing import ( @@ -34,7 +35,7 @@ except ImportError: from langchain.embeddings.base import Embeddings from langchain.load.dump import dumps from langchain.load.load import loads -from langchain.schema import Generation +from langchain.schema import ChatGeneration, Generation from langchain.vectorstores.redis import Redis as RedisVectorstore logger = logging.getLogger(__file__) @@ -232,6 +233,12 @@ class RedisCache(BaseCache): "RedisCache only supports caching of normal LLM generations, " f"got {type(gen)}" ) + if isinstance(gen, ChatGeneration): + warnings.warn( + "NOTE: Generation has not been cached. RedisCache does not" + " support caching ChatModel outputs." + ) + return # Write to a Redis HASH key = self._key(prompt, llm_string) self.redis.hset( @@ -345,6 +352,12 @@ class RedisSemanticCache(BaseCache): "RedisSemanticCache only supports caching of " f"normal LLM generations, got {type(gen)}" ) + if isinstance(gen, ChatGeneration): + warnings.warn( + "NOTE: Generation has not been cached. RedisSentimentCache does not" + " support caching ChatModel outputs." + ) + return llm_cache = self._get_llm_cache(llm_string) # Write to vectorstore metadata = { diff --git a/tests/integration_tests/cache/test_redis_cache.py b/tests/integration_tests/cache/test_redis_cache.py index 7ce18bf65e..b39cc0fac9 100644 --- a/tests/integration_tests/cache/test_redis_cache.py +++ b/tests/integration_tests/cache/test_redis_cache.py @@ -1,10 +1,12 @@ """Test Redis cache functionality.""" +import pytest import redis import langchain from langchain.cache import RedisCache, RedisSemanticCache from langchain.schema import Generation, LLMResult from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings +from tests.unit_tests.llms.fake_chat_model import FakeChatModel from tests.unit_tests.llms.fake_llm import FakeLLM REDIS_TEST_URL = "redis://localhost:6379" @@ -28,6 +30,17 @@ def test_redis_cache() -> None: langchain.llm_cache.redis.flushall() +def test_redis_cache_chat() -> None: + langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)) + llm = FakeChatModel() + params = llm.dict() + params["stop"] = None + with pytest.warns(): + llm.predict("foo") + llm.predict("foo") + langchain.llm_cache.redis.flushall() + + def test_redis_semantic_cache() -> None: langchain.llm_cache = RedisSemanticCache( embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1 @@ -53,3 +66,14 @@ def test_redis_semantic_cache() -> None: # expect different output now without cached result assert output != expected_output langchain.llm_cache.clear(llm_string=llm_string) + + +def test_redis_semantic_cache_chat() -> None: + langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)) + llm = FakeChatModel() + params = llm.dict() + params["stop"] = None + with pytest.warns(): + llm.predict("foo") + llm.predict("foo") + langchain.llm_cache.redis.flushall()