Add caching to BaseChatModel (issue #1644) (#5089)

#  Add caching to BaseChatModel
Fixes #1644

(Sidenote: While testing, I noticed we have multiple implementations of
Fake LLMs, used for testing. I consolidated them.)

## Who can review?
Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:
Models
- @hwchase17
- @agola11

Twitter: [@UmerHAdil](https://twitter.com/@UmerHAdil) | Discord:
RicChilligerDude#7589

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
UmerHA 2023-06-24 20:45:09 +02:00 committed by GitHub
parent c289cc891a
commit 068142fce2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 465 additions and 63 deletions

View File

@ -0,0 +1,9 @@
# Caching
LangChain provides an optional caching layer for Chat Models. This is useful for two reasons:
It can save you money by reducing the number of API calls you make to the LLM provider, if you're often requesting the same completion multiple times.
It can speed up your application by reducing the number of API calls you make to the LLM provider.
import CachingChat from "@snippets/modules/model_io/models/chat/how_to/chat_model_caching.mdx"
<CachingChat/>

View File

@ -0,0 +1,97 @@
```python
import langchain
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI()
```
## In Memory Cache
```python
from langchain.cache import InMemoryCache
langchain.llm_cache = InMemoryCache()
# The first time, it is not yet in cache, so it should take longer
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">
```
CPU times: user 35.9 ms, sys: 28.6 ms, total: 64.6 ms
Wall time: 4.83 s
"\n\nWhy couldn't the bicycle stand up by itself? It was...two tired!"
```
</CodeOutputBlock>
```python
# The second time it is, so it goes faster
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">
```
CPU times: user 238 µs, sys: 143 µs, total: 381 µs
Wall time: 1.76 ms
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
```
</CodeOutputBlock>
## SQLite Cache
```bash
rm .langchain.db
```
```python
# We can do the same thing with a SQLite cache
from langchain.cache import SQLiteCache
langchain.llm_cache = SQLiteCache(database_path=".langchain.db")
```
```python
# The first time, it is not yet in cache, so it should take longer
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">
```
CPU times: user 17 ms, sys: 9.76 ms, total: 26.7 ms
Wall time: 825 ms
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
```
</CodeOutputBlock>
```python
# The second time it is, so it goes faster
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">
```
CPU times: user 2.46 ms, sys: 1.23 ms, total: 3.7 ms
Wall time: 2.67 ms
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
```
</CodeOutputBlock>

View File

@ -14,7 +14,7 @@ from langchain.cache import InMemoryCache
langchain.llm_cache = InMemoryCache()
# The first time, it is not yet in cache, so it should take longer
llm("Tell me a joke")
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">
@ -32,7 +32,7 @@ llm("Tell me a joke")
```python
# The second time it is, so it goes faster
llm("Tell me a joke")
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">
@ -64,7 +64,7 @@ langchain.llm_cache = SQLiteCache(database_path=".langchain.db")
```python
# The first time, it is not yet in cache, so it should take longer
llm("Tell me a joke")
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">
@ -82,7 +82,7 @@ llm("Tell me a joke")
```python
# The second time it is, so it goes faster
llm("Tell me a joke")
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import hashlib
import inspect
import json
import logging
from abc import ABC, abstractmethod
from datetime import timedelta
from typing import (
@ -11,8 +12,8 @@ from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
@ -31,13 +32,17 @@ except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain.embeddings.base import Embeddings
from langchain.load.dump import dumps
from langchain.load.load import loads
from langchain.schema import Generation
from langchain.vectorstores.redis import Redis as RedisVectorstore
logger = logging.getLogger(__file__)
if TYPE_CHECKING:
import momento
RETURN_VAL_TYPE = List[Generation]
RETURN_VAL_TYPE = Sequence[Generation]
def _hash(_input: str) -> str:
@ -147,13 +152,24 @@ class SQLAlchemyCache(BaseCache):
with Session(self.engine) as session:
rows = session.execute(stmt).fetchall()
if rows:
return [Generation(text=row[0]) for row in rows]
try:
return [loads(row[0]) for row in rows]
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
"properly. This is likely due to the cache being in an "
"older format. Please recreate your cache to avoid this "
"error."
)
# In a previous life we stored the raw text directly
# in the table, so assume it's in that format.
return [Generation(text=row[0]) for row in rows]
return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update based on prompt and llm_string."""
items = [
self.cache_schema(prompt=prompt, llm=llm_string, response=gen.text, idx=i)
self.cache_schema(prompt=prompt, llm=llm_string, response=dumps(gen), idx=i)
for i, gen in enumerate(return_val)
]
with Session(self.engine) as session, session.begin():
@ -163,7 +179,7 @@ class SQLAlchemyCache(BaseCache):
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
with Session(self.engine) as session:
session.execute(self.cache_schema.delete())
session.query(self.cache_schema).delete()
class SQLiteCache(SQLAlchemyCache):
@ -209,6 +225,12 @@ class RedisCache(BaseCache):
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for gen in return_val:
if not isinstance(return_val, Generation):
raise ValueError(
"RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
# Write to a Redis HASH
key = self._key(prompt, llm_string)
self.redis.hset(
@ -314,6 +336,12 @@ class RedisSemanticCache(BaseCache):
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for gen in return_val:
if not isinstance(return_val, Generation):
raise ValueError(
"RedisSemanticCache only supports caching of "
f"normal LLM generations, got {type(gen)}"
)
llm_cache = self._get_llm_cache(llm_string)
# Write to vectorstore
metadata = {
@ -426,6 +454,12 @@ class GPTCache(BaseCache):
First, retrieve the corresponding cache object using the `llm_string` parameter,
and then store the `prompt` and `return_val` in the cache object.
"""
for gen in return_val:
if not isinstance(return_val, Generation):
raise ValueError(
"GPTCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
from gptcache.adapter.api import put
_gptcache = self._get_gptcache(llm_string)
@ -567,7 +601,7 @@ class MomentoCache(BaseCache):
"""
from momento.responses import CacheGet
generations = []
generations: RETURN_VAL_TYPE = []
get_response = self.cache_client.get(
self.cache_name, self.__key(prompt, llm_string)
@ -593,6 +627,12 @@ class MomentoCache(BaseCache):
SdkException: Momento service or network error
Exception: Unexpected response
"""
for gen in return_val:
if not isinstance(return_val, Generation):
raise ValueError(
"Momento only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
key = self.__key(prompt, llm_string)
value = _dump_generations_to_json(return_val)
set_response = self.cache_client.set(self.cache_name, key, value, self.ttl)

View File

@ -1,5 +1,6 @@
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.chat_models.azure_openai import AzureChatOpenAI
from langchain.chat_models.fake import FakeListChatModel
from langchain.chat_models.google_palm import ChatGooglePalm
from langchain.chat_models.openai import ChatOpenAI
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
@ -8,6 +9,7 @@ from langchain.chat_models.vertexai import ChatVertexAI
__all__ = [
"ChatOpenAI",
"AzureChatOpenAI",
"FakeListChatModel",
"PromptLayerChatOpenAI",
"ChatAnthropic",
"ChatGooglePalm",

View File

@ -17,7 +17,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.load.dump import dumpd
from langchain.load.dump import dumpd, dumps
from langchain.schema import (
AIMessage,
BaseMessage,
@ -35,6 +35,7 @@ def _get_verbosity() -> bool:
class BaseChatModel(BaseLanguageModel, ABC):
cache: Optional[bool] = None
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
@ -61,6 +62,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return {}
def _get_invocation_params(
self,
stop: Optional[List[str]] = None,
) -> dict:
params = self.dict()
params["stop"] = stop
return params
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
if self.lc_serializable:
params = {**kwargs, **{"stop": stop}}
param_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = dumps(self)
return llm_string + "---" + param_string
else:
params = self._get_invocation_params(stop=stop)
params = {**params, **kwargs}
return str(sorted([(k, v) for k, v in params.items()]))
def generate(
self,
messages: List[List[BaseMessage]],
@ -71,9 +91,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
params = self.dict()
params["stop"] = stop
params = self._get_invocation_params(stop=stop)
options = {"stop": stop}
callback_manager = CallbackManager.configure(
@ -87,14 +105,11 @@ class BaseChatModel(BaseLanguageModel, ABC):
dumpd(self), messages, invocation_params=params, options=options
)
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
try:
results = [
self._generate(m, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported
else self._generate(m, stop=stop)
self._generate_with_cache(
m, stop=stop, run_manager=run_manager, **kwargs
)
for m in messages
]
except (KeyboardInterrupt, Exception) as e:
@ -118,8 +133,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
params = self.dict()
params["stop"] = stop
params = self._get_invocation_params(stop=stop)
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
@ -133,15 +147,12 @@ class BaseChatModel(BaseLanguageModel, ABC):
dumpd(self), messages, invocation_params=params, options=options
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
try:
results = await asyncio.gather(
*[
self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported
else self._agenerate(m, stop=stop)
self._agenerate_with_cache(
m, stop=stop, run_manager=run_manager, **kwargs
)
for m in messages
]
)
@ -178,6 +189,84 @@ class BaseChatModel(BaseLanguageModel, ABC):
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
)
def _generate_with_cache(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if new_arg_supported:
return self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
return self._generate(messages, stop=stop, **kwargs)
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
if new_arg_supported:
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
result = self._generate(messages, stop=stop, **kwargs)
langchain.llm_cache.update(prompt, llm_string, result.generations)
return result
async def _agenerate_with_cache(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if new_arg_supported:
return await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
return await self._agenerate(messages, stop=stop, **kwargs)
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
if new_arg_supported:
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
langchain.llm_cache.update(prompt, llm_string, result.generations)
return result
@abstractmethod
def _generate(
self,

View File

@ -0,0 +1,33 @@
"""Fake ChatModel for testing purposes."""
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import BaseMessage
class FakeListChatModel(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
responses: List
i: int = 0
@property
def _llm_type(self) -> str:
return "fake-list-chat-model"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
response = self.responses[self.i]
self.i += 1
return response
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"responses": self.responses}

View File

@ -76,6 +76,11 @@ class Generation(Serializable):
"""May include things like reason for finishing (e.g. in OpenAI)"""
# TODO: add log probs
@property
def lc_serializable(self) -> bool:
"""This class is LangChain serializable."""
return True
class BaseMessage(Serializable):
"""Message object."""
@ -88,6 +93,11 @@ class BaseMessage(Serializable):
def type(self) -> str:
"""Type of the message, used for serialization."""
@property
def lc_serializable(self) -> bool:
"""This class is LangChain serializable."""
return True
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""

View File

@ -1,13 +1,12 @@
"""Unit tests for ReAct."""
from typing import Any, List, Mapping, Optional, Union
from typing import Union
from langchain.agents.react.base import ReActChain, ReActDocstoreAgent
from langchain.agents.tools import Tool
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import LLM
from langchain.llms.fake import FakeListLLM
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AgentAction
@ -22,33 +21,6 @@ Made in 2022."""
_FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}")
class FakeListLLM(LLM):
"""Fake LLM for testing that outputs elements of a list."""
responses: List[str]
i: int = -1
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fake_list"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Increment counter, and then return response in that index."""
self.i += 1
return self.responses[self.i]
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}
class FakeDocstore(Docstore):
"""Fake docstore for testing purposes."""

View File

@ -1,17 +1,17 @@
"""Test LLM callbacks."""
from langchain.chat_models.fake import FakeListChatModel
from langchain.llms.fake import FakeListLLM
from langchain.schema import HumanMessage
from tests.unit_tests.callbacks.fake_callback_handler import (
FakeCallbackHandler,
FakeCallbackHandlerWithChatStart,
)
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_llm_with_callbacks() -> None:
"""Test LLM callbacks."""
handler = FakeCallbackHandler()
llm = FakeLLM(callbacks=[handler], verbose=True)
llm = FakeListLLM(callbacks=[handler], verbose=True, responses=["foo"])
output = llm("foo")
assert output == "foo"
assert handler.starts == 1
@ -22,7 +22,9 @@ def test_llm_with_callbacks() -> None:
def test_chat_model_with_v1_callbacks() -> None:
"""Test chat model callbacks fall back to on_llm_start."""
handler = FakeCallbackHandler()
llm = FakeChatModel(callbacks=[handler], verbose=True)
llm = FakeListChatModel(
callbacks=[handler], verbose=True, responses=["fake response"]
)
output = llm([HumanMessage(content="foo")])
assert output.content == "fake response"
assert handler.starts == 1
@ -35,7 +37,9 @@ def test_chat_model_with_v1_callbacks() -> None:
def test_chat_model_with_v2_callbacks() -> None:
"""Test chat model callbacks fall back to on_llm_start."""
handler = FakeCallbackHandlerWithChatStart()
llm = FakeChatModel(callbacks=[handler], verbose=True)
llm = FakeListChatModel(
callbacks=[handler], verbose=True, responses=["fake response"]
)
output = llm([HumanMessage(content="foo")])
assert output.content == "fake response"
assert handler.starts == 1

View File

@ -0,0 +1,146 @@
"""Test caching for LLMs and ChatModels."""
from typing import Dict, Generator, List, Union
import pytest
from _pytest.fixtures import FixtureRequest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
import langchain
from langchain.cache import (
InMemoryCache,
SQLAlchemyCache,
)
from langchain.chat_models import FakeListChatModel
from langchain.chat_models.base import BaseChatModel, dumps
from langchain.llms import FakeListLLM
from langchain.llms.base import BaseLLM
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
Generation,
HumanMessage,
)
def get_sqlite_cache() -> SQLAlchemyCache:
return SQLAlchemyCache(engine=create_engine("sqlite://"))
CACHE_OPTIONS = [
InMemoryCache,
get_sqlite_cache,
]
@pytest.fixture(autouse=True, params=CACHE_OPTIONS)
def set_cache_and_teardown(request: FixtureRequest) -> Generator[None, None, None]:
# Will be run before each test
cache_instance = request.param
langchain.llm_cache = cache_instance()
if langchain.llm_cache:
langchain.llm_cache.clear()
else:
raise ValueError("Cache not set. This should never happen.")
yield
# Will be run after each test
if langchain.llm_cache:
langchain.llm_cache.clear()
else:
raise ValueError("Cache not set. This should never happen.")
def test_llm_caching() -> None:
prompt = "How are you?"
response = "Test response"
cached_response = "Cached test response"
llm = FakeListLLM(responses=[response])
if langchain.llm_cache:
langchain.llm_cache.update(
prompt=prompt,
llm_string=create_llm_string(llm),
return_val=[Generation(text=cached_response)],
)
assert llm(prompt) == cached_response
else:
raise ValueError(
"The cache not set. This should never happen, as the pytest fixture "
"`set_cache_and_teardown` always sets the cache."
)
def test_old_sqlite_llm_caching() -> None:
if isinstance(langchain.llm_cache, SQLAlchemyCache):
prompt = "How are you?"
response = "Test response"
cached_response = "Cached test response"
llm = FakeListLLM(responses=[response])
items = [
langchain.llm_cache.cache_schema(
prompt=prompt,
llm=create_llm_string(llm),
response=cached_response,
idx=0,
)
]
with Session(langchain.llm_cache.engine) as session, session.begin():
for item in items:
session.merge(item)
assert llm(prompt) == cached_response
def test_chat_model_caching() -> None:
prompt: List[BaseMessage] = [HumanMessage(content="How are you?")]
response = "Test response"
cached_response = "Cached test response"
cached_message = AIMessage(content=cached_response)
llm = FakeListChatModel(responses=[response])
if langchain.llm_cache:
langchain.llm_cache.update(
prompt=dumps(prompt),
llm_string=llm._get_llm_string(),
return_val=[ChatGeneration(message=cached_message)],
)
result = llm(prompt)
assert isinstance(result, AIMessage)
assert result.content == cached_response
else:
raise ValueError(
"The cache not set. This should never happen, as the pytest fixture "
"`set_cache_and_teardown` always sets the cache."
)
def test_chat_model_caching_params() -> None:
prompt: List[BaseMessage] = [HumanMessage(content="How are you?")]
response = "Test response"
cached_response = "Cached test response"
cached_message = AIMessage(content=cached_response)
llm = FakeListChatModel(responses=[response])
if langchain.llm_cache:
langchain.llm_cache.update(
prompt=dumps(prompt),
llm_string=llm._get_llm_string(functions=[]),
return_val=[ChatGeneration(message=cached_message)],
)
result = llm(prompt, functions=[])
assert isinstance(result, AIMessage)
assert result.content == cached_response
result_no_params = llm(prompt)
assert isinstance(result_no_params, AIMessage)
assert result_no_params.content == response
else:
raise ValueError(
"The cache not set. This should never happen, as the pytest fixture "
"`set_cache_and_teardown` always sets the cache."
)
def create_llm_string(llm: Union[BaseLLM, BaseChatModel]) -> str:
_dict: Dict = llm.dict()
_dict["stop"] = None
return str(sorted([(k, v) for k, v in _dict.items()]))