From 068142fce203d531db7aca381d3cbf034e2bbc60 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Sat, 24 Jun 2023 20:45:09 +0200 Subject: [PATCH] Add caching to BaseChatModel (issue #1644) (#5089) # Add caching to BaseChatModel Fixes #1644 (Sidenote: While testing, I noticed we have multiple implementations of Fake LLMs, used for testing. I consolidated them.) ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: Models - @hwchase17 - @agola11 Twitter: [@UmerHAdil](https://twitter.com/@UmerHAdil) | Discord: RicChilligerDude#7589 --------- Co-authored-by: Harrison Chase --- .../models/chat/how_to/chat_model_caching.mdx | 9 ++ .../models/chat/how_to/chat_model_caching.mdx | 97 ++++++++++++ .../models/llms/how_to/llm_caching.mdx | 8 +- langchain/cache.py | 52 ++++++- langchain/chat_models/__init__.py | 2 + langchain/chat_models/base.py | 125 ++++++++++++--- langchain/chat_models/fake.py | 33 ++++ langchain/schema.py | 10 ++ tests/unit_tests/agents/test_react.py | 32 +--- tests/unit_tests/llms/test_callbacks.py | 14 +- tests/unit_tests/test_cache.py | 146 ++++++++++++++++++ 11 files changed, 465 insertions(+), 63 deletions(-) create mode 100644 docs/docs_skeleton/docs/modules/model_io/models/chat/how_to/chat_model_caching.mdx create mode 100644 docs/snippets/modules/model_io/models/chat/how_to/chat_model_caching.mdx create mode 100644 langchain/chat_models/fake.py create mode 100644 tests/unit_tests/test_cache.py diff --git a/docs/docs_skeleton/docs/modules/model_io/models/chat/how_to/chat_model_caching.mdx b/docs/docs_skeleton/docs/modules/model_io/models/chat/how_to/chat_model_caching.mdx new file mode 100644 index 0000000000..c34cb22326 --- /dev/null +++ b/docs/docs_skeleton/docs/modules/model_io/models/chat/how_to/chat_model_caching.mdx @@ -0,0 +1,9 @@ +# Caching +LangChain provides an optional caching layer for Chat Models. This is useful for two reasons: + +It can save you money by reducing the number of API calls you make to the LLM provider, if you're often requesting the same completion multiple times. +It can speed up your application by reducing the number of API calls you make to the LLM provider. + +import CachingChat from "@snippets/modules/model_io/models/chat/how_to/chat_model_caching.mdx" + + diff --git a/docs/snippets/modules/model_io/models/chat/how_to/chat_model_caching.mdx b/docs/snippets/modules/model_io/models/chat/how_to/chat_model_caching.mdx new file mode 100644 index 0000000000..580ca37488 --- /dev/null +++ b/docs/snippets/modules/model_io/models/chat/how_to/chat_model_caching.mdx @@ -0,0 +1,97 @@ +```python +import langchain +from langchain.chat_models import ChatOpenAI + +llm = ChatOpenAI() +``` + +## In Memory Cache + + +```python +from langchain.cache import InMemoryCache +langchain.llm_cache = InMemoryCache() + +# The first time, it is not yet in cache, so it should take longer +llm.predict("Tell me a joke") +``` + + + +``` + CPU times: user 35.9 ms, sys: 28.6 ms, total: 64.6 ms + Wall time: 4.83 s + + + "\n\nWhy couldn't the bicycle stand up by itself? It was...two tired!" +``` + + + + +```python +# The second time it is, so it goes faster +llm.predict("Tell me a joke") +``` + + + +``` + CPU times: user 238 µs, sys: 143 µs, total: 381 µs + Wall time: 1.76 ms + + + '\n\nWhy did the chicken cross the road?\n\nTo get to the other side.' +``` + + + +## SQLite Cache + + +```bash +rm .langchain.db +``` + + +```python +# We can do the same thing with a SQLite cache +from langchain.cache import SQLiteCache +langchain.llm_cache = SQLiteCache(database_path=".langchain.db") +``` + + +```python +# The first time, it is not yet in cache, so it should take longer +llm.predict("Tell me a joke") +``` + + + +``` + CPU times: user 17 ms, sys: 9.76 ms, total: 26.7 ms + Wall time: 825 ms + + + '\n\nWhy did the chicken cross the road?\n\nTo get to the other side.' +``` + + + + +```python +# The second time it is, so it goes faster +llm.predict("Tell me a joke") +``` + + + +``` + CPU times: user 2.46 ms, sys: 1.23 ms, total: 3.7 ms + Wall time: 2.67 ms + + + '\n\nWhy did the chicken cross the road?\n\nTo get to the other side.' +``` + + diff --git a/docs/snippets/modules/model_io/models/llms/how_to/llm_caching.mdx b/docs/snippets/modules/model_io/models/llms/how_to/llm_caching.mdx index daa43c560f..5bb436ff82 100644 --- a/docs/snippets/modules/model_io/models/llms/how_to/llm_caching.mdx +++ b/docs/snippets/modules/model_io/models/llms/how_to/llm_caching.mdx @@ -14,7 +14,7 @@ from langchain.cache import InMemoryCache langchain.llm_cache = InMemoryCache() # The first time, it is not yet in cache, so it should take longer -llm("Tell me a joke") +llm.predict("Tell me a joke") ``` @@ -32,7 +32,7 @@ llm("Tell me a joke") ```python # The second time it is, so it goes faster -llm("Tell me a joke") +llm.predict("Tell me a joke") ``` @@ -64,7 +64,7 @@ langchain.llm_cache = SQLiteCache(database_path=".langchain.db") ```python # The first time, it is not yet in cache, so it should take longer -llm("Tell me a joke") +llm.predict("Tell me a joke") ``` @@ -82,7 +82,7 @@ llm("Tell me a joke") ```python # The second time it is, so it goes faster -llm("Tell me a joke") +llm.predict("Tell me a joke") ``` diff --git a/langchain/cache.py b/langchain/cache.py index 857e3dc5a6..2cfb7ff372 100644 --- a/langchain/cache.py +++ b/langchain/cache.py @@ -4,6 +4,7 @@ from __future__ import annotations import hashlib import inspect import json +import logging from abc import ABC, abstractmethod from datetime import timedelta from typing import ( @@ -11,8 +12,8 @@ from typing import ( Any, Callable, Dict, - List, Optional, + Sequence, Tuple, Type, Union, @@ -31,13 +32,17 @@ except ImportError: from sqlalchemy.ext.declarative import declarative_base from langchain.embeddings.base import Embeddings +from langchain.load.dump import dumps +from langchain.load.load import loads from langchain.schema import Generation from langchain.vectorstores.redis import Redis as RedisVectorstore +logger = logging.getLogger(__file__) + if TYPE_CHECKING: import momento -RETURN_VAL_TYPE = List[Generation] +RETURN_VAL_TYPE = Sequence[Generation] def _hash(_input: str) -> str: @@ -147,13 +152,24 @@ class SQLAlchemyCache(BaseCache): with Session(self.engine) as session: rows = session.execute(stmt).fetchall() if rows: - return [Generation(text=row[0]) for row in rows] + try: + return [loads(row[0]) for row in rows] + except Exception: + logger.warning( + "Retrieving a cache value that could not be deserialized " + "properly. This is likely due to the cache being in an " + "older format. Please recreate your cache to avoid this " + "error." + ) + # In a previous life we stored the raw text directly + # in the table, so assume it's in that format. + return [Generation(text=row[0]) for row in rows] return None def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """Update based on prompt and llm_string.""" items = [ - self.cache_schema(prompt=prompt, llm=llm_string, response=gen.text, idx=i) + self.cache_schema(prompt=prompt, llm=llm_string, response=dumps(gen), idx=i) for i, gen in enumerate(return_val) ] with Session(self.engine) as session, session.begin(): @@ -163,7 +179,7 @@ class SQLAlchemyCache(BaseCache): def clear(self, **kwargs: Any) -> None: """Clear cache.""" with Session(self.engine) as session: - session.execute(self.cache_schema.delete()) + session.query(self.cache_schema).delete() class SQLiteCache(SQLAlchemyCache): @@ -209,6 +225,12 @@ class RedisCache(BaseCache): def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """Update cache based on prompt and llm_string.""" + for gen in return_val: + if not isinstance(return_val, Generation): + raise ValueError( + "RedisCache only supports caching of normal LLM generations, " + f"got {type(gen)}" + ) # Write to a Redis HASH key = self._key(prompt, llm_string) self.redis.hset( @@ -314,6 +336,12 @@ class RedisSemanticCache(BaseCache): def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """Update cache based on prompt and llm_string.""" + for gen in return_val: + if not isinstance(return_val, Generation): + raise ValueError( + "RedisSemanticCache only supports caching of " + f"normal LLM generations, got {type(gen)}" + ) llm_cache = self._get_llm_cache(llm_string) # Write to vectorstore metadata = { @@ -426,6 +454,12 @@ class GPTCache(BaseCache): First, retrieve the corresponding cache object using the `llm_string` parameter, and then store the `prompt` and `return_val` in the cache object. """ + for gen in return_val: + if not isinstance(return_val, Generation): + raise ValueError( + "GPTCache only supports caching of normal LLM generations, " + f"got {type(gen)}" + ) from gptcache.adapter.api import put _gptcache = self._get_gptcache(llm_string) @@ -567,7 +601,7 @@ class MomentoCache(BaseCache): """ from momento.responses import CacheGet - generations = [] + generations: RETURN_VAL_TYPE = [] get_response = self.cache_client.get( self.cache_name, self.__key(prompt, llm_string) @@ -593,6 +627,12 @@ class MomentoCache(BaseCache): SdkException: Momento service or network error Exception: Unexpected response """ + for gen in return_val: + if not isinstance(return_val, Generation): + raise ValueError( + "Momento only supports caching of normal LLM generations, " + f"got {type(gen)}" + ) key = self.__key(prompt, llm_string) value = _dump_generations_to_json(return_val) set_response = self.cache_client.set(self.cache_name, key, value, self.ttl) diff --git a/langchain/chat_models/__init__.py b/langchain/chat_models/__init__.py index 0c048ec3f3..39952027df 100644 --- a/langchain/chat_models/__init__.py +++ b/langchain/chat_models/__init__.py @@ -1,5 +1,6 @@ from langchain.chat_models.anthropic import ChatAnthropic from langchain.chat_models.azure_openai import AzureChatOpenAI +from langchain.chat_models.fake import FakeListChatModel from langchain.chat_models.google_palm import ChatGooglePalm from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI @@ -8,6 +9,7 @@ from langchain.chat_models.vertexai import ChatVertexAI __all__ = [ "ChatOpenAI", "AzureChatOpenAI", + "FakeListChatModel", "PromptLayerChatOpenAI", "ChatAnthropic", "ChatGooglePalm", diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index 5fa949ef0a..c4bb81305a 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -17,7 +17,7 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, Callbacks, ) -from langchain.load.dump import dumpd +from langchain.load.dump import dumpd, dumps from langchain.schema import ( AIMessage, BaseMessage, @@ -35,6 +35,7 @@ def _get_verbosity() -> bool: class BaseChatModel(BaseLanguageModel, ABC): + cache: Optional[bool] = None verbose: bool = Field(default_factory=_get_verbosity) """Whether to print out response text.""" callbacks: Callbacks = Field(default=None, exclude=True) @@ -61,6 +62,25 @@ class BaseChatModel(BaseLanguageModel, ABC): def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: return {} + def _get_invocation_params( + self, + stop: Optional[List[str]] = None, + ) -> dict: + params = self.dict() + params["stop"] = stop + return params + + def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str: + if self.lc_serializable: + params = {**kwargs, **{"stop": stop}} + param_string = str(sorted([(k, v) for k, v in params.items()])) + llm_string = dumps(self) + return llm_string + "---" + param_string + else: + params = self._get_invocation_params(stop=stop) + params = {**params, **kwargs} + return str(sorted([(k, v) for k, v in params.items()])) + def generate( self, messages: List[List[BaseMessage]], @@ -71,9 +91,7 @@ class BaseChatModel(BaseLanguageModel, ABC): **kwargs: Any, ) -> LLMResult: """Top Level call""" - - params = self.dict() - params["stop"] = stop + params = self._get_invocation_params(stop=stop) options = {"stop": stop} callback_manager = CallbackManager.configure( @@ -87,14 +105,11 @@ class BaseChatModel(BaseLanguageModel, ABC): dumpd(self), messages, invocation_params=params, options=options ) - new_arg_supported = inspect.signature(self._generate).parameters.get( - "run_manager" - ) try: results = [ - self._generate(m, stop=stop, run_manager=run_manager, **kwargs) - if new_arg_supported - else self._generate(m, stop=stop) + self._generate_with_cache( + m, stop=stop, run_manager=run_manager, **kwargs + ) for m in messages ] except (KeyboardInterrupt, Exception) as e: @@ -118,8 +133,7 @@ class BaseChatModel(BaseLanguageModel, ABC): **kwargs: Any, ) -> LLMResult: """Top Level call""" - params = self.dict() - params["stop"] = stop + params = self._get_invocation_params(stop=stop) options = {"stop": stop} callback_manager = AsyncCallbackManager.configure( @@ -133,15 +147,12 @@ class BaseChatModel(BaseLanguageModel, ABC): dumpd(self), messages, invocation_params=params, options=options ) - new_arg_supported = inspect.signature(self._agenerate).parameters.get( - "run_manager" - ) try: results = await asyncio.gather( *[ - self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs) - if new_arg_supported - else self._agenerate(m, stop=stop) + self._agenerate_with_cache( + m, stop=stop, run_manager=run_manager, **kwargs + ) for m in messages ] ) @@ -178,6 +189,84 @@ class BaseChatModel(BaseLanguageModel, ABC): prompt_messages, stop=stop, callbacks=callbacks, **kwargs ) + def _generate_with_cache( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + new_arg_supported = inspect.signature(self._generate).parameters.get( + "run_manager" + ) + disregard_cache = self.cache is not None and not self.cache + if langchain.llm_cache is None or disregard_cache: + # This happens when langchain.cache is None, but self.cache is True + if self.cache is not None and self.cache: + raise ValueError( + "Asked to cache, but no cache found at `langchain.cache`." + ) + if new_arg_supported: + return self._generate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + else: + return self._generate(messages, stop=stop, **kwargs) + else: + llm_string = self._get_llm_string(stop=stop, **kwargs) + prompt = dumps(messages) + cache_val = langchain.llm_cache.lookup(prompt, llm_string) + if isinstance(cache_val, list): + return ChatResult(generations=cache_val) + else: + if new_arg_supported: + result = self._generate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + else: + result = self._generate(messages, stop=stop, **kwargs) + langchain.llm_cache.update(prompt, llm_string, result.generations) + return result + + async def _agenerate_with_cache( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + new_arg_supported = inspect.signature(self._agenerate).parameters.get( + "run_manager" + ) + disregard_cache = self.cache is not None and not self.cache + if langchain.llm_cache is None or disregard_cache: + # This happens when langchain.cache is None, but self.cache is True + if self.cache is not None and self.cache: + raise ValueError( + "Asked to cache, but no cache found at `langchain.cache`." + ) + if new_arg_supported: + return await self._agenerate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + else: + return await self._agenerate(messages, stop=stop, **kwargs) + else: + llm_string = self._get_llm_string(stop=stop, **kwargs) + prompt = dumps(messages) + cache_val = langchain.llm_cache.lookup(prompt, llm_string) + if isinstance(cache_val, list): + return ChatResult(generations=cache_val) + else: + if new_arg_supported: + result = await self._agenerate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + else: + result = await self._agenerate(messages, stop=stop, **kwargs) + langchain.llm_cache.update(prompt, llm_string, result.generations) + return result + @abstractmethod def _generate( self, diff --git a/langchain/chat_models/fake.py b/langchain/chat_models/fake.py new file mode 100644 index 0000000000..0149b1ce72 --- /dev/null +++ b/langchain/chat_models/fake.py @@ -0,0 +1,33 @@ +"""Fake ChatModel for testing purposes.""" +from typing import Any, List, Mapping, Optional + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.chat_models.base import SimpleChatModel +from langchain.schema import BaseMessage + + +class FakeListChatModel(SimpleChatModel): + """Fake ChatModel for testing purposes.""" + + responses: List + i: int = 0 + + @property + def _llm_type(self) -> str: + return "fake-list-chat-model" + + def _call( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """First try to lookup in queries, else return 'foo' or 'bar'.""" + response = self.responses[self.i] + self.i += 1 + return response + + @property + def _identifying_params(self) -> Mapping[str, Any]: + return {"responses": self.responses} diff --git a/langchain/schema.py b/langchain/schema.py index b678d10cfd..88baff3c20 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -76,6 +76,11 @@ class Generation(Serializable): """May include things like reason for finishing (e.g. in OpenAI)""" # TODO: add log probs + @property + def lc_serializable(self) -> bool: + """This class is LangChain serializable.""" + return True + class BaseMessage(Serializable): """Message object.""" @@ -88,6 +93,11 @@ class BaseMessage(Serializable): def type(self) -> str: """Type of the message, used for serialization.""" + @property + def lc_serializable(self) -> bool: + """This class is LangChain serializable.""" + return True + class HumanMessage(BaseMessage): """Type of message that is spoken by the human.""" diff --git a/tests/unit_tests/agents/test_react.py b/tests/unit_tests/agents/test_react.py index ca3ffa2517..d61e50db68 100644 --- a/tests/unit_tests/agents/test_react.py +++ b/tests/unit_tests/agents/test_react.py @@ -1,13 +1,12 @@ """Unit tests for ReAct.""" -from typing import Any, List, Mapping, Optional, Union +from typing import Union from langchain.agents.react.base import ReActChain, ReActDocstoreAgent from langchain.agents.tools import Tool -from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.docstore.base import Docstore from langchain.docstore.document import Document -from langchain.llms.base import LLM +from langchain.llms.fake import FakeListLLM from langchain.prompts.prompt import PromptTemplate from langchain.schema import AgentAction @@ -22,33 +21,6 @@ Made in 2022.""" _FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}") -class FakeListLLM(LLM): - """Fake LLM for testing that outputs elements of a list.""" - - responses: List[str] - i: int = -1 - - @property - def _llm_type(self) -> str: - """Return type of llm.""" - return "fake_list" - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - """Increment counter, and then return response in that index.""" - self.i += 1 - return self.responses[self.i] - - @property - def _identifying_params(self) -> Mapping[str, Any]: - return {} - - class FakeDocstore(Docstore): """Fake docstore for testing purposes.""" diff --git a/tests/unit_tests/llms/test_callbacks.py b/tests/unit_tests/llms/test_callbacks.py index 8d0f848799..72babf3a2a 100644 --- a/tests/unit_tests/llms/test_callbacks.py +++ b/tests/unit_tests/llms/test_callbacks.py @@ -1,17 +1,17 @@ """Test LLM callbacks.""" +from langchain.chat_models.fake import FakeListChatModel +from langchain.llms.fake import FakeListLLM from langchain.schema import HumanMessage from tests.unit_tests.callbacks.fake_callback_handler import ( FakeCallbackHandler, FakeCallbackHandlerWithChatStart, ) -from tests.unit_tests.llms.fake_chat_model import FakeChatModel -from tests.unit_tests.llms.fake_llm import FakeLLM def test_llm_with_callbacks() -> None: """Test LLM callbacks.""" handler = FakeCallbackHandler() - llm = FakeLLM(callbacks=[handler], verbose=True) + llm = FakeListLLM(callbacks=[handler], verbose=True, responses=["foo"]) output = llm("foo") assert output == "foo" assert handler.starts == 1 @@ -22,7 +22,9 @@ def test_llm_with_callbacks() -> None: def test_chat_model_with_v1_callbacks() -> None: """Test chat model callbacks fall back to on_llm_start.""" handler = FakeCallbackHandler() - llm = FakeChatModel(callbacks=[handler], verbose=True) + llm = FakeListChatModel( + callbacks=[handler], verbose=True, responses=["fake response"] + ) output = llm([HumanMessage(content="foo")]) assert output.content == "fake response" assert handler.starts == 1 @@ -35,7 +37,9 @@ def test_chat_model_with_v1_callbacks() -> None: def test_chat_model_with_v2_callbacks() -> None: """Test chat model callbacks fall back to on_llm_start.""" handler = FakeCallbackHandlerWithChatStart() - llm = FakeChatModel(callbacks=[handler], verbose=True) + llm = FakeListChatModel( + callbacks=[handler], verbose=True, responses=["fake response"] + ) output = llm([HumanMessage(content="foo")]) assert output.content == "fake response" assert handler.starts == 1 diff --git a/tests/unit_tests/test_cache.py b/tests/unit_tests/test_cache.py new file mode 100644 index 0000000000..0cbe324e1f --- /dev/null +++ b/tests/unit_tests/test_cache.py @@ -0,0 +1,146 @@ +"""Test caching for LLMs and ChatModels.""" +from typing import Dict, Generator, List, Union + +import pytest +from _pytest.fixtures import FixtureRequest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +import langchain +from langchain.cache import ( + InMemoryCache, + SQLAlchemyCache, +) +from langchain.chat_models import FakeListChatModel +from langchain.chat_models.base import BaseChatModel, dumps +from langchain.llms import FakeListLLM +from langchain.llms.base import BaseLLM +from langchain.schema import ( + AIMessage, + BaseMessage, + ChatGeneration, + Generation, + HumanMessage, +) + + +def get_sqlite_cache() -> SQLAlchemyCache: + return SQLAlchemyCache(engine=create_engine("sqlite://")) + + +CACHE_OPTIONS = [ + InMemoryCache, + get_sqlite_cache, +] + + +@pytest.fixture(autouse=True, params=CACHE_OPTIONS) +def set_cache_and_teardown(request: FixtureRequest) -> Generator[None, None, None]: + # Will be run before each test + cache_instance = request.param + langchain.llm_cache = cache_instance() + if langchain.llm_cache: + langchain.llm_cache.clear() + else: + raise ValueError("Cache not set. This should never happen.") + + yield + + # Will be run after each test + if langchain.llm_cache: + langchain.llm_cache.clear() + else: + raise ValueError("Cache not set. This should never happen.") + + +def test_llm_caching() -> None: + prompt = "How are you?" + response = "Test response" + cached_response = "Cached test response" + llm = FakeListLLM(responses=[response]) + if langchain.llm_cache: + langchain.llm_cache.update( + prompt=prompt, + llm_string=create_llm_string(llm), + return_val=[Generation(text=cached_response)], + ) + assert llm(prompt) == cached_response + else: + raise ValueError( + "The cache not set. This should never happen, as the pytest fixture " + "`set_cache_and_teardown` always sets the cache." + ) + + +def test_old_sqlite_llm_caching() -> None: + if isinstance(langchain.llm_cache, SQLAlchemyCache): + prompt = "How are you?" + response = "Test response" + cached_response = "Cached test response" + llm = FakeListLLM(responses=[response]) + items = [ + langchain.llm_cache.cache_schema( + prompt=prompt, + llm=create_llm_string(llm), + response=cached_response, + idx=0, + ) + ] + with Session(langchain.llm_cache.engine) as session, session.begin(): + for item in items: + session.merge(item) + assert llm(prompt) == cached_response + + +def test_chat_model_caching() -> None: + prompt: List[BaseMessage] = [HumanMessage(content="How are you?")] + response = "Test response" + cached_response = "Cached test response" + cached_message = AIMessage(content=cached_response) + llm = FakeListChatModel(responses=[response]) + if langchain.llm_cache: + langchain.llm_cache.update( + prompt=dumps(prompt), + llm_string=llm._get_llm_string(), + return_val=[ChatGeneration(message=cached_message)], + ) + result = llm(prompt) + assert isinstance(result, AIMessage) + assert result.content == cached_response + else: + raise ValueError( + "The cache not set. This should never happen, as the pytest fixture " + "`set_cache_and_teardown` always sets the cache." + ) + + +def test_chat_model_caching_params() -> None: + prompt: List[BaseMessage] = [HumanMessage(content="How are you?")] + response = "Test response" + cached_response = "Cached test response" + cached_message = AIMessage(content=cached_response) + llm = FakeListChatModel(responses=[response]) + if langchain.llm_cache: + langchain.llm_cache.update( + prompt=dumps(prompt), + llm_string=llm._get_llm_string(functions=[]), + return_val=[ChatGeneration(message=cached_message)], + ) + result = llm(prompt, functions=[]) + assert isinstance(result, AIMessage) + assert result.content == cached_response + result_no_params = llm(prompt) + assert isinstance(result_no_params, AIMessage) + assert result_no_params.content == response + + else: + raise ValueError( + "The cache not set. This should never happen, as the pytest fixture " + "`set_cache_and_teardown` always sets the cache." + ) + + +def create_llm_string(llm: Union[BaseLLM, BaseChatModel]) -> str: + _dict: Dict = llm.dict() + _dict["stop"] = None + return str(sorted([(k, v) for k, v in _dict.items()]))