langchain/tests/integration_tests/cache/test_momento_cache.py
Michael Landis 7047a2c1af
feat: add Momento as a standard cache and chat message history provider (#5221)
# Add Momento as a standard cache and chat message history provider

This PR adds Momento as a standard caching provider. Implements the
interface, adds integration tests, and documentation. We also add
Momento as a chat history message provider along with integration tests,
and documentation.

[Momento](https://www.gomomento.com/) is a fully serverless cache.
Similar to S3 or DynamoDB, it requires zero configuration,
infrastructure management, and is instantly available. Users sign up for
free and get 50GB of data in/out for free every month.

## Before submitting

 We have added documentation, notebooks, and integration tests
demonstrating usage.

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
2023-05-25 19:13:21 -07:00

95 lines
2.9 KiB
Python

"""Test Momento cache functionality.
To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid
Momento auth token. This can be obtained by signing up for a free
Momento account at https://gomomento.com/.
"""
from __future__ import annotations
import uuid
from datetime import timedelta
from typing import Iterator
import pytest
from momento import CacheClient, Configurations, CredentialProvider
import langchain
from langchain.cache import MomentoCache
from langchain.schema import Generation, LLMResult
from tests.unit_tests.llms.fake_llm import FakeLLM
def random_string() -> str:
return str(uuid.uuid4())
@pytest.fixture(scope="module")
def momento_cache() -> Iterator[MomentoCache]:
cache_name = f"langchain-test-cache-{random_string()}"
client = CacheClient(
Configurations.Laptop.v1(),
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
default_ttl=timedelta(seconds=30),
)
try:
llm_cache = MomentoCache(client, cache_name)
langchain.llm_cache = llm_cache
yield llm_cache
finally:
client.delete_cache(cache_name)
def test_invalid_ttl() -> None:
client = CacheClient(
Configurations.Laptop.v1(),
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
default_ttl=timedelta(seconds=30),
)
with pytest.raises(ValueError):
MomentoCache(client, cache_name=random_string(), ttl=timedelta(seconds=-1))
def test_momento_cache_miss(momento_cache: MomentoCache) -> None:
llm = FakeLLM()
stub_llm_output = LLMResult(generations=[[Generation(text="foo")]])
assert llm.generate([random_string()]) == stub_llm_output
@pytest.mark.parametrize(
"prompts, generations",
[
# Single prompt, single generation
([random_string()], [[random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string(), random_string()]]),
# Multiple prompts, multiple generations
(
[random_string(), random_string()],
[[random_string()], [random_string(), random_string()]],
),
],
)
def test_momento_cache_hit(
momento_cache: MomentoCache, prompts: list[str], generations: list[list[str]]
) -> None:
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_generations = [
[
Generation(text=generation, generation_info=params)
for generation in prompt_i_generations
]
for prompt_i_generations in generations
]
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
momento_cache.update(prompt_i, llm_string, llm_generations_i)
assert llm.generate(prompts) == LLMResult(
generations=llm_generations, llm_output={}
)