mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
# Add caching to BaseChatModel Fixes #1644 (Sidenote: While testing, I noticed we have multiple implementations of Fake LLMs, used for testing. I consolidated them.) ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: Models - @hwchase17 - @agola11 Twitter: [@UmerHAdil](https://twitter.com/@UmerHAdil) | Discord: RicChilligerDude#7589 --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
c289cc891a
commit
068142fce2
@ -0,0 +1,9 @@
|
|||||||
|
# Caching
|
||||||
|
LangChain provides an optional caching layer for Chat Models. This is useful for two reasons:
|
||||||
|
|
||||||
|
It can save you money by reducing the number of API calls you make to the LLM provider, if you're often requesting the same completion multiple times.
|
||||||
|
It can speed up your application by reducing the number of API calls you make to the LLM provider.
|
||||||
|
|
||||||
|
import CachingChat from "@snippets/modules/model_io/models/chat/how_to/chat_model_caching.mdx"
|
||||||
|
|
||||||
|
<CachingChat/>
|
@ -0,0 +1,97 @@
|
|||||||
|
```python
|
||||||
|
import langchain
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
|
||||||
|
llm = ChatOpenAI()
|
||||||
|
```
|
||||||
|
|
||||||
|
## In Memory Cache
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.cache import InMemoryCache
|
||||||
|
langchain.llm_cache = InMemoryCache()
|
||||||
|
|
||||||
|
# The first time, it is not yet in cache, so it should take longer
|
||||||
|
llm.predict("Tell me a joke")
|
||||||
|
```
|
||||||
|
|
||||||
|
<CodeOutputBlock lang="python">
|
||||||
|
|
||||||
|
```
|
||||||
|
CPU times: user 35.9 ms, sys: 28.6 ms, total: 64.6 ms
|
||||||
|
Wall time: 4.83 s
|
||||||
|
|
||||||
|
|
||||||
|
"\n\nWhy couldn't the bicycle stand up by itself? It was...two tired!"
|
||||||
|
```
|
||||||
|
|
||||||
|
</CodeOutputBlock>
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
# The second time it is, so it goes faster
|
||||||
|
llm.predict("Tell me a joke")
|
||||||
|
```
|
||||||
|
|
||||||
|
<CodeOutputBlock lang="python">
|
||||||
|
|
||||||
|
```
|
||||||
|
CPU times: user 238 µs, sys: 143 µs, total: 381 µs
|
||||||
|
Wall time: 1.76 ms
|
||||||
|
|
||||||
|
|
||||||
|
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
|
||||||
|
```
|
||||||
|
|
||||||
|
</CodeOutputBlock>
|
||||||
|
|
||||||
|
## SQLite Cache
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rm .langchain.db
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
# We can do the same thing with a SQLite cache
|
||||||
|
from langchain.cache import SQLiteCache
|
||||||
|
langchain.llm_cache = SQLiteCache(database_path=".langchain.db")
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
# The first time, it is not yet in cache, so it should take longer
|
||||||
|
llm.predict("Tell me a joke")
|
||||||
|
```
|
||||||
|
|
||||||
|
<CodeOutputBlock lang="python">
|
||||||
|
|
||||||
|
```
|
||||||
|
CPU times: user 17 ms, sys: 9.76 ms, total: 26.7 ms
|
||||||
|
Wall time: 825 ms
|
||||||
|
|
||||||
|
|
||||||
|
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
|
||||||
|
```
|
||||||
|
|
||||||
|
</CodeOutputBlock>
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
# The second time it is, so it goes faster
|
||||||
|
llm.predict("Tell me a joke")
|
||||||
|
```
|
||||||
|
|
||||||
|
<CodeOutputBlock lang="python">
|
||||||
|
|
||||||
|
```
|
||||||
|
CPU times: user 2.46 ms, sys: 1.23 ms, total: 3.7 ms
|
||||||
|
Wall time: 2.67 ms
|
||||||
|
|
||||||
|
|
||||||
|
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
|
||||||
|
```
|
||||||
|
|
||||||
|
</CodeOutputBlock>
|
@ -14,7 +14,7 @@ from langchain.cache import InMemoryCache
|
|||||||
langchain.llm_cache = InMemoryCache()
|
langchain.llm_cache = InMemoryCache()
|
||||||
|
|
||||||
# The first time, it is not yet in cache, so it should take longer
|
# The first time, it is not yet in cache, so it should take longer
|
||||||
llm("Tell me a joke")
|
llm.predict("Tell me a joke")
|
||||||
```
|
```
|
||||||
|
|
||||||
<CodeOutputBlock lang="python">
|
<CodeOutputBlock lang="python">
|
||||||
@ -32,7 +32,7 @@ llm("Tell me a joke")
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
# The second time it is, so it goes faster
|
# The second time it is, so it goes faster
|
||||||
llm("Tell me a joke")
|
llm.predict("Tell me a joke")
|
||||||
```
|
```
|
||||||
|
|
||||||
<CodeOutputBlock lang="python">
|
<CodeOutputBlock lang="python">
|
||||||
@ -64,7 +64,7 @@ langchain.llm_cache = SQLiteCache(database_path=".langchain.db")
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
# The first time, it is not yet in cache, so it should take longer
|
# The first time, it is not yet in cache, so it should take longer
|
||||||
llm("Tell me a joke")
|
llm.predict("Tell me a joke")
|
||||||
```
|
```
|
||||||
|
|
||||||
<CodeOutputBlock lang="python">
|
<CodeOutputBlock lang="python">
|
||||||
@ -82,7 +82,7 @@ llm("Tell me a joke")
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
# The second time it is, so it goes faster
|
# The second time it is, so it goes faster
|
||||||
llm("Tell me a joke")
|
llm.predict("Tell me a joke")
|
||||||
```
|
```
|
||||||
|
|
||||||
<CodeOutputBlock lang="python">
|
<CodeOutputBlock lang="python">
|
||||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
import hashlib
|
import hashlib
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -11,8 +12,8 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
@ -31,13 +32,17 @@ except ImportError:
|
|||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.load.dump import dumps
|
||||||
|
from langchain.load.load import loads
|
||||||
from langchain.schema import Generation
|
from langchain.schema import Generation
|
||||||
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import momento
|
import momento
|
||||||
|
|
||||||
RETURN_VAL_TYPE = List[Generation]
|
RETURN_VAL_TYPE = Sequence[Generation]
|
||||||
|
|
||||||
|
|
||||||
def _hash(_input: str) -> str:
|
def _hash(_input: str) -> str:
|
||||||
@ -147,13 +152,24 @@ class SQLAlchemyCache(BaseCache):
|
|||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
rows = session.execute(stmt).fetchall()
|
rows = session.execute(stmt).fetchall()
|
||||||
if rows:
|
if rows:
|
||||||
return [Generation(text=row[0]) for row in rows]
|
try:
|
||||||
|
return [loads(row[0]) for row in rows]
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Retrieving a cache value that could not be deserialized "
|
||||||
|
"properly. This is likely due to the cache being in an "
|
||||||
|
"older format. Please recreate your cache to avoid this "
|
||||||
|
"error."
|
||||||
|
)
|
||||||
|
# In a previous life we stored the raw text directly
|
||||||
|
# in the table, so assume it's in that format.
|
||||||
|
return [Generation(text=row[0]) for row in rows]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
"""Update based on prompt and llm_string."""
|
"""Update based on prompt and llm_string."""
|
||||||
items = [
|
items = [
|
||||||
self.cache_schema(prompt=prompt, llm=llm_string, response=gen.text, idx=i)
|
self.cache_schema(prompt=prompt, llm=llm_string, response=dumps(gen), idx=i)
|
||||||
for i, gen in enumerate(return_val)
|
for i, gen in enumerate(return_val)
|
||||||
]
|
]
|
||||||
with Session(self.engine) as session, session.begin():
|
with Session(self.engine) as session, session.begin():
|
||||||
@ -163,7 +179,7 @@ class SQLAlchemyCache(BaseCache):
|
|||||||
def clear(self, **kwargs: Any) -> None:
|
def clear(self, **kwargs: Any) -> None:
|
||||||
"""Clear cache."""
|
"""Clear cache."""
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
session.execute(self.cache_schema.delete())
|
session.query(self.cache_schema).delete()
|
||||||
|
|
||||||
|
|
||||||
class SQLiteCache(SQLAlchemyCache):
|
class SQLiteCache(SQLAlchemyCache):
|
||||||
@ -209,6 +225,12 @@ class RedisCache(BaseCache):
|
|||||||
|
|
||||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
"""Update cache based on prompt and llm_string."""
|
"""Update cache based on prompt and llm_string."""
|
||||||
|
for gen in return_val:
|
||||||
|
if not isinstance(return_val, Generation):
|
||||||
|
raise ValueError(
|
||||||
|
"RedisCache only supports caching of normal LLM generations, "
|
||||||
|
f"got {type(gen)}"
|
||||||
|
)
|
||||||
# Write to a Redis HASH
|
# Write to a Redis HASH
|
||||||
key = self._key(prompt, llm_string)
|
key = self._key(prompt, llm_string)
|
||||||
self.redis.hset(
|
self.redis.hset(
|
||||||
@ -314,6 +336,12 @@ class RedisSemanticCache(BaseCache):
|
|||||||
|
|
||||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
"""Update cache based on prompt and llm_string."""
|
"""Update cache based on prompt and llm_string."""
|
||||||
|
for gen in return_val:
|
||||||
|
if not isinstance(return_val, Generation):
|
||||||
|
raise ValueError(
|
||||||
|
"RedisSemanticCache only supports caching of "
|
||||||
|
f"normal LLM generations, got {type(gen)}"
|
||||||
|
)
|
||||||
llm_cache = self._get_llm_cache(llm_string)
|
llm_cache = self._get_llm_cache(llm_string)
|
||||||
# Write to vectorstore
|
# Write to vectorstore
|
||||||
metadata = {
|
metadata = {
|
||||||
@ -426,6 +454,12 @@ class GPTCache(BaseCache):
|
|||||||
First, retrieve the corresponding cache object using the `llm_string` parameter,
|
First, retrieve the corresponding cache object using the `llm_string` parameter,
|
||||||
and then store the `prompt` and `return_val` in the cache object.
|
and then store the `prompt` and `return_val` in the cache object.
|
||||||
"""
|
"""
|
||||||
|
for gen in return_val:
|
||||||
|
if not isinstance(return_val, Generation):
|
||||||
|
raise ValueError(
|
||||||
|
"GPTCache only supports caching of normal LLM generations, "
|
||||||
|
f"got {type(gen)}"
|
||||||
|
)
|
||||||
from gptcache.adapter.api import put
|
from gptcache.adapter.api import put
|
||||||
|
|
||||||
_gptcache = self._get_gptcache(llm_string)
|
_gptcache = self._get_gptcache(llm_string)
|
||||||
@ -567,7 +601,7 @@ class MomentoCache(BaseCache):
|
|||||||
"""
|
"""
|
||||||
from momento.responses import CacheGet
|
from momento.responses import CacheGet
|
||||||
|
|
||||||
generations = []
|
generations: RETURN_VAL_TYPE = []
|
||||||
|
|
||||||
get_response = self.cache_client.get(
|
get_response = self.cache_client.get(
|
||||||
self.cache_name, self.__key(prompt, llm_string)
|
self.cache_name, self.__key(prompt, llm_string)
|
||||||
@ -593,6 +627,12 @@ class MomentoCache(BaseCache):
|
|||||||
SdkException: Momento service or network error
|
SdkException: Momento service or network error
|
||||||
Exception: Unexpected response
|
Exception: Unexpected response
|
||||||
"""
|
"""
|
||||||
|
for gen in return_val:
|
||||||
|
if not isinstance(return_val, Generation):
|
||||||
|
raise ValueError(
|
||||||
|
"Momento only supports caching of normal LLM generations, "
|
||||||
|
f"got {type(gen)}"
|
||||||
|
)
|
||||||
key = self.__key(prompt, llm_string)
|
key = self.__key(prompt, llm_string)
|
||||||
value = _dump_generations_to_json(return_val)
|
value = _dump_generations_to_json(return_val)
|
||||||
set_response = self.cache_client.set(self.cache_name, key, value, self.ttl)
|
set_response = self.cache_client.set(self.cache_name, key, value, self.ttl)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from langchain.chat_models.anthropic import ChatAnthropic
|
from langchain.chat_models.anthropic import ChatAnthropic
|
||||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||||
|
from langchain.chat_models.fake import FakeListChatModel
|
||||||
from langchain.chat_models.google_palm import ChatGooglePalm
|
from langchain.chat_models.google_palm import ChatGooglePalm
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
||||||
@ -8,6 +9,7 @@ from langchain.chat_models.vertexai import ChatVertexAI
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
"AzureChatOpenAI",
|
"AzureChatOpenAI",
|
||||||
|
"FakeListChatModel",
|
||||||
"PromptLayerChatOpenAI",
|
"PromptLayerChatOpenAI",
|
||||||
"ChatAnthropic",
|
"ChatAnthropic",
|
||||||
"ChatGooglePalm",
|
"ChatGooglePalm",
|
||||||
|
@ -17,7 +17,7 @@ from langchain.callbacks.manager import (
|
|||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd, dumps
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -35,6 +35,7 @@ def _get_verbosity() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
class BaseChatModel(BaseLanguageModel, ABC):
|
class BaseChatModel(BaseLanguageModel, ABC):
|
||||||
|
cache: Optional[bool] = None
|
||||||
verbose: bool = Field(default_factory=_get_verbosity)
|
verbose: bool = Field(default_factory=_get_verbosity)
|
||||||
"""Whether to print out response text."""
|
"""Whether to print out response text."""
|
||||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||||
@ -61,6 +62,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def _get_invocation_params(
|
||||||
|
self,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
) -> dict:
|
||||||
|
params = self.dict()
|
||||||
|
params["stop"] = stop
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
|
||||||
|
if self.lc_serializable:
|
||||||
|
params = {**kwargs, **{"stop": stop}}
|
||||||
|
param_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
llm_string = dumps(self)
|
||||||
|
return llm_string + "---" + param_string
|
||||||
|
else:
|
||||||
|
params = self._get_invocation_params(stop=stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
return str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
messages: List[List[BaseMessage]],
|
messages: List[List[BaseMessage]],
|
||||||
@ -71,9 +91,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
|
params = self._get_invocation_params(stop=stop)
|
||||||
params = self.dict()
|
|
||||||
params["stop"] = stop
|
|
||||||
options = {"stop": stop}
|
options = {"stop": stop}
|
||||||
|
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
@ -87,14 +105,11 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
dumpd(self), messages, invocation_params=params, options=options
|
dumpd(self), messages, invocation_params=params, options=options
|
||||||
)
|
)
|
||||||
|
|
||||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
|
||||||
"run_manager"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
results = [
|
results = [
|
||||||
self._generate(m, stop=stop, run_manager=run_manager, **kwargs)
|
self._generate_with_cache(
|
||||||
if new_arg_supported
|
m, stop=stop, run_manager=run_manager, **kwargs
|
||||||
else self._generate(m, stop=stop)
|
)
|
||||||
for m in messages
|
for m in messages
|
||||||
]
|
]
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
@ -118,8 +133,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
params = self.dict()
|
params = self._get_invocation_params(stop=stop)
|
||||||
params["stop"] = stop
|
|
||||||
options = {"stop": stop}
|
options = {"stop": stop}
|
||||||
|
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
@ -133,15 +147,12 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
dumpd(self), messages, invocation_params=params, options=options
|
dumpd(self), messages, invocation_params=params, options=options
|
||||||
)
|
)
|
||||||
|
|
||||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
|
||||||
"run_manager"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs)
|
self._agenerate_with_cache(
|
||||||
if new_arg_supported
|
m, stop=stop, run_manager=run_manager, **kwargs
|
||||||
else self._agenerate(m, stop=stop)
|
)
|
||||||
for m in messages
|
for m in messages
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -178,6 +189,84 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _generate_with_cache(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||||
|
"run_manager"
|
||||||
|
)
|
||||||
|
disregard_cache = self.cache is not None and not self.cache
|
||||||
|
if langchain.llm_cache is None or disregard_cache:
|
||||||
|
# This happens when langchain.cache is None, but self.cache is True
|
||||||
|
if self.cache is not None and self.cache:
|
||||||
|
raise ValueError(
|
||||||
|
"Asked to cache, but no cache found at `langchain.cache`."
|
||||||
|
)
|
||||||
|
if new_arg_supported:
|
||||||
|
return self._generate(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._generate(messages, stop=stop, **kwargs)
|
||||||
|
else:
|
||||||
|
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||||
|
prompt = dumps(messages)
|
||||||
|
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
||||||
|
if isinstance(cache_val, list):
|
||||||
|
return ChatResult(generations=cache_val)
|
||||||
|
else:
|
||||||
|
if new_arg_supported:
|
||||||
|
result = self._generate(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = self._generate(messages, stop=stop, **kwargs)
|
||||||
|
langchain.llm_cache.update(prompt, llm_string, result.generations)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _agenerate_with_cache(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||||
|
"run_manager"
|
||||||
|
)
|
||||||
|
disregard_cache = self.cache is not None and not self.cache
|
||||||
|
if langchain.llm_cache is None or disregard_cache:
|
||||||
|
# This happens when langchain.cache is None, but self.cache is True
|
||||||
|
if self.cache is not None and self.cache:
|
||||||
|
raise ValueError(
|
||||||
|
"Asked to cache, but no cache found at `langchain.cache`."
|
||||||
|
)
|
||||||
|
if new_arg_supported:
|
||||||
|
return await self._agenerate(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await self._agenerate(messages, stop=stop, **kwargs)
|
||||||
|
else:
|
||||||
|
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||||
|
prompt = dumps(messages)
|
||||||
|
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
||||||
|
if isinstance(cache_val, list):
|
||||||
|
return ChatResult(generations=cache_val)
|
||||||
|
else:
|
||||||
|
if new_arg_supported:
|
||||||
|
result = await self._agenerate(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||||
|
langchain.llm_cache.update(prompt, llm_string, result.generations)
|
||||||
|
return result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
|
33
langchain/chat_models/fake.py
Normal file
33
langchain/chat_models/fake.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
"""Fake ChatModel for testing purposes."""
|
||||||
|
from typing import Any, List, Mapping, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.chat_models.base import SimpleChatModel
|
||||||
|
from langchain.schema import BaseMessage
|
||||||
|
|
||||||
|
|
||||||
|
class FakeListChatModel(SimpleChatModel):
|
||||||
|
"""Fake ChatModel for testing purposes."""
|
||||||
|
|
||||||
|
responses: List
|
||||||
|
i: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "fake-list-chat-model"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
||||||
|
response = self.responses[self.i]
|
||||||
|
self.i += 1
|
||||||
|
return response
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
return {"responses": self.responses}
|
@ -76,6 +76,11 @@ class Generation(Serializable):
|
|||||||
"""May include things like reason for finishing (e.g. in OpenAI)"""
|
"""May include things like reason for finishing (e.g. in OpenAI)"""
|
||||||
# TODO: add log probs
|
# TODO: add log probs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
"""This class is LangChain serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class BaseMessage(Serializable):
|
class BaseMessage(Serializable):
|
||||||
"""Message object."""
|
"""Message object."""
|
||||||
@ -88,6 +93,11 @@ class BaseMessage(Serializable):
|
|||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
"""Type of the message, used for serialization."""
|
"""Type of the message, used for serialization."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
"""This class is LangChain serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class HumanMessage(BaseMessage):
|
class HumanMessage(BaseMessage):
|
||||||
"""Type of message that is spoken by the human."""
|
"""Type of message that is spoken by the human."""
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
"""Unit tests for ReAct."""
|
"""Unit tests for ReAct."""
|
||||||
|
|
||||||
from typing import Any, List, Mapping, Optional, Union
|
from typing import Union
|
||||||
|
|
||||||
from langchain.agents.react.base import ReActChain, ReActDocstoreAgent
|
from langchain.agents.react.base import ReActChain, ReActDocstoreAgent
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
|
||||||
from langchain.docstore.base import Docstore
|
from langchain.docstore.base import Docstore
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.fake import FakeListLLM
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import AgentAction
|
from langchain.schema import AgentAction
|
||||||
|
|
||||||
@ -22,33 +21,6 @@ Made in 2022."""
|
|||||||
_FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}")
|
_FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}")
|
||||||
|
|
||||||
|
|
||||||
class FakeListLLM(LLM):
|
|
||||||
"""Fake LLM for testing that outputs elements of a list."""
|
|
||||||
|
|
||||||
responses: List[str]
|
|
||||||
i: int = -1
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _llm_type(self) -> str:
|
|
||||||
"""Return type of llm."""
|
|
||||||
return "fake_list"
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> str:
|
|
||||||
"""Increment counter, and then return response in that index."""
|
|
||||||
self.i += 1
|
|
||||||
return self.responses[self.i]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class FakeDocstore(Docstore):
|
class FakeDocstore(Docstore):
|
||||||
"""Fake docstore for testing purposes."""
|
"""Fake docstore for testing purposes."""
|
||||||
|
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
"""Test LLM callbacks."""
|
"""Test LLM callbacks."""
|
||||||
|
from langchain.chat_models.fake import FakeListChatModel
|
||||||
|
from langchain.llms.fake import FakeListLLM
|
||||||
from langchain.schema import HumanMessage
|
from langchain.schema import HumanMessage
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||||
FakeCallbackHandler,
|
FakeCallbackHandler,
|
||||||
FakeCallbackHandlerWithChatStart,
|
FakeCallbackHandlerWithChatStart,
|
||||||
)
|
)
|
||||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
|
||||||
|
|
||||||
|
|
||||||
def test_llm_with_callbacks() -> None:
|
def test_llm_with_callbacks() -> None:
|
||||||
"""Test LLM callbacks."""
|
"""Test LLM callbacks."""
|
||||||
handler = FakeCallbackHandler()
|
handler = FakeCallbackHandler()
|
||||||
llm = FakeLLM(callbacks=[handler], verbose=True)
|
llm = FakeListLLM(callbacks=[handler], verbose=True, responses=["foo"])
|
||||||
output = llm("foo")
|
output = llm("foo")
|
||||||
assert output == "foo"
|
assert output == "foo"
|
||||||
assert handler.starts == 1
|
assert handler.starts == 1
|
||||||
@ -22,7 +22,9 @@ def test_llm_with_callbacks() -> None:
|
|||||||
def test_chat_model_with_v1_callbacks() -> None:
|
def test_chat_model_with_v1_callbacks() -> None:
|
||||||
"""Test chat model callbacks fall back to on_llm_start."""
|
"""Test chat model callbacks fall back to on_llm_start."""
|
||||||
handler = FakeCallbackHandler()
|
handler = FakeCallbackHandler()
|
||||||
llm = FakeChatModel(callbacks=[handler], verbose=True)
|
llm = FakeListChatModel(
|
||||||
|
callbacks=[handler], verbose=True, responses=["fake response"]
|
||||||
|
)
|
||||||
output = llm([HumanMessage(content="foo")])
|
output = llm([HumanMessage(content="foo")])
|
||||||
assert output.content == "fake response"
|
assert output.content == "fake response"
|
||||||
assert handler.starts == 1
|
assert handler.starts == 1
|
||||||
@ -35,7 +37,9 @@ def test_chat_model_with_v1_callbacks() -> None:
|
|||||||
def test_chat_model_with_v2_callbacks() -> None:
|
def test_chat_model_with_v2_callbacks() -> None:
|
||||||
"""Test chat model callbacks fall back to on_llm_start."""
|
"""Test chat model callbacks fall back to on_llm_start."""
|
||||||
handler = FakeCallbackHandlerWithChatStart()
|
handler = FakeCallbackHandlerWithChatStart()
|
||||||
llm = FakeChatModel(callbacks=[handler], verbose=True)
|
llm = FakeListChatModel(
|
||||||
|
callbacks=[handler], verbose=True, responses=["fake response"]
|
||||||
|
)
|
||||||
output = llm([HumanMessage(content="foo")])
|
output = llm([HumanMessage(content="foo")])
|
||||||
assert output.content == "fake response"
|
assert output.content == "fake response"
|
||||||
assert handler.starts == 1
|
assert handler.starts == 1
|
||||||
|
146
tests/unit_tests/test_cache.py
Normal file
146
tests/unit_tests/test_cache.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
"""Test caching for LLMs and ChatModels."""
|
||||||
|
from typing import Dict, Generator, List, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from _pytest.fixtures import FixtureRequest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
import langchain
|
||||||
|
from langchain.cache import (
|
||||||
|
InMemoryCache,
|
||||||
|
SQLAlchemyCache,
|
||||||
|
)
|
||||||
|
from langchain.chat_models import FakeListChatModel
|
||||||
|
from langchain.chat_models.base import BaseChatModel, dumps
|
||||||
|
from langchain.llms import FakeListLLM
|
||||||
|
from langchain.llms.base import BaseLLM
|
||||||
|
from langchain.schema import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ChatGeneration,
|
||||||
|
Generation,
|
||||||
|
HumanMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sqlite_cache() -> SQLAlchemyCache:
|
||||||
|
return SQLAlchemyCache(engine=create_engine("sqlite://"))
|
||||||
|
|
||||||
|
|
||||||
|
CACHE_OPTIONS = [
|
||||||
|
InMemoryCache,
|
||||||
|
get_sqlite_cache,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True, params=CACHE_OPTIONS)
|
||||||
|
def set_cache_and_teardown(request: FixtureRequest) -> Generator[None, None, None]:
|
||||||
|
# Will be run before each test
|
||||||
|
cache_instance = request.param
|
||||||
|
langchain.llm_cache = cache_instance()
|
||||||
|
if langchain.llm_cache:
|
||||||
|
langchain.llm_cache.clear()
|
||||||
|
else:
|
||||||
|
raise ValueError("Cache not set. This should never happen.")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Will be run after each test
|
||||||
|
if langchain.llm_cache:
|
||||||
|
langchain.llm_cache.clear()
|
||||||
|
else:
|
||||||
|
raise ValueError("Cache not set. This should never happen.")
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_caching() -> None:
|
||||||
|
prompt = "How are you?"
|
||||||
|
response = "Test response"
|
||||||
|
cached_response = "Cached test response"
|
||||||
|
llm = FakeListLLM(responses=[response])
|
||||||
|
if langchain.llm_cache:
|
||||||
|
langchain.llm_cache.update(
|
||||||
|
prompt=prompt,
|
||||||
|
llm_string=create_llm_string(llm),
|
||||||
|
return_val=[Generation(text=cached_response)],
|
||||||
|
)
|
||||||
|
assert llm(prompt) == cached_response
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The cache not set. This should never happen, as the pytest fixture "
|
||||||
|
"`set_cache_and_teardown` always sets the cache."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_old_sqlite_llm_caching() -> None:
|
||||||
|
if isinstance(langchain.llm_cache, SQLAlchemyCache):
|
||||||
|
prompt = "How are you?"
|
||||||
|
response = "Test response"
|
||||||
|
cached_response = "Cached test response"
|
||||||
|
llm = FakeListLLM(responses=[response])
|
||||||
|
items = [
|
||||||
|
langchain.llm_cache.cache_schema(
|
||||||
|
prompt=prompt,
|
||||||
|
llm=create_llm_string(llm),
|
||||||
|
response=cached_response,
|
||||||
|
idx=0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
with Session(langchain.llm_cache.engine) as session, session.begin():
|
||||||
|
for item in items:
|
||||||
|
session.merge(item)
|
||||||
|
assert llm(prompt) == cached_response
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_model_caching() -> None:
|
||||||
|
prompt: List[BaseMessage] = [HumanMessage(content="How are you?")]
|
||||||
|
response = "Test response"
|
||||||
|
cached_response = "Cached test response"
|
||||||
|
cached_message = AIMessage(content=cached_response)
|
||||||
|
llm = FakeListChatModel(responses=[response])
|
||||||
|
if langchain.llm_cache:
|
||||||
|
langchain.llm_cache.update(
|
||||||
|
prompt=dumps(prompt),
|
||||||
|
llm_string=llm._get_llm_string(),
|
||||||
|
return_val=[ChatGeneration(message=cached_message)],
|
||||||
|
)
|
||||||
|
result = llm(prompt)
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
assert result.content == cached_response
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The cache not set. This should never happen, as the pytest fixture "
|
||||||
|
"`set_cache_and_teardown` always sets the cache."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_model_caching_params() -> None:
|
||||||
|
prompt: List[BaseMessage] = [HumanMessage(content="How are you?")]
|
||||||
|
response = "Test response"
|
||||||
|
cached_response = "Cached test response"
|
||||||
|
cached_message = AIMessage(content=cached_response)
|
||||||
|
llm = FakeListChatModel(responses=[response])
|
||||||
|
if langchain.llm_cache:
|
||||||
|
langchain.llm_cache.update(
|
||||||
|
prompt=dumps(prompt),
|
||||||
|
llm_string=llm._get_llm_string(functions=[]),
|
||||||
|
return_val=[ChatGeneration(message=cached_message)],
|
||||||
|
)
|
||||||
|
result = llm(prompt, functions=[])
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
assert result.content == cached_response
|
||||||
|
result_no_params = llm(prompt)
|
||||||
|
assert isinstance(result_no_params, AIMessage)
|
||||||
|
assert result_no_params.content == response
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The cache not set. This should never happen, as the pytest fixture "
|
||||||
|
"`set_cache_and_teardown` always sets the cache."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_llm_string(llm: Union[BaseLLM, BaseChatModel]) -> str:
|
||||||
|
_dict: Dict = llm.dict()
|
||||||
|
_dict["stop"] = None
|
||||||
|
return str(sorted([(k, v) for k, v in _dict.items()]))
|
Loading…
Reference in New Issue
Block a user