use get_llm_cache and set_llm_cache (#11741)

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/11812/head
Harrison Chase 12 months ago committed by GitHub
parent f3ad22e64a
commit 4a2f0c51a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -463,15 +463,15 @@ class RedisSemanticCache(BaseCache):
.. code-block:: python
import langchain
from langchain.globals import set_llm_cache
from langchain.cache import RedisSemanticCache
from langchain.embeddings import OpenAIEmbeddings
langchain.llm_cache = RedisSemanticCache(
set_llm_cache(RedisSemanticCache(
redis_url="redis://localhost:6379",
embedding=OpenAIEmbeddings()
)
))
"""
self._cache_dict: Dict[str, RedisVectorstore] = {}
@ -588,6 +588,7 @@ class GPTCache(BaseCache):
import gptcache
from gptcache.processor.pre import get_prompt
from gptcache.manager.factory import get_data_manager
from langchain.globals import set_llm_cache
# Avoid multiple caches using the same file,
causing different llm model caches to affect each other
@ -601,7 +602,7 @@ class GPTCache(BaseCache):
),
)
langchain.llm_cache = GPTCache(init_gptcache)
set_llm_cache(GPTCache(init_gptcache))
"""
try:

@ -15,7 +15,6 @@ from typing import (
cast,
)
import langchain
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManager,
@ -24,6 +23,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.globals import get_llm_cache
from langchain.load.dump import dumpd, dumps
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
@ -487,7 +487,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
if langchain.llm_cache is None or disregard_cache:
llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
@ -502,7 +503,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
@ -512,7 +513,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
)
else:
result = self._generate(messages, stop=stop, **kwargs)
langchain.llm_cache.update(prompt, llm_string, result.generations)
llm_cache.update(prompt, llm_string, result.generations)
return result
async def _agenerate_with_cache(
@ -526,7 +527,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
if langchain.llm_cache is None or disregard_cache:
llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
@ -541,7 +543,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
@ -551,7 +553,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
)
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
langchain.llm_cache.update(prompt, llm_string, result.generations)
llm_cache.update(prompt, llm_string, result.generations)
return result
@abstractmethod

@ -121,7 +121,7 @@ def get_debug() -> bool:
return _debug or old_debug
def set_llm_cache(value: "BaseCache") -> None:
def set_llm_cache(value: Optional["BaseCache"]) -> None:
"""Set a new LLM cache, overwriting the previous value, if any."""
import langchain

@ -37,7 +37,6 @@ from tenacity import (
wait_exponential,
)
import langchain
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManager,
@ -46,6 +45,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.globals import get_llm_cache
from langchain.load.dump import dumpd
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
@ -124,9 +124,10 @@ def get_prompts(
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
llm_cache = get_llm_cache()
for i, prompt in enumerate(prompts):
if langchain.llm_cache is not None:
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if llm_cache is not None:
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
existing_prompts[i] = cache_val
else:
@ -143,11 +144,12 @@ def update_cache(
prompts: List[str],
) -> Optional[dict]:
"""Update the cache and get the LLM output."""
llm_cache = get_llm_cache()
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]]
if langchain.llm_cache is not None:
langchain.llm_cache.update(prompt, llm_string, result)
if llm_cache is not None:
llm_cache.update(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output
@ -624,7 +626,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
if langchain.llm_cache is None or disregard_cache:
if get_llm_cache() is None or disregard_cache:
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
@ -788,7 +790,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
if langchain.llm_cache is None or disregard_cache:
if get_llm_cache() is None or disregard_cache:
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."

@ -18,8 +18,8 @@ git grep '^from langchain' langchain/utilities | grep -vE 'from langchain.(pydan
git grep '^from langchain' langchain/storage | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/prompts | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api)' && errors=$((errors+1))
git grep '^from langchain' langchain/output_parsers | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api|output_parsers)' && errors=$((errors+1))
git grep '^from langchain' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities|globals)' && errors=$((errors+1))
git grep '^from langchain' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities|globals)' && errors=$((errors+1))
git grep '^from langchain' langchain/embeddings | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|llms|embeddings|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/docstore | grep -vE 'from langchain.(pydantic_v1|utils|schema|docstore)' && errors=$((errors+1))
git grep '^from langchain' langchain/vectorstores | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|_api|storage|llms|docstore|vectorstores|utilities)' && errors=$((errors+1))

@ -5,8 +5,8 @@ from typing import Any, Iterator, Tuple
import pytest
import langchain
from langchain.cache import CassandraCache, CassandraSemanticCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain.schema import Generation, LLMResult
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.unit_tests.llms.fake_llm import FakeLLM
@ -39,12 +39,12 @@ def cassandra_connection() -> Iterator[Tuple[Any, str]]:
def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection
cache = CassandraCache(session=session, keyspace=keyspace)
langchain.llm_cache = cache
set_llm_cache(cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
print(output)
expected_output = LLMResult(
@ -59,12 +59,12 @@ def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection
cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2)
langchain.llm_cache = cache
set_llm_cache(cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
@ -85,12 +85,12 @@ def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None
keyspace=keyspace,
embedding=FakeEmbeddings(),
)
langchain.llm_cache = sem_cache
set_llm_cache(sem_cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["bar"]) # same embedding as 'foo'
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],

@ -3,8 +3,8 @@ from typing import Any, Callable, Union
import pytest
import langchain
from langchain.cache import GPTCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain.schema import Generation
from tests.unit_tests.llms.fake_llm import FakeLLM
@ -48,15 +48,15 @@ def test_gptcache_caching(
init_func: Union[Callable[[Any, str], None], Callable[[Any], None], None]
) -> None:
"""Test gptcache default caching behavior."""
langchain.llm_cache = GPTCache(init_func)
set_llm_cache(GPTCache(init_func))
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
_ = llm.generate(["foo", "bar", "foo"])
cache_output = langchain.llm_cache.lookup("foo", llm_string)
cache_output = get_llm_cache().lookup("foo", llm_string)
assert cache_output == [Generation(text="fizz")]
langchain.llm_cache.clear()
assert langchain.llm_cache.lookup("bar", llm_string) is None
get_llm_cache().clear()
assert get_llm_cache().lookup("bar", llm_string) is None

@ -12,8 +12,8 @@ from typing import Iterator
import pytest
import langchain
from langchain.cache import MomentoCache
from langchain.globals import set_llm_cache
from langchain.schema import Generation, LLMResult
from tests.unit_tests.llms.fake_llm import FakeLLM
@ -34,7 +34,7 @@ def momento_cache() -> Iterator[MomentoCache]:
)
try:
llm_cache = MomentoCache(client, cache_name)
langchain.llm_cache = llm_cache
set_llm_cache(llm_cache)
yield llm_cache
finally:
client.delete_cache(cache_name)

@ -1,11 +1,11 @@
"""Test Redis cache functionality."""
import uuid
from typing import List
from typing import List, cast
import pytest
import langchain
from langchain.cache import RedisCache, RedisSemanticCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain.load.dump import dumps
from langchain.schema import Generation, LLMResult
from langchain.schema.embeddings import Embeddings
@ -28,40 +28,42 @@ def random_string() -> str:
def test_redis_cache_ttl() -> None:
import redis
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1)
langchain.llm_cache.update("foo", "bar", [Generation(text="fizz")])
key = langchain.llm_cache._key("foo", "bar")
assert langchain.llm_cache.redis.pttl(key) > 0
set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1))
llm_cache = cast(RedisCache, get_llm_cache())
llm_cache.update("foo", "bar", [Generation(text="fizz")])
key = llm_cache._key("foo", "bar")
assert llm_cache.redis.pttl(key) > 0
def test_redis_cache() -> None:
import redis
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)))
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
langchain.llm_cache.redis.flushall()
llm_cache = cast(RedisCache, get_llm_cache())
llm_cache.redis.flushall()
def test_redis_cache_chat() -> None:
import redis
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)))
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
langchain.llm_cache.update(
get_llm_cache().update(
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
)
output = llm.generate([prompt])
@ -70,18 +72,21 @@ def test_redis_cache_chat() -> None:
llm_output={},
)
assert output == expected_output
langchain.llm_cache.redis.flushall()
llm_cache = cast(RedisCache, get_llm_cache())
llm_cache.redis.flushall()
def test_redis_semantic_cache() -> None:
langchain.llm_cache = RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
set_llm_cache(
RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
@ -91,24 +96,26 @@ def test_redis_semantic_cache() -> None:
)
assert output == expected_output
# clear the cache
langchain.llm_cache.clear(llm_string=llm_string)
get_llm_cache().clear(llm_string=llm_string)
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
# expect different output now without cached result
assert output != expected_output
langchain.llm_cache.clear(llm_string=llm_string)
get_llm_cache().clear(llm_string=llm_string)
def test_redis_semantic_cache_multi() -> None:
langchain.llm_cache = RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
set_llm_cache(
RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update(
get_llm_cache().update(
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
)
output = llm.generate(
@ -120,19 +127,21 @@ def test_redis_semantic_cache_multi() -> None:
)
assert output == expected_output
# clear the cache
langchain.llm_cache.clear(llm_string=llm_string)
get_llm_cache().clear(llm_string=llm_string)
def test_redis_semantic_cache_chat() -> None:
langchain.llm_cache = RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
set_llm_cache(
RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
)
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
langchain.llm_cache.update(
get_llm_cache().update(
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
)
output = llm.generate([prompt])
@ -141,7 +150,7 @@ def test_redis_semantic_cache_chat() -> None:
llm_output={},
)
assert output == expected_output
langchain.llm_cache.clear(llm_string=llm_string)
get_llm_cache().clear(llm_string=llm_string)
@pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()])
@ -170,9 +179,7 @@ def test_redis_semantic_cache_chat() -> None:
def test_redis_semantic_cache_hit(
embedding: Embeddings, prompts: List[str], generations: List[List[str]]
) -> None:
langchain.llm_cache = RedisSemanticCache(
embedding=embedding, redis_url=REDIS_TEST_URL
)
set_llm_cache(RedisSemanticCache(embedding=embedding, redis_url=REDIS_TEST_URL))
llm = FakeLLM()
params = llm.dict()
@ -189,7 +196,7 @@ def test_redis_semantic_cache_hit(
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
print(prompt_i)
print(llm_generations_i)
langchain.llm_cache.update(prompt_i, llm_string, llm_generations_i)
get_llm_cache().update(prompt_i, llm_string, llm_generations_i)
llm.generate(prompts)
assert llm.generate(prompts) == LLMResult(
generations=llm_generations, llm_output={}

@ -6,25 +6,25 @@ try:
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
import langchain
from langchain.cache import InMemoryCache, SQLAlchemyCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain.schema import Generation, LLMResult
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_caching() -> None:
"""Test caching behavior."""
langchain.llm_cache = InMemoryCache()
set_llm_cache(InMemoryCache())
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"])
expected_cache_output = [Generation(text="foo")]
cache_output = langchain.llm_cache.lookup("bar", llm_string)
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == expected_cache_output
langchain.llm_cache = None
set_llm_cache(None)
expected_generations = [
[Generation(text="fizz")],
[Generation(text="foo")],
@ -52,17 +52,17 @@ def test_custom_caching() -> None:
response = Column(String)
engine = create_engine("sqlite://")
langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)
set_llm_cache(SQLAlchemyCache(engine, FulltextLLMCache))
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"])
expected_cache_output = [Generation(text="foo")]
cache_output = langchain.llm_cache.lookup("bar", llm_string)
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == expected_cache_output
langchain.llm_cache = None
set_llm_cache(None)
expected_generations = [
[Generation(text="fizz")],
[Generation(text="foo")],

@ -6,13 +6,13 @@ from _pytest.fixtures import FixtureRequest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
import langchain
from langchain.cache import (
InMemoryCache,
SQLAlchemyCache,
)
from langchain.chat_models import FakeListChatModel
from langchain.chat_models.base import BaseChatModel, dumps
from langchain.globals import get_llm_cache, set_llm_cache
from langchain.llms import FakeListLLM
from langchain.llms.base import BaseLLM
from langchain.schema import (
@ -36,18 +36,18 @@ CACHE_OPTIONS = [
def set_cache_and_teardown(request: FixtureRequest) -> Generator[None, None, None]:
# Will be run before each test
cache_instance = request.param
langchain.llm_cache = cache_instance()
if langchain.llm_cache:
langchain.llm_cache.clear()
set_llm_cache(cache_instance())
if get_llm_cache():
get_llm_cache().clear()
else:
raise ValueError("Cache not set. This should never happen.")
yield
# Will be run after each test
if langchain.llm_cache:
langchain.llm_cache.clear()
langchain.llm_cache = None
if get_llm_cache():
get_llm_cache().clear()
set_llm_cache(None)
else:
raise ValueError("Cache not set. This should never happen.")
@ -57,8 +57,8 @@ def test_llm_caching() -> None:
response = "Test response"
cached_response = "Cached test response"
llm = FakeListLLM(responses=[response])
if langchain.llm_cache:
langchain.llm_cache.update(
if get_llm_cache():
get_llm_cache().update(
prompt=prompt,
llm_string=create_llm_string(llm),
return_val=[Generation(text=cached_response)],
@ -72,20 +72,21 @@ def test_llm_caching() -> None:
def test_old_sqlite_llm_caching() -> None:
if isinstance(langchain.llm_cache, SQLAlchemyCache):
llm_cache = get_llm_cache()
if isinstance(llm_cache, SQLAlchemyCache):
prompt = "How are you?"
response = "Test response"
cached_response = "Cached test response"
llm = FakeListLLM(responses=[response])
items = [
langchain.llm_cache.cache_schema(
llm_cache.cache_schema(
prompt=prompt,
llm=create_llm_string(llm),
response=cached_response,
idx=0,
)
]
with Session(langchain.llm_cache.engine) as session, session.begin():
with Session(llm_cache.engine) as session, session.begin():
for item in items:
session.merge(item)
assert llm(prompt) == cached_response
@ -97,8 +98,8 @@ def test_chat_model_caching() -> None:
cached_response = "Cached test response"
cached_message = AIMessage(content=cached_response)
llm = FakeListChatModel(responses=[response])
if langchain.llm_cache:
langchain.llm_cache.update(
if get_llm_cache():
get_llm_cache().update(
prompt=dumps(prompt),
llm_string=llm._get_llm_string(),
return_val=[ChatGeneration(message=cached_message)],
@ -119,8 +120,8 @@ def test_chat_model_caching_params() -> None:
cached_response = "Cached test response"
cached_message = AIMessage(content=cached_response)
llm = FakeListChatModel(responses=[response])
if langchain.llm_cache:
langchain.llm_cache.update(
if get_llm_cache():
get_llm_cache().update(
prompt=dumps(prompt),
llm_string=llm._get_llm_string(functions=[]),
return_val=[ChatGeneration(message=cached_message)],
@ -144,13 +145,13 @@ def test_llm_cache_clear() -> None:
response = "Test response"
cached_response = "Cached test response"
llm = FakeListLLM(responses=[response])
if langchain.llm_cache:
langchain.llm_cache.update(
if get_llm_cache():
get_llm_cache().update(
prompt=prompt,
llm_string=create_llm_string(llm),
return_val=[Generation(text=cached_response)],
)
langchain.llm_cache.clear()
get_llm_cache().clear()
assert llm(prompt) == response
else:
raise ValueError(

Loading…
Cancel
Save