You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/tests/integration_tests/cache/test_momento_cache.py

95 lines
2.9 KiB
Python

"""Test Momento cache functionality.
To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid
Momento auth token. This can be obtained by signing up for a free
Momento account at https://gomomento.com/.
"""
from __future__ import annotations
import uuid
from datetime import timedelta
from typing import Iterator
import pytest
from momento import CacheClient, Configurations, CredentialProvider
import langchain
from langchain.cache import MomentoCache
from langchain.schema import Generation, LLMResult
from tests.unit_tests.llms.fake_llm import FakeLLM
def random_string() -> str:
return str(uuid.uuid4())
@pytest.fixture(scope="module")
def momento_cache() -> Iterator[MomentoCache]:
cache_name = f"langchain-test-cache-{random_string()}"
client = CacheClient(
Configurations.Laptop.v1(),
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
default_ttl=timedelta(seconds=30),
)
try:
llm_cache = MomentoCache(client, cache_name)
langchain.llm_cache = llm_cache
yield llm_cache
finally:
client.delete_cache(cache_name)
def test_invalid_ttl() -> None:
client = CacheClient(
Configurations.Laptop.v1(),
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
default_ttl=timedelta(seconds=30),
)
with pytest.raises(ValueError):
MomentoCache(client, cache_name=random_string(), ttl=timedelta(seconds=-1))
def test_momento_cache_miss(momento_cache: MomentoCache) -> None:
llm = FakeLLM()
stub_llm_output = LLMResult(generations=[[Generation(text="foo")]])
assert llm.generate([random_string()]) == stub_llm_output
@pytest.mark.parametrize(
"prompts, generations",
[
# Single prompt, single generation
([random_string()], [[random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string(), random_string()]]),
# Multiple prompts, multiple generations
(
[random_string(), random_string()],
[[random_string()], [random_string(), random_string()]],
),
],
)
def test_momento_cache_hit(
momento_cache: MomentoCache, prompts: list[str], generations: list[list[str]]
) -> None:
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_generations = [
[
Generation(text=generation, generation_info=params)
for generation in prompt_i_generations
]
for prompt_i_generations in generations
]
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
momento_cache.update(prompt_i, llm_string, llm_generations_i)
assert llm.generate(prompts) == LLMResult(
generations=llm_generations, llm_output={}
)