mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
# Add caching to BaseChatModel Fixes #1644 (Sidenote: While testing, I noticed we have multiple implementations of Fake LLMs, used for testing. I consolidated them.) ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: Models - @hwchase17 - @agola11 Twitter: [@UmerHAdil](https://twitter.com/@UmerHAdil) | Discord: RicChilligerDude#7589 --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
c289cc891a
commit
068142fce2
@ -0,0 +1,9 @@
|
||||
# Caching
|
||||
LangChain provides an optional caching layer for Chat Models. This is useful for two reasons:
|
||||
|
||||
It can save you money by reducing the number of API calls you make to the LLM provider, if you're often requesting the same completion multiple times.
|
||||
It can speed up your application by reducing the number of API calls you make to the LLM provider.
|
||||
|
||||
import CachingChat from "@snippets/modules/model_io/models/chat/how_to/chat_model_caching.mdx"
|
||||
|
||||
<CachingChat/>
|
@ -0,0 +1,97 @@
|
||||
```python
|
||||
import langchain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
llm = ChatOpenAI()
|
||||
```
|
||||
|
||||
## In Memory Cache
|
||||
|
||||
|
||||
```python
|
||||
from langchain.cache import InMemoryCache
|
||||
langchain.llm_cache = InMemoryCache()
|
||||
|
||||
# The first time, it is not yet in cache, so it should take longer
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
|
||||
```
|
||||
CPU times: user 35.9 ms, sys: 28.6 ms, total: 64.6 ms
|
||||
Wall time: 4.83 s
|
||||
|
||||
|
||||
"\n\nWhy couldn't the bicycle stand up by itself? It was...two tired!"
|
||||
```
|
||||
|
||||
</CodeOutputBlock>
|
||||
|
||||
|
||||
```python
|
||||
# The second time it is, so it goes faster
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
|
||||
```
|
||||
CPU times: user 238 µs, sys: 143 µs, total: 381 µs
|
||||
Wall time: 1.76 ms
|
||||
|
||||
|
||||
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
|
||||
```
|
||||
|
||||
</CodeOutputBlock>
|
||||
|
||||
## SQLite Cache
|
||||
|
||||
|
||||
```bash
|
||||
rm .langchain.db
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
# We can do the same thing with a SQLite cache
|
||||
from langchain.cache import SQLiteCache
|
||||
langchain.llm_cache = SQLiteCache(database_path=".langchain.db")
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
# The first time, it is not yet in cache, so it should take longer
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
|
||||
```
|
||||
CPU times: user 17 ms, sys: 9.76 ms, total: 26.7 ms
|
||||
Wall time: 825 ms
|
||||
|
||||
|
||||
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
|
||||
```
|
||||
|
||||
</CodeOutputBlock>
|
||||
|
||||
|
||||
```python
|
||||
# The second time it is, so it goes faster
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
|
||||
```
|
||||
CPU times: user 2.46 ms, sys: 1.23 ms, total: 3.7 ms
|
||||
Wall time: 2.67 ms
|
||||
|
||||
|
||||
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
|
||||
```
|
||||
|
||||
</CodeOutputBlock>
|
@ -14,7 +14,7 @@ from langchain.cache import InMemoryCache
|
||||
langchain.llm_cache = InMemoryCache()
|
||||
|
||||
# The first time, it is not yet in cache, so it should take longer
|
||||
llm("Tell me a joke")
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
@ -32,7 +32,7 @@ llm("Tell me a joke")
|
||||
|
||||
```python
|
||||
# The second time it is, so it goes faster
|
||||
llm("Tell me a joke")
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
@ -64,7 +64,7 @@ langchain.llm_cache = SQLiteCache(database_path=".langchain.db")
|
||||
|
||||
```python
|
||||
# The first time, it is not yet in cache, so it should take longer
|
||||
llm("Tell me a joke")
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
@ -82,7 +82,7 @@ llm("Tell me a joke")
|
||||
|
||||
```python
|
||||
# The second time it is, so it goes faster
|
||||
llm("Tell me a joke")
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import timedelta
|
||||
from typing import (
|
||||
@ -11,8 +12,8 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
@ -31,13 +32,17 @@ except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.load.load import loads
|
||||
from langchain.schema import Generation
|
||||
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import momento
|
||||
|
||||
RETURN_VAL_TYPE = List[Generation]
|
||||
RETURN_VAL_TYPE = Sequence[Generation]
|
||||
|
||||
|
||||
def _hash(_input: str) -> str:
|
||||
@ -147,13 +152,24 @@ class SQLAlchemyCache(BaseCache):
|
||||
with Session(self.engine) as session:
|
||||
rows = session.execute(stmt).fetchall()
|
||||
if rows:
|
||||
return [Generation(text=row[0]) for row in rows]
|
||||
try:
|
||||
return [loads(row[0]) for row in rows]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Retrieving a cache value that could not be deserialized "
|
||||
"properly. This is likely due to the cache being in an "
|
||||
"older format. Please recreate your cache to avoid this "
|
||||
"error."
|
||||
)
|
||||
# In a previous life we stored the raw text directly
|
||||
# in the table, so assume it's in that format.
|
||||
return [Generation(text=row[0]) for row in rows]
|
||||
return None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update based on prompt and llm_string."""
|
||||
items = [
|
||||
self.cache_schema(prompt=prompt, llm=llm_string, response=gen.text, idx=i)
|
||||
self.cache_schema(prompt=prompt, llm=llm_string, response=dumps(gen), idx=i)
|
||||
for i, gen in enumerate(return_val)
|
||||
]
|
||||
with Session(self.engine) as session, session.begin():
|
||||
@ -163,7 +179,7 @@ class SQLAlchemyCache(BaseCache):
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache."""
|
||||
with Session(self.engine) as session:
|
||||
session.execute(self.cache_schema.delete())
|
||||
session.query(self.cache_schema).delete()
|
||||
|
||||
|
||||
class SQLiteCache(SQLAlchemyCache):
|
||||
@ -209,6 +225,12 @@ class RedisCache(BaseCache):
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
for gen in return_val:
|
||||
if not isinstance(return_val, Generation):
|
||||
raise ValueError(
|
||||
"RedisCache only supports caching of normal LLM generations, "
|
||||
f"got {type(gen)}"
|
||||
)
|
||||
# Write to a Redis HASH
|
||||
key = self._key(prompt, llm_string)
|
||||
self.redis.hset(
|
||||
@ -314,6 +336,12 @@ class RedisSemanticCache(BaseCache):
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
for gen in return_val:
|
||||
if not isinstance(return_val, Generation):
|
||||
raise ValueError(
|
||||
"RedisSemanticCache only supports caching of "
|
||||
f"normal LLM generations, got {type(gen)}"
|
||||
)
|
||||
llm_cache = self._get_llm_cache(llm_string)
|
||||
# Write to vectorstore
|
||||
metadata = {
|
||||
@ -426,6 +454,12 @@ class GPTCache(BaseCache):
|
||||
First, retrieve the corresponding cache object using the `llm_string` parameter,
|
||||
and then store the `prompt` and `return_val` in the cache object.
|
||||
"""
|
||||
for gen in return_val:
|
||||
if not isinstance(return_val, Generation):
|
||||
raise ValueError(
|
||||
"GPTCache only supports caching of normal LLM generations, "
|
||||
f"got {type(gen)}"
|
||||
)
|
||||
from gptcache.adapter.api import put
|
||||
|
||||
_gptcache = self._get_gptcache(llm_string)
|
||||
@ -567,7 +601,7 @@ class MomentoCache(BaseCache):
|
||||
"""
|
||||
from momento.responses import CacheGet
|
||||
|
||||
generations = []
|
||||
generations: RETURN_VAL_TYPE = []
|
||||
|
||||
get_response = self.cache_client.get(
|
||||
self.cache_name, self.__key(prompt, llm_string)
|
||||
@ -593,6 +627,12 @@ class MomentoCache(BaseCache):
|
||||
SdkException: Momento service or network error
|
||||
Exception: Unexpected response
|
||||
"""
|
||||
for gen in return_val:
|
||||
if not isinstance(return_val, Generation):
|
||||
raise ValueError(
|
||||
"Momento only supports caching of normal LLM generations, "
|
||||
f"got {type(gen)}"
|
||||
)
|
||||
key = self.__key(prompt, llm_string)
|
||||
value = _dump_generations_to_json(return_val)
|
||||
set_response = self.cache_client.set(self.cache_name, key, value, self.ttl)
|
||||
|
@ -1,5 +1,6 @@
|
||||
from langchain.chat_models.anthropic import ChatAnthropic
|
||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
from langchain.chat_models.google_palm import ChatGooglePalm
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
||||
@ -8,6 +9,7 @@ from langchain.chat_models.vertexai import ChatVertexAI
|
||||
__all__ = [
|
||||
"ChatOpenAI",
|
||||
"AzureChatOpenAI",
|
||||
"FakeListChatModel",
|
||||
"PromptLayerChatOpenAI",
|
||||
"ChatAnthropic",
|
||||
"ChatGooglePalm",
|
||||
|
@ -17,7 +17,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.dump import dumpd, dumps
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@ -35,6 +35,7 @@ def _get_verbosity() -> bool:
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel, ABC):
|
||||
cache: Optional[bool] = None
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
@ -61,6 +62,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
return {}
|
||||
|
||||
def _get_invocation_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> dict:
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
return params
|
||||
|
||||
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
|
||||
if self.lc_serializable:
|
||||
params = {**kwargs, **{"stop": stop}}
|
||||
param_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
llm_string = dumps(self)
|
||||
return llm_string + "---" + param_string
|
||||
else:
|
||||
params = self._get_invocation_params(stop=stop)
|
||||
params = {**params, **kwargs}
|
||||
return str(sorted([(k, v) for k, v in params.items()]))
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
@ -71,9 +91,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
params = self._get_invocation_params(stop=stop)
|
||||
options = {"stop": stop}
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
@ -87,14 +105,11 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
dumpd(self), messages, invocation_params=params, options=options
|
||||
)
|
||||
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
try:
|
||||
results = [
|
||||
self._generate(m, stop=stop, run_manager=run_manager, **kwargs)
|
||||
if new_arg_supported
|
||||
else self._generate(m, stop=stop)
|
||||
self._generate_with_cache(
|
||||
m, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
@ -118,8 +133,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
params = self._get_invocation_params(stop=stop)
|
||||
options = {"stop": stop}
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
@ -133,15 +147,12 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
dumpd(self), messages, invocation_params=params, options=options
|
||||
)
|
||||
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
try:
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs)
|
||||
if new_arg_supported
|
||||
else self._agenerate(m, stop=stop)
|
||||
self._agenerate_with_cache(
|
||||
m, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
)
|
||||
@ -178,6 +189,84 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
def _generate_with_cache(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if new_arg_supported:
|
||||
return self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
return self._generate(messages, stop=stop, **kwargs)
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
if new_arg_supported:
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
result = self._generate(messages, stop=stop, **kwargs)
|
||||
langchain.llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
|
||||
async def _agenerate_with_cache(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if new_arg_supported:
|
||||
return await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
return await self._agenerate(messages, stop=stop, **kwargs)
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
if new_arg_supported:
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||
langchain.llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self,
|
||||
|
33
langchain/chat_models/fake.py
Normal file
33
langchain/chat_models/fake.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import SimpleChatModel
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
|
||||
class FakeListChatModel(SimpleChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
responses: List
|
||||
i: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-list-chat-model"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
||||
response = self.responses[self.i]
|
||||
self.i += 1
|
||||
return response
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {"responses": self.responses}
|
@ -76,6 +76,11 @@ class Generation(Serializable):
|
||||
"""May include things like reason for finishing (e.g. in OpenAI)"""
|
||||
# TODO: add log probs
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""This class is LangChain serializable."""
|
||||
return True
|
||||
|
||||
|
||||
class BaseMessage(Serializable):
|
||||
"""Message object."""
|
||||
@ -88,6 +93,11 @@ class BaseMessage(Serializable):
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""This class is LangChain serializable."""
|
||||
return True
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the human."""
|
||||
|
@ -1,13 +1,12 @@
|
||||
"""Unit tests for ReAct."""
|
||||
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
from typing import Union
|
||||
|
||||
from langchain.agents.react.base import ReActChain, ReActDocstoreAgent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import AgentAction
|
||||
|
||||
@ -22,33 +21,6 @@ Made in 2022."""
|
||||
_FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}")
|
||||
|
||||
|
||||
class FakeListLLM(LLM):
|
||||
"""Fake LLM for testing that outputs elements of a list."""
|
||||
|
||||
responses: List[str]
|
||||
i: int = -1
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fake_list"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Increment counter, and then return response in that index."""
|
||||
self.i += 1
|
||||
return self.responses[self.i]
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
class FakeDocstore(Docstore):
|
||||
"""Fake docstore for testing purposes."""
|
||||
|
||||
|
@ -1,17 +1,17 @@
|
||||
"""Test LLM callbacks."""
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from langchain.schema import HumanMessage
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
FakeCallbackHandler,
|
||||
FakeCallbackHandlerWithChatStart,
|
||||
)
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def test_llm_with_callbacks() -> None:
|
||||
"""Test LLM callbacks."""
|
||||
handler = FakeCallbackHandler()
|
||||
llm = FakeLLM(callbacks=[handler], verbose=True)
|
||||
llm = FakeListLLM(callbacks=[handler], verbose=True, responses=["foo"])
|
||||
output = llm("foo")
|
||||
assert output == "foo"
|
||||
assert handler.starts == 1
|
||||
@ -22,7 +22,9 @@ def test_llm_with_callbacks() -> None:
|
||||
def test_chat_model_with_v1_callbacks() -> None:
|
||||
"""Test chat model callbacks fall back to on_llm_start."""
|
||||
handler = FakeCallbackHandler()
|
||||
llm = FakeChatModel(callbacks=[handler], verbose=True)
|
||||
llm = FakeListChatModel(
|
||||
callbacks=[handler], verbose=True, responses=["fake response"]
|
||||
)
|
||||
output = llm([HumanMessage(content="foo")])
|
||||
assert output.content == "fake response"
|
||||
assert handler.starts == 1
|
||||
@ -35,7 +37,9 @@ def test_chat_model_with_v1_callbacks() -> None:
|
||||
def test_chat_model_with_v2_callbacks() -> None:
|
||||
"""Test chat model callbacks fall back to on_llm_start."""
|
||||
handler = FakeCallbackHandlerWithChatStart()
|
||||
llm = FakeChatModel(callbacks=[handler], verbose=True)
|
||||
llm = FakeListChatModel(
|
||||
callbacks=[handler], verbose=True, responses=["fake response"]
|
||||
)
|
||||
output = llm([HumanMessage(content="foo")])
|
||||
assert output.content == "fake response"
|
||||
assert handler.starts == 1
|
||||
|
146
tests/unit_tests/test_cache.py
Normal file
146
tests/unit_tests/test_cache.py
Normal file
@ -0,0 +1,146 @@
|
||||
"""Test caching for LLMs and ChatModels."""
|
||||
from typing import Dict, Generator, List, Union
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import FixtureRequest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import langchain
|
||||
from langchain.cache import (
|
||||
InMemoryCache,
|
||||
SQLAlchemyCache,
|
||||
)
|
||||
from langchain.chat_models import FakeListChatModel
|
||||
from langchain.chat_models.base import BaseChatModel, dumps
|
||||
from langchain.llms import FakeListLLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
Generation,
|
||||
HumanMessage,
|
||||
)
|
||||
|
||||
|
||||
def get_sqlite_cache() -> SQLAlchemyCache:
|
||||
return SQLAlchemyCache(engine=create_engine("sqlite://"))
|
||||
|
||||
|
||||
CACHE_OPTIONS = [
|
||||
InMemoryCache,
|
||||
get_sqlite_cache,
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, params=CACHE_OPTIONS)
|
||||
def set_cache_and_teardown(request: FixtureRequest) -> Generator[None, None, None]:
|
||||
# Will be run before each test
|
||||
cache_instance = request.param
|
||||
langchain.llm_cache = cache_instance()
|
||||
if langchain.llm_cache:
|
||||
langchain.llm_cache.clear()
|
||||
else:
|
||||
raise ValueError("Cache not set. This should never happen.")
|
||||
|
||||
yield
|
||||
|
||||
# Will be run after each test
|
||||
if langchain.llm_cache:
|
||||
langchain.llm_cache.clear()
|
||||
else:
|
||||
raise ValueError("Cache not set. This should never happen.")
|
||||
|
||||
|
||||
def test_llm_caching() -> None:
|
||||
prompt = "How are you?"
|
||||
response = "Test response"
|
||||
cached_response = "Cached test response"
|
||||
llm = FakeListLLM(responses=[response])
|
||||
if langchain.llm_cache:
|
||||
langchain.llm_cache.update(
|
||||
prompt=prompt,
|
||||
llm_string=create_llm_string(llm),
|
||||
return_val=[Generation(text=cached_response)],
|
||||
)
|
||||
assert llm(prompt) == cached_response
|
||||
else:
|
||||
raise ValueError(
|
||||
"The cache not set. This should never happen, as the pytest fixture "
|
||||
"`set_cache_and_teardown` always sets the cache."
|
||||
)
|
||||
|
||||
|
||||
def test_old_sqlite_llm_caching() -> None:
|
||||
if isinstance(langchain.llm_cache, SQLAlchemyCache):
|
||||
prompt = "How are you?"
|
||||
response = "Test response"
|
||||
cached_response = "Cached test response"
|
||||
llm = FakeListLLM(responses=[response])
|
||||
items = [
|
||||
langchain.llm_cache.cache_schema(
|
||||
prompt=prompt,
|
||||
llm=create_llm_string(llm),
|
||||
response=cached_response,
|
||||
idx=0,
|
||||
)
|
||||
]
|
||||
with Session(langchain.llm_cache.engine) as session, session.begin():
|
||||
for item in items:
|
||||
session.merge(item)
|
||||
assert llm(prompt) == cached_response
|
||||
|
||||
|
||||
def test_chat_model_caching() -> None:
|
||||
prompt: List[BaseMessage] = [HumanMessage(content="How are you?")]
|
||||
response = "Test response"
|
||||
cached_response = "Cached test response"
|
||||
cached_message = AIMessage(content=cached_response)
|
||||
llm = FakeListChatModel(responses=[response])
|
||||
if langchain.llm_cache:
|
||||
langchain.llm_cache.update(
|
||||
prompt=dumps(prompt),
|
||||
llm_string=llm._get_llm_string(),
|
||||
return_val=[ChatGeneration(message=cached_message)],
|
||||
)
|
||||
result = llm(prompt)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content == cached_response
|
||||
else:
|
||||
raise ValueError(
|
||||
"The cache not set. This should never happen, as the pytest fixture "
|
||||
"`set_cache_and_teardown` always sets the cache."
|
||||
)
|
||||
|
||||
|
||||
def test_chat_model_caching_params() -> None:
|
||||
prompt: List[BaseMessage] = [HumanMessage(content="How are you?")]
|
||||
response = "Test response"
|
||||
cached_response = "Cached test response"
|
||||
cached_message = AIMessage(content=cached_response)
|
||||
llm = FakeListChatModel(responses=[response])
|
||||
if langchain.llm_cache:
|
||||
langchain.llm_cache.update(
|
||||
prompt=dumps(prompt),
|
||||
llm_string=llm._get_llm_string(functions=[]),
|
||||
return_val=[ChatGeneration(message=cached_message)],
|
||||
)
|
||||
result = llm(prompt, functions=[])
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content == cached_response
|
||||
result_no_params = llm(prompt)
|
||||
assert isinstance(result_no_params, AIMessage)
|
||||
assert result_no_params.content == response
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"The cache not set. This should never happen, as the pytest fixture "
|
||||
"`set_cache_and_teardown` always sets the cache."
|
||||
)
|
||||
|
||||
|
||||
def create_llm_string(llm: Union[BaseLLM, BaseChatModel]) -> str:
|
||||
_dict: Dict = llm.dict()
|
||||
_dict["stop"] = None
|
||||
return str(sorted([(k, v) for k, v in _dict.items()]))
|
Loading…
Reference in New Issue
Block a user