diff --git a/docs/docs_skeleton/docs/modules/model_io/models/chat/how_to/chat_model_caching.mdx b/docs/docs_skeleton/docs/modules/model_io/models/chat/how_to/chat_model_caching.mdx
new file mode 100644
index 0000000000..c34cb22326
--- /dev/null
+++ b/docs/docs_skeleton/docs/modules/model_io/models/chat/how_to/chat_model_caching.mdx
@@ -0,0 +1,9 @@
+# Caching
+LangChain provides an optional caching layer for Chat Models. This is useful for two reasons:
+
+It can save you money by reducing the number of API calls you make to the LLM provider, if you're often requesting the same completion multiple times.
+It can speed up your application by reducing the number of API calls you make to the LLM provider.
+
+import CachingChat from "@snippets/modules/model_io/models/chat/how_to/chat_model_caching.mdx"
+
+
diff --git a/docs/snippets/modules/model_io/models/chat/how_to/chat_model_caching.mdx b/docs/snippets/modules/model_io/models/chat/how_to/chat_model_caching.mdx
new file mode 100644
index 0000000000..580ca37488
--- /dev/null
+++ b/docs/snippets/modules/model_io/models/chat/how_to/chat_model_caching.mdx
@@ -0,0 +1,97 @@
+```python
+import langchain
+from langchain.chat_models import ChatOpenAI
+
+llm = ChatOpenAI()
+```
+
+## In Memory Cache
+
+
+```python
+from langchain.cache import InMemoryCache
+langchain.llm_cache = InMemoryCache()
+
+# The first time, it is not yet in cache, so it should take longer
+llm.predict("Tell me a joke")
+```
+
+
+
+```
+ CPU times: user 35.9 ms, sys: 28.6 ms, total: 64.6 ms
+ Wall time: 4.83 s
+
+
+ "\n\nWhy couldn't the bicycle stand up by itself? It was...two tired!"
+```
+
+
+
+
+```python
+# The second time it is, so it goes faster
+llm.predict("Tell me a joke")
+```
+
+
+
+```
+ CPU times: user 238 µs, sys: 143 µs, total: 381 µs
+ Wall time: 1.76 ms
+
+
+ '\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
+```
+
+
+
+## SQLite Cache
+
+
+```bash
+rm .langchain.db
+```
+
+
+```python
+# We can do the same thing with a SQLite cache
+from langchain.cache import SQLiteCache
+langchain.llm_cache = SQLiteCache(database_path=".langchain.db")
+```
+
+
+```python
+# The first time, it is not yet in cache, so it should take longer
+llm.predict("Tell me a joke")
+```
+
+
+
+```
+ CPU times: user 17 ms, sys: 9.76 ms, total: 26.7 ms
+ Wall time: 825 ms
+
+
+ '\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
+```
+
+
+
+
+```python
+# The second time it is, so it goes faster
+llm.predict("Tell me a joke")
+```
+
+
+
+```
+ CPU times: user 2.46 ms, sys: 1.23 ms, total: 3.7 ms
+ Wall time: 2.67 ms
+
+
+ '\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
+```
+
+
diff --git a/docs/snippets/modules/model_io/models/llms/how_to/llm_caching.mdx b/docs/snippets/modules/model_io/models/llms/how_to/llm_caching.mdx
index daa43c560f..5bb436ff82 100644
--- a/docs/snippets/modules/model_io/models/llms/how_to/llm_caching.mdx
+++ b/docs/snippets/modules/model_io/models/llms/how_to/llm_caching.mdx
@@ -14,7 +14,7 @@ from langchain.cache import InMemoryCache
langchain.llm_cache = InMemoryCache()
# The first time, it is not yet in cache, so it should take longer
-llm("Tell me a joke")
+llm.predict("Tell me a joke")
```
@@ -32,7 +32,7 @@ llm("Tell me a joke")
```python
# The second time it is, so it goes faster
-llm("Tell me a joke")
+llm.predict("Tell me a joke")
```
@@ -64,7 +64,7 @@ langchain.llm_cache = SQLiteCache(database_path=".langchain.db")
```python
# The first time, it is not yet in cache, so it should take longer
-llm("Tell me a joke")
+llm.predict("Tell me a joke")
```
@@ -82,7 +82,7 @@ llm("Tell me a joke")
```python
# The second time it is, so it goes faster
-llm("Tell me a joke")
+llm.predict("Tell me a joke")
```
diff --git a/langchain/cache.py b/langchain/cache.py
index 857e3dc5a6..2cfb7ff372 100644
--- a/langchain/cache.py
+++ b/langchain/cache.py
@@ -4,6 +4,7 @@ from __future__ import annotations
import hashlib
import inspect
import json
+import logging
from abc import ABC, abstractmethod
from datetime import timedelta
from typing import (
@@ -11,8 +12,8 @@ from typing import (
Any,
Callable,
Dict,
- List,
Optional,
+ Sequence,
Tuple,
Type,
Union,
@@ -31,13 +32,17 @@ except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain.embeddings.base import Embeddings
+from langchain.load.dump import dumps
+from langchain.load.load import loads
from langchain.schema import Generation
from langchain.vectorstores.redis import Redis as RedisVectorstore
+logger = logging.getLogger(__file__)
+
if TYPE_CHECKING:
import momento
-RETURN_VAL_TYPE = List[Generation]
+RETURN_VAL_TYPE = Sequence[Generation]
def _hash(_input: str) -> str:
@@ -147,13 +152,24 @@ class SQLAlchemyCache(BaseCache):
with Session(self.engine) as session:
rows = session.execute(stmt).fetchall()
if rows:
- return [Generation(text=row[0]) for row in rows]
+ try:
+ return [loads(row[0]) for row in rows]
+ except Exception:
+ logger.warning(
+ "Retrieving a cache value that could not be deserialized "
+ "properly. This is likely due to the cache being in an "
+ "older format. Please recreate your cache to avoid this "
+ "error."
+ )
+ # In a previous life we stored the raw text directly
+ # in the table, so assume it's in that format.
+ return [Generation(text=row[0]) for row in rows]
return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update based on prompt and llm_string."""
items = [
- self.cache_schema(prompt=prompt, llm=llm_string, response=gen.text, idx=i)
+ self.cache_schema(prompt=prompt, llm=llm_string, response=dumps(gen), idx=i)
for i, gen in enumerate(return_val)
]
with Session(self.engine) as session, session.begin():
@@ -163,7 +179,7 @@ class SQLAlchemyCache(BaseCache):
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
with Session(self.engine) as session:
- session.execute(self.cache_schema.delete())
+ session.query(self.cache_schema).delete()
class SQLiteCache(SQLAlchemyCache):
@@ -209,6 +225,12 @@ class RedisCache(BaseCache):
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
+ for gen in return_val:
+ if not isinstance(return_val, Generation):
+ raise ValueError(
+ "RedisCache only supports caching of normal LLM generations, "
+ f"got {type(gen)}"
+ )
# Write to a Redis HASH
key = self._key(prompt, llm_string)
self.redis.hset(
@@ -314,6 +336,12 @@ class RedisSemanticCache(BaseCache):
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
+ for gen in return_val:
+ if not isinstance(return_val, Generation):
+ raise ValueError(
+ "RedisSemanticCache only supports caching of "
+ f"normal LLM generations, got {type(gen)}"
+ )
llm_cache = self._get_llm_cache(llm_string)
# Write to vectorstore
metadata = {
@@ -426,6 +454,12 @@ class GPTCache(BaseCache):
First, retrieve the corresponding cache object using the `llm_string` parameter,
and then store the `prompt` and `return_val` in the cache object.
"""
+ for gen in return_val:
+ if not isinstance(return_val, Generation):
+ raise ValueError(
+ "GPTCache only supports caching of normal LLM generations, "
+ f"got {type(gen)}"
+ )
from gptcache.adapter.api import put
_gptcache = self._get_gptcache(llm_string)
@@ -567,7 +601,7 @@ class MomentoCache(BaseCache):
"""
from momento.responses import CacheGet
- generations = []
+ generations: RETURN_VAL_TYPE = []
get_response = self.cache_client.get(
self.cache_name, self.__key(prompt, llm_string)
@@ -593,6 +627,12 @@ class MomentoCache(BaseCache):
SdkException: Momento service or network error
Exception: Unexpected response
"""
+ for gen in return_val:
+ if not isinstance(return_val, Generation):
+ raise ValueError(
+ "Momento only supports caching of normal LLM generations, "
+ f"got {type(gen)}"
+ )
key = self.__key(prompt, llm_string)
value = _dump_generations_to_json(return_val)
set_response = self.cache_client.set(self.cache_name, key, value, self.ttl)
diff --git a/langchain/chat_models/__init__.py b/langchain/chat_models/__init__.py
index 0c048ec3f3..39952027df 100644
--- a/langchain/chat_models/__init__.py
+++ b/langchain/chat_models/__init__.py
@@ -1,5 +1,6 @@
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.chat_models.azure_openai import AzureChatOpenAI
+from langchain.chat_models.fake import FakeListChatModel
from langchain.chat_models.google_palm import ChatGooglePalm
from langchain.chat_models.openai import ChatOpenAI
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
@@ -8,6 +9,7 @@ from langchain.chat_models.vertexai import ChatVertexAI
__all__ = [
"ChatOpenAI",
"AzureChatOpenAI",
+ "FakeListChatModel",
"PromptLayerChatOpenAI",
"ChatAnthropic",
"ChatGooglePalm",
diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py
index 5fa949ef0a..c4bb81305a 100644
--- a/langchain/chat_models/base.py
+++ b/langchain/chat_models/base.py
@@ -17,7 +17,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
Callbacks,
)
-from langchain.load.dump import dumpd
+from langchain.load.dump import dumpd, dumps
from langchain.schema import (
AIMessage,
BaseMessage,
@@ -35,6 +35,7 @@ def _get_verbosity() -> bool:
class BaseChatModel(BaseLanguageModel, ABC):
+ cache: Optional[bool] = None
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
@@ -61,6 +62,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return {}
+ def _get_invocation_params(
+ self,
+ stop: Optional[List[str]] = None,
+ ) -> dict:
+ params = self.dict()
+ params["stop"] = stop
+ return params
+
+ def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
+ if self.lc_serializable:
+ params = {**kwargs, **{"stop": stop}}
+ param_string = str(sorted([(k, v) for k, v in params.items()]))
+ llm_string = dumps(self)
+ return llm_string + "---" + param_string
+ else:
+ params = self._get_invocation_params(stop=stop)
+ params = {**params, **kwargs}
+ return str(sorted([(k, v) for k, v in params.items()]))
+
def generate(
self,
messages: List[List[BaseMessage]],
@@ -71,9 +91,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
-
- params = self.dict()
- params["stop"] = stop
+ params = self._get_invocation_params(stop=stop)
options = {"stop": stop}
callback_manager = CallbackManager.configure(
@@ -87,14 +105,11 @@ class BaseChatModel(BaseLanguageModel, ABC):
dumpd(self), messages, invocation_params=params, options=options
)
- new_arg_supported = inspect.signature(self._generate).parameters.get(
- "run_manager"
- )
try:
results = [
- self._generate(m, stop=stop, run_manager=run_manager, **kwargs)
- if new_arg_supported
- else self._generate(m, stop=stop)
+ self._generate_with_cache(
+ m, stop=stop, run_manager=run_manager, **kwargs
+ )
for m in messages
]
except (KeyboardInterrupt, Exception) as e:
@@ -118,8 +133,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
- params = self.dict()
- params["stop"] = stop
+ params = self._get_invocation_params(stop=stop)
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
@@ -133,15 +147,12 @@ class BaseChatModel(BaseLanguageModel, ABC):
dumpd(self), messages, invocation_params=params, options=options
)
- new_arg_supported = inspect.signature(self._agenerate).parameters.get(
- "run_manager"
- )
try:
results = await asyncio.gather(
*[
- self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs)
- if new_arg_supported
- else self._agenerate(m, stop=stop)
+ self._agenerate_with_cache(
+ m, stop=stop, run_manager=run_manager, **kwargs
+ )
for m in messages
]
)
@@ -178,6 +189,84 @@ class BaseChatModel(BaseLanguageModel, ABC):
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
)
+ def _generate_with_cache(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ new_arg_supported = inspect.signature(self._generate).parameters.get(
+ "run_manager"
+ )
+ disregard_cache = self.cache is not None and not self.cache
+ if langchain.llm_cache is None or disregard_cache:
+ # This happens when langchain.cache is None, but self.cache is True
+ if self.cache is not None and self.cache:
+ raise ValueError(
+ "Asked to cache, but no cache found at `langchain.cache`."
+ )
+ if new_arg_supported:
+ return self._generate(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ else:
+ return self._generate(messages, stop=stop, **kwargs)
+ else:
+ llm_string = self._get_llm_string(stop=stop, **kwargs)
+ prompt = dumps(messages)
+ cache_val = langchain.llm_cache.lookup(prompt, llm_string)
+ if isinstance(cache_val, list):
+ return ChatResult(generations=cache_val)
+ else:
+ if new_arg_supported:
+ result = self._generate(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ else:
+ result = self._generate(messages, stop=stop, **kwargs)
+ langchain.llm_cache.update(prompt, llm_string, result.generations)
+ return result
+
+ async def _agenerate_with_cache(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ new_arg_supported = inspect.signature(self._agenerate).parameters.get(
+ "run_manager"
+ )
+ disregard_cache = self.cache is not None and not self.cache
+ if langchain.llm_cache is None or disregard_cache:
+ # This happens when langchain.cache is None, but self.cache is True
+ if self.cache is not None and self.cache:
+ raise ValueError(
+ "Asked to cache, but no cache found at `langchain.cache`."
+ )
+ if new_arg_supported:
+ return await self._agenerate(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ else:
+ return await self._agenerate(messages, stop=stop, **kwargs)
+ else:
+ llm_string = self._get_llm_string(stop=stop, **kwargs)
+ prompt = dumps(messages)
+ cache_val = langchain.llm_cache.lookup(prompt, llm_string)
+ if isinstance(cache_val, list):
+ return ChatResult(generations=cache_val)
+ else:
+ if new_arg_supported:
+ result = await self._agenerate(
+ messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ else:
+ result = await self._agenerate(messages, stop=stop, **kwargs)
+ langchain.llm_cache.update(prompt, llm_string, result.generations)
+ return result
+
@abstractmethod
def _generate(
self,
diff --git a/langchain/chat_models/fake.py b/langchain/chat_models/fake.py
new file mode 100644
index 0000000000..0149b1ce72
--- /dev/null
+++ b/langchain/chat_models/fake.py
@@ -0,0 +1,33 @@
+"""Fake ChatModel for testing purposes."""
+from typing import Any, List, Mapping, Optional
+
+from langchain.callbacks.manager import CallbackManagerForLLMRun
+from langchain.chat_models.base import SimpleChatModel
+from langchain.schema import BaseMessage
+
+
+class FakeListChatModel(SimpleChatModel):
+ """Fake ChatModel for testing purposes."""
+
+ responses: List
+ i: int = 0
+
+ @property
+ def _llm_type(self) -> str:
+ return "fake-list-chat-model"
+
+ def _call(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """First try to lookup in queries, else return 'foo' or 'bar'."""
+ response = self.responses[self.i]
+ self.i += 1
+ return response
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ return {"responses": self.responses}
diff --git a/langchain/schema.py b/langchain/schema.py
index b678d10cfd..88baff3c20 100644
--- a/langchain/schema.py
+++ b/langchain/schema.py
@@ -76,6 +76,11 @@ class Generation(Serializable):
"""May include things like reason for finishing (e.g. in OpenAI)"""
# TODO: add log probs
+ @property
+ def lc_serializable(self) -> bool:
+ """This class is LangChain serializable."""
+ return True
+
class BaseMessage(Serializable):
"""Message object."""
@@ -88,6 +93,11 @@ class BaseMessage(Serializable):
def type(self) -> str:
"""Type of the message, used for serialization."""
+ @property
+ def lc_serializable(self) -> bool:
+ """This class is LangChain serializable."""
+ return True
+
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
diff --git a/tests/unit_tests/agents/test_react.py b/tests/unit_tests/agents/test_react.py
index ca3ffa2517..d61e50db68 100644
--- a/tests/unit_tests/agents/test_react.py
+++ b/tests/unit_tests/agents/test_react.py
@@ -1,13 +1,12 @@
"""Unit tests for ReAct."""
-from typing import Any, List, Mapping, Optional, Union
+from typing import Union
from langchain.agents.react.base import ReActChain, ReActDocstoreAgent
from langchain.agents.tools import Tool
-from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
-from langchain.llms.base import LLM
+from langchain.llms.fake import FakeListLLM
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AgentAction
@@ -22,33 +21,6 @@ Made in 2022."""
_FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}")
-class FakeListLLM(LLM):
- """Fake LLM for testing that outputs elements of a list."""
-
- responses: List[str]
- i: int = -1
-
- @property
- def _llm_type(self) -> str:
- """Return type of llm."""
- return "fake_list"
-
- def _call(
- self,
- prompt: str,
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> str:
- """Increment counter, and then return response in that index."""
- self.i += 1
- return self.responses[self.i]
-
- @property
- def _identifying_params(self) -> Mapping[str, Any]:
- return {}
-
-
class FakeDocstore(Docstore):
"""Fake docstore for testing purposes."""
diff --git a/tests/unit_tests/llms/test_callbacks.py b/tests/unit_tests/llms/test_callbacks.py
index 8d0f848799..72babf3a2a 100644
--- a/tests/unit_tests/llms/test_callbacks.py
+++ b/tests/unit_tests/llms/test_callbacks.py
@@ -1,17 +1,17 @@
"""Test LLM callbacks."""
+from langchain.chat_models.fake import FakeListChatModel
+from langchain.llms.fake import FakeListLLM
from langchain.schema import HumanMessage
from tests.unit_tests.callbacks.fake_callback_handler import (
FakeCallbackHandler,
FakeCallbackHandlerWithChatStart,
)
-from tests.unit_tests.llms.fake_chat_model import FakeChatModel
-from tests.unit_tests.llms.fake_llm import FakeLLM
def test_llm_with_callbacks() -> None:
"""Test LLM callbacks."""
handler = FakeCallbackHandler()
- llm = FakeLLM(callbacks=[handler], verbose=True)
+ llm = FakeListLLM(callbacks=[handler], verbose=True, responses=["foo"])
output = llm("foo")
assert output == "foo"
assert handler.starts == 1
@@ -22,7 +22,9 @@ def test_llm_with_callbacks() -> None:
def test_chat_model_with_v1_callbacks() -> None:
"""Test chat model callbacks fall back to on_llm_start."""
handler = FakeCallbackHandler()
- llm = FakeChatModel(callbacks=[handler], verbose=True)
+ llm = FakeListChatModel(
+ callbacks=[handler], verbose=True, responses=["fake response"]
+ )
output = llm([HumanMessage(content="foo")])
assert output.content == "fake response"
assert handler.starts == 1
@@ -35,7 +37,9 @@ def test_chat_model_with_v1_callbacks() -> None:
def test_chat_model_with_v2_callbacks() -> None:
"""Test chat model callbacks fall back to on_llm_start."""
handler = FakeCallbackHandlerWithChatStart()
- llm = FakeChatModel(callbacks=[handler], verbose=True)
+ llm = FakeListChatModel(
+ callbacks=[handler], verbose=True, responses=["fake response"]
+ )
output = llm([HumanMessage(content="foo")])
assert output.content == "fake response"
assert handler.starts == 1
diff --git a/tests/unit_tests/test_cache.py b/tests/unit_tests/test_cache.py
new file mode 100644
index 0000000000..0cbe324e1f
--- /dev/null
+++ b/tests/unit_tests/test_cache.py
@@ -0,0 +1,146 @@
+"""Test caching for LLMs and ChatModels."""
+from typing import Dict, Generator, List, Union
+
+import pytest
+from _pytest.fixtures import FixtureRequest
+from sqlalchemy import create_engine
+from sqlalchemy.orm import Session
+
+import langchain
+from langchain.cache import (
+ InMemoryCache,
+ SQLAlchemyCache,
+)
+from langchain.chat_models import FakeListChatModel
+from langchain.chat_models.base import BaseChatModel, dumps
+from langchain.llms import FakeListLLM
+from langchain.llms.base import BaseLLM
+from langchain.schema import (
+ AIMessage,
+ BaseMessage,
+ ChatGeneration,
+ Generation,
+ HumanMessage,
+)
+
+
+def get_sqlite_cache() -> SQLAlchemyCache:
+ return SQLAlchemyCache(engine=create_engine("sqlite://"))
+
+
+CACHE_OPTIONS = [
+ InMemoryCache,
+ get_sqlite_cache,
+]
+
+
+@pytest.fixture(autouse=True, params=CACHE_OPTIONS)
+def set_cache_and_teardown(request: FixtureRequest) -> Generator[None, None, None]:
+ # Will be run before each test
+ cache_instance = request.param
+ langchain.llm_cache = cache_instance()
+ if langchain.llm_cache:
+ langchain.llm_cache.clear()
+ else:
+ raise ValueError("Cache not set. This should never happen.")
+
+ yield
+
+ # Will be run after each test
+ if langchain.llm_cache:
+ langchain.llm_cache.clear()
+ else:
+ raise ValueError("Cache not set. This should never happen.")
+
+
+def test_llm_caching() -> None:
+ prompt = "How are you?"
+ response = "Test response"
+ cached_response = "Cached test response"
+ llm = FakeListLLM(responses=[response])
+ if langchain.llm_cache:
+ langchain.llm_cache.update(
+ prompt=prompt,
+ llm_string=create_llm_string(llm),
+ return_val=[Generation(text=cached_response)],
+ )
+ assert llm(prompt) == cached_response
+ else:
+ raise ValueError(
+ "The cache not set. This should never happen, as the pytest fixture "
+ "`set_cache_and_teardown` always sets the cache."
+ )
+
+
+def test_old_sqlite_llm_caching() -> None:
+ if isinstance(langchain.llm_cache, SQLAlchemyCache):
+ prompt = "How are you?"
+ response = "Test response"
+ cached_response = "Cached test response"
+ llm = FakeListLLM(responses=[response])
+ items = [
+ langchain.llm_cache.cache_schema(
+ prompt=prompt,
+ llm=create_llm_string(llm),
+ response=cached_response,
+ idx=0,
+ )
+ ]
+ with Session(langchain.llm_cache.engine) as session, session.begin():
+ for item in items:
+ session.merge(item)
+ assert llm(prompt) == cached_response
+
+
+def test_chat_model_caching() -> None:
+ prompt: List[BaseMessage] = [HumanMessage(content="How are you?")]
+ response = "Test response"
+ cached_response = "Cached test response"
+ cached_message = AIMessage(content=cached_response)
+ llm = FakeListChatModel(responses=[response])
+ if langchain.llm_cache:
+ langchain.llm_cache.update(
+ prompt=dumps(prompt),
+ llm_string=llm._get_llm_string(),
+ return_val=[ChatGeneration(message=cached_message)],
+ )
+ result = llm(prompt)
+ assert isinstance(result, AIMessage)
+ assert result.content == cached_response
+ else:
+ raise ValueError(
+ "The cache not set. This should never happen, as the pytest fixture "
+ "`set_cache_and_teardown` always sets the cache."
+ )
+
+
+def test_chat_model_caching_params() -> None:
+ prompt: List[BaseMessage] = [HumanMessage(content="How are you?")]
+ response = "Test response"
+ cached_response = "Cached test response"
+ cached_message = AIMessage(content=cached_response)
+ llm = FakeListChatModel(responses=[response])
+ if langchain.llm_cache:
+ langchain.llm_cache.update(
+ prompt=dumps(prompt),
+ llm_string=llm._get_llm_string(functions=[]),
+ return_val=[ChatGeneration(message=cached_message)],
+ )
+ result = llm(prompt, functions=[])
+ assert isinstance(result, AIMessage)
+ assert result.content == cached_response
+ result_no_params = llm(prompt)
+ assert isinstance(result_no_params, AIMessage)
+ assert result_no_params.content == response
+
+ else:
+ raise ValueError(
+ "The cache not set. This should never happen, as the pytest fixture "
+ "`set_cache_and_teardown` always sets the cache."
+ )
+
+
+def create_llm_string(llm: Union[BaseLLM, BaseChatModel]) -> str:
+ _dict: Dict = llm.dict()
+ _dict["stop"] = None
+ return str(sorted([(k, v) for k, v in _dict.items()]))