forked from Archives/langchain
feat: add Momento as a standard cache and chat message history provider (#5221)
# Add Momento as a standard cache and chat message history provider This PR adds Momento as a standard caching provider. Implements the interface, adds integration tests, and documentation. We also add Momento as a chat history message provider along with integration tests, and documentation. [Momento](https://www.gomomento.com/) is a fully serverless cache. Similar to S3 or DynamoDB, it requires zero configuration, infrastructure management, and is instantly available. Users sign up for free and get 50GB of data in/out for free every month. ## Before submitting ✅ We have added documentation, notebooks, and integration tests demonstrating usage. Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
56ad56c812
commit
7047a2c1af
53
docs/integrations/momento.md
Normal file
53
docs/integrations/momento.md
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# Momento
|
||||||
|
|
||||||
|
This page covers how to use the [Momento](https://gomomento.com) ecosystem within LangChain.
|
||||||
|
It is broken into two parts: installation and setup, and then references to specific Momento wrappers.
|
||||||
|
|
||||||
|
## Installation and Setup
|
||||||
|
|
||||||
|
- Sign up for a free account [here](https://docs.momentohq.com/getting-started) and get an auth token
|
||||||
|
- Install the Momento Python SDK with `pip install momento`
|
||||||
|
|
||||||
|
## Wrappers
|
||||||
|
|
||||||
|
### Cache
|
||||||
|
|
||||||
|
The Cache wrapper allows for [Momento](https://gomomento.com) to be used as a serverless, distributed, low-latency cache for LLM prompts and responses.
|
||||||
|
|
||||||
|
#### Standard Cache
|
||||||
|
|
||||||
|
The standard cache is the go-to use case for [Momento](https://gomomento.com) users in any environment.
|
||||||
|
|
||||||
|
Import the cache as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.cache import MomentoCache
|
||||||
|
```
|
||||||
|
|
||||||
|
And set up like so:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from datetime import timedelta
|
||||||
|
from momento import CacheClient, Configurations, CredentialProvider
|
||||||
|
import langchain
|
||||||
|
|
||||||
|
# Instantiate the Momento client
|
||||||
|
cache_client = CacheClient(
|
||||||
|
Configurations.Laptop.v1(),
|
||||||
|
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
|
||||||
|
default_ttl=timedelta(days=1))
|
||||||
|
|
||||||
|
# Choose a Momento cache name of your choice
|
||||||
|
cache_name = "langchain"
|
||||||
|
|
||||||
|
# Instantiate the LLM cache
|
||||||
|
langchain.llm_cache = MomentoCache(cache_client, cache_name)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Memory
|
||||||
|
|
||||||
|
Momento can be used as a distributed memory store for LLMs.
|
||||||
|
|
||||||
|
#### Chat Message History Memory
|
||||||
|
|
||||||
|
See [this notebook](../modules/memory/examples/momento_chat_message_history.ipynb) for a walkthrough of how to use Momento as a memory store for chat message history.
|
@ -0,0 +1,86 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "91c6a7ef",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Momento\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over how to use [Momento Cache](https://gomomento.com) to store chat message history using the `MomentoChatMessageHistory` class. See the Momento [docs](https://docs.momentohq.com/getting-started) for more detail on how to get set up with Momento.\n",
|
||||||
|
"\n",
|
||||||
|
"Note that, by default we will create a cache if one with the given name doesn't already exist.\n",
|
||||||
|
"\n",
|
||||||
|
"You'll need to get a Momento auth token to use this class. This can either be passed in to a momento.CacheClient if you'd like to instantiate that directly, as a named parameter `auth_token` to `MomentoChatMessageHistory.from_client_params`, or can just be set as an environment variable `MOMENTO_AUTH_TOKEN`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "d15e3302",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from datetime import timedelta\n",
|
||||||
|
"\n",
|
||||||
|
"from langchain.memory import MomentoChatMessageHistory\n",
|
||||||
|
"\n",
|
||||||
|
"session_id = \"foo\"\n",
|
||||||
|
"cache_name = \"langchain\"\n",
|
||||||
|
"ttl = timedelta(days=1),\n",
|
||||||
|
"history = MomentoChatMessageHistory.from_client_params(\n",
|
||||||
|
" session_id, \n",
|
||||||
|
" cache_name,\n",
|
||||||
|
" ttl,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"history.add_user_message(\"hi!\")\n",
|
||||||
|
"\n",
|
||||||
|
"history.add_ai_message(\"whats up?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "64fc465e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n",
|
||||||
|
" AIMessage(content='whats up?', additional_kwargs={}, example=False)]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"history.messages"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -41,7 +41,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 6,
|
||||||
"id": "f69f6283",
|
"id": "f69f6283",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -612,6 +612,115 @@
|
|||||||
"llm(\"Tell me joke\")"
|
"llm(\"Tell me joke\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "726fe754",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Momento Cache\n",
|
||||||
|
"Use [Momento](../../../../integrations/momento.md) to cache prompts and responses.\n",
|
||||||
|
"\n",
|
||||||
|
"Requires momento to use, uncomment below to install:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e8949f29",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# !pip install momento"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "56ea6a08",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You'll need to get a Momemto auth token to use this class. This can either be passed in to a momento.CacheClient if you'd like to instantiate that directly, as a named parameter `auth_token` to `MomentoChatMessageHistory.from_client_params`, or can just be set as an environment variable `MOMENTO_AUTH_TOKEN`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "2005f03a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from datetime import timedelta\n",
|
||||||
|
"\n",
|
||||||
|
"from langchain.cache import MomentoCache\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"cache_name = \"langchain\"\n",
|
||||||
|
"ttl = timedelta(days=1)\n",
|
||||||
|
"langchain.llm_cache = MomentoCache.from_client_params(cache_name, ttl)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "c6a6c238",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"CPU times: user 40.7 ms, sys: 16.5 ms, total: 57.2 ms\n",
|
||||||
|
"Wall time: 1.73 s\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%%time\n",
|
||||||
|
"# The first time, it is not yet in cache, so it should take longer\n",
|
||||||
|
"llm(\"Tell me a joke\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"id": "b8f78f9d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"CPU times: user 3.16 ms, sys: 2.98 ms, total: 6.14 ms\n",
|
||||||
|
"Wall time: 57.9 ms\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%%time\n",
|
||||||
|
"# The second time it is, so it goes faster\n",
|
||||||
|
"# When run in the same region as the cache, latencies are single digit ms\n",
|
||||||
|
"llm(\"Tell me a joke\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "934943dc",
|
"id": "934943dc",
|
||||||
@ -909,9 +1018,9 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3",
|
"display_name": "venv",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "venv"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
@ -923,7 +1032,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.8.8"
|
"version": "3.11.3"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -1,14 +1,30 @@
|
|||||||
"""Beta Feature: base interface for cache."""
|
"""Beta Feature: base interface for cache."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
|
from datetime import timedelta
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, create_engine, select
|
from sqlalchemy import Column, Integer, String, create_engine, select
|
||||||
from sqlalchemy.engine.base import Engine
|
from sqlalchemy.engine.base import Engine
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from langchain.utils import get_from_env
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sqlalchemy.orm import declarative_base
|
from sqlalchemy.orm import declarative_base
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -18,6 +34,9 @@ from langchain.embeddings.base import Embeddings
|
|||||||
from langchain.schema import Generation
|
from langchain.schema import Generation
|
||||||
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import momento
|
||||||
|
|
||||||
RETURN_VAL_TYPE = List[Generation]
|
RETURN_VAL_TYPE = List[Generation]
|
||||||
|
|
||||||
|
|
||||||
@ -26,6 +45,39 @@ def _hash(_input: str) -> str:
|
|||||||
return hashlib.md5(_input.encode()).hexdigest()
|
return hashlib.md5(_input.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str:
|
||||||
|
"""Dump generations to json.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generations (RETURN_VAL_TYPE): A list of language model generations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Json representing a list of generations.
|
||||||
|
"""
|
||||||
|
return json.dumps([generation.dict() for generation in generations])
|
||||||
|
|
||||||
|
|
||||||
|
def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
|
||||||
|
"""Load generations from json.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generations_json (str): A string of json representing a list of generations.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Could not decode json string to list of generations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RETURN_VAL_TYPE: A list of generations.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
results = json.loads(generations_json)
|
||||||
|
return [Generation(**generation_dict) for generation_dict in results]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not decode json to list of generations: {generations_json}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseCache(ABC):
|
class BaseCache(ABC):
|
||||||
"""Base interface for cache."""
|
"""Base interface for cache."""
|
||||||
|
|
||||||
@ -390,3 +442,179 @@ class GPTCache(BaseCache):
|
|||||||
gptcache_instance.flush()
|
gptcache_instance.flush()
|
||||||
|
|
||||||
self.gptcache_dict.clear()
|
self.gptcache_dict.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None:
|
||||||
|
"""Create cache if it doesn't exist.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SdkException: Momento service or network error
|
||||||
|
Exception: Unexpected response
|
||||||
|
"""
|
||||||
|
from momento.responses import CreateCache
|
||||||
|
|
||||||
|
create_cache_response = cache_client.create_cache(cache_name)
|
||||||
|
if isinstance(create_cache_response, CreateCache.Success) or isinstance(
|
||||||
|
create_cache_response, CreateCache.CacheAlreadyExists
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
elif isinstance(create_cache_response, CreateCache.Error):
|
||||||
|
raise create_cache_response.inner_exception
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unexpected response cache creation: {create_cache_response}")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_ttl(ttl: Optional[timedelta]) -> None:
|
||||||
|
if ttl is not None and ttl <= timedelta(seconds=0):
|
||||||
|
raise ValueError(f"ttl must be positive but was {ttl}.")
|
||||||
|
|
||||||
|
|
||||||
|
class MomentoCache(BaseCache):
|
||||||
|
"""Cache that uses Momento as a backend. See https://gomomento.com/"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cache_client: momento.CacheClient,
|
||||||
|
cache_name: str,
|
||||||
|
*,
|
||||||
|
ttl: Optional[timedelta] = None,
|
||||||
|
ensure_cache_exists: bool = True,
|
||||||
|
):
|
||||||
|
"""Instantiate a prompt cache using Momento as a backend.
|
||||||
|
|
||||||
|
Note: to instantiate the cache client passed to MomentoCache,
|
||||||
|
you must have a Momento account. See https://gomomento.com/.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_client (CacheClient): The Momento cache client.
|
||||||
|
cache_name (str): The name of the cache to use to store the data.
|
||||||
|
ttl (Optional[timedelta], optional): The time to live for the cache items.
|
||||||
|
Defaults to None, ie use the client default TTL.
|
||||||
|
ensure_cache_exists (bool, optional): Create the cache if it doesn't
|
||||||
|
exist. Defaults to True.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: Momento python package is not installed.
|
||||||
|
TypeError: cache_client is not of type momento.CacheClientObject
|
||||||
|
ValueError: ttl is non-null and non-negative
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from momento import CacheClient
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import momento python package. "
|
||||||
|
"Please install it with `pip install momento`."
|
||||||
|
)
|
||||||
|
if not isinstance(cache_client, CacheClient):
|
||||||
|
raise TypeError("cache_client must be a momento.CacheClient object.")
|
||||||
|
_validate_ttl(ttl)
|
||||||
|
if ensure_cache_exists:
|
||||||
|
_ensure_cache_exists(cache_client, cache_name)
|
||||||
|
|
||||||
|
self.cache_client = cache_client
|
||||||
|
self.cache_name = cache_name
|
||||||
|
self.ttl = ttl
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_client_params(
|
||||||
|
cls,
|
||||||
|
cache_name: str,
|
||||||
|
ttl: timedelta,
|
||||||
|
*,
|
||||||
|
configuration: Optional[momento.config.Configuration] = None,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> MomentoCache:
|
||||||
|
"""Construct cache from CacheClient parameters."""
|
||||||
|
try:
|
||||||
|
from momento import CacheClient, Configurations, CredentialProvider
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import momento python package. "
|
||||||
|
"Please install it with `pip install momento`."
|
||||||
|
)
|
||||||
|
if configuration is None:
|
||||||
|
configuration = Configurations.Laptop.v1()
|
||||||
|
auth_token = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN")
|
||||||
|
credentials = CredentialProvider.from_string(auth_token)
|
||||||
|
cache_client = CacheClient(configuration, credentials, default_ttl=ttl)
|
||||||
|
return cls(cache_client, cache_name, ttl=ttl, **kwargs)
|
||||||
|
|
||||||
|
def __key(self, prompt: str, llm_string: str) -> str:
|
||||||
|
"""Compute cache key from prompt and associated model and settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt run through the language model.
|
||||||
|
llm_string (str): The language model version and settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The cache key.
|
||||||
|
"""
|
||||||
|
return _hash(prompt + llm_string)
|
||||||
|
|
||||||
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
|
"""Lookup llm generations in cache by prompt and associated model and settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt run through the language model.
|
||||||
|
llm_string (str): The language model version and settings.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SdkException: Momento service or network error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[RETURN_VAL_TYPE]: A list of language model generations.
|
||||||
|
"""
|
||||||
|
from momento.responses import CacheGet
|
||||||
|
|
||||||
|
generations = []
|
||||||
|
|
||||||
|
get_response = self.cache_client.get(
|
||||||
|
self.cache_name, self.__key(prompt, llm_string)
|
||||||
|
)
|
||||||
|
if isinstance(get_response, CacheGet.Hit):
|
||||||
|
value = get_response.value_string
|
||||||
|
generations = _load_generations_from_json(value)
|
||||||
|
elif isinstance(get_response, CacheGet.Miss):
|
||||||
|
pass
|
||||||
|
elif isinstance(get_response, CacheGet.Error):
|
||||||
|
raise get_response.inner_exception
|
||||||
|
return generations if generations else None
|
||||||
|
|
||||||
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
|
"""Store llm generations in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt run through the language model.
|
||||||
|
llm_string (str): The language model string.
|
||||||
|
return_val (RETURN_VAL_TYPE): A list of language model generations.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SdkException: Momento service or network error
|
||||||
|
Exception: Unexpected response
|
||||||
|
"""
|
||||||
|
key = self.__key(prompt, llm_string)
|
||||||
|
value = _dump_generations_to_json(return_val)
|
||||||
|
set_response = self.cache_client.set(self.cache_name, key, value, self.ttl)
|
||||||
|
from momento.responses import CacheSet
|
||||||
|
|
||||||
|
if isinstance(set_response, CacheSet.Success):
|
||||||
|
pass
|
||||||
|
elif isinstance(set_response, CacheSet.Error):
|
||||||
|
raise set_response.inner_exception
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unexpected response: {set_response}")
|
||||||
|
|
||||||
|
def clear(self, **kwargs: Any) -> None:
|
||||||
|
"""Clear the cache.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SdkException: Momento service or network error
|
||||||
|
"""
|
||||||
|
from momento.responses import CacheFlush
|
||||||
|
|
||||||
|
flush_response = self.cache_client.flush_cache(self.cache_name)
|
||||||
|
if isinstance(flush_response, CacheFlush.Success):
|
||||||
|
pass
|
||||||
|
elif isinstance(flush_response, CacheFlush.Error):
|
||||||
|
raise flush_response.inner_exception
|
||||||
|
@ -3,6 +3,7 @@ from langchain.memory.buffer import (
|
|||||||
ConversationStringBufferMemory,
|
ConversationStringBufferMemory,
|
||||||
)
|
)
|
||||||
from langchain.memory.buffer_window import ConversationBufferWindowMemory
|
from langchain.memory.buffer_window import ConversationBufferWindowMemory
|
||||||
|
from langchain.memory.chat_message_histories import MomentoChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.cassandra import (
|
from langchain.memory.chat_message_histories.cassandra import (
|
||||||
CassandraChatMessageHistory,
|
CassandraChatMessageHistory,
|
||||||
)
|
)
|
||||||
@ -50,4 +51,5 @@ __all__ = [
|
|||||||
"FileChatMessageHistory",
|
"FileChatMessageHistory",
|
||||||
"MongoDBChatMessageHistory",
|
"MongoDBChatMessageHistory",
|
||||||
"CassandraChatMessageHistory",
|
"CassandraChatMessageHistory",
|
||||||
|
"MomentoChatMessageHistory",
|
||||||
]
|
]
|
||||||
|
@ -7,6 +7,7 @@ from langchain.memory.chat_message_histories.file import FileChatMessageHistory
|
|||||||
from langchain.memory.chat_message_histories.firestore import (
|
from langchain.memory.chat_message_histories.firestore import (
|
||||||
FirestoreChatMessageHistory,
|
FirestoreChatMessageHistory,
|
||||||
)
|
)
|
||||||
|
from langchain.memory.chat_message_histories.momento import MomentoChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHistory
|
from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
|
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
|
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
|
||||||
@ -24,4 +25,5 @@ __all__ = [
|
|||||||
"MongoDBChatMessageHistory",
|
"MongoDBChatMessageHistory",
|
||||||
"CassandraChatMessageHistory",
|
"CassandraChatMessageHistory",
|
||||||
"ZepChatMessageHistory",
|
"ZepChatMessageHistory",
|
||||||
|
"MomentoChatMessageHistory",
|
||||||
]
|
]
|
||||||
|
200
langchain/memory/chat_message_histories/momento.py
Normal file
200
langchain/memory/chat_message_histories/momento.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
from langchain.schema import (
|
||||||
|
AIMessage,
|
||||||
|
BaseChatMessageHistory,
|
||||||
|
BaseMessage,
|
||||||
|
HumanMessage,
|
||||||
|
_message_to_dict,
|
||||||
|
messages_from_dict,
|
||||||
|
)
|
||||||
|
from langchain.utils import get_from_env
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import momento
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None:
|
||||||
|
"""Create cache if it doesn't exist.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SdkException: Momento service or network error
|
||||||
|
Exception: Unexpected response
|
||||||
|
"""
|
||||||
|
from momento.responses import CreateCache
|
||||||
|
|
||||||
|
create_cache_response = cache_client.create_cache(cache_name)
|
||||||
|
if isinstance(create_cache_response, CreateCache.Success) or isinstance(
|
||||||
|
create_cache_response, CreateCache.CacheAlreadyExists
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
elif isinstance(create_cache_response, CreateCache.Error):
|
||||||
|
raise create_cache_response.inner_exception
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unexpected response cache creation: {create_cache_response}")
|
||||||
|
|
||||||
|
|
||||||
|
class MomentoChatMessageHistory(BaseChatMessageHistory):
|
||||||
|
"""Chat message history cache that uses Momento as a backend.
|
||||||
|
See https://gomomento.com/"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
cache_client: momento.CacheClient,
|
||||||
|
cache_name: str,
|
||||||
|
*,
|
||||||
|
key_prefix: str = "message_store:",
|
||||||
|
ttl: Optional[timedelta] = None,
|
||||||
|
ensure_cache_exists: bool = True,
|
||||||
|
):
|
||||||
|
"""Instantiate a chat message history cache that uses Momento as a backend.
|
||||||
|
|
||||||
|
Note: to instantiate the cache client passed to MomentoChatMessageHistory,
|
||||||
|
you must have a Momento account at https://gomomento.com/.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id (str): The session ID to use for this chat session.
|
||||||
|
cache_client (CacheClient): The Momento cache client.
|
||||||
|
cache_name (str): The name of the cache to use to store the messages.
|
||||||
|
key_prefix (str, optional): The prefix to apply to the cache key.
|
||||||
|
Defaults to "message_store:".
|
||||||
|
ttl (Optional[timedelta], optional): The TTL to use for the messages.
|
||||||
|
Defaults to None, ie the default TTL of the cache will be used.
|
||||||
|
ensure_cache_exists (bool, optional): Create the cache if it doesn't exist.
|
||||||
|
Defaults to True.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: Momento python package is not installed.
|
||||||
|
TypeError: cache_client is not of type momento.CacheClientObject
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from momento import CacheClient
|
||||||
|
from momento.requests import CollectionTtl
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import momento python package. "
|
||||||
|
"Please install it with `pip install momento`."
|
||||||
|
)
|
||||||
|
if not isinstance(cache_client, CacheClient):
|
||||||
|
raise TypeError("cache_client must be a momento.CacheClient object.")
|
||||||
|
if ensure_cache_exists:
|
||||||
|
_ensure_cache_exists(cache_client, cache_name)
|
||||||
|
self.key = key_prefix + session_id
|
||||||
|
self.cache_client = cache_client
|
||||||
|
self.cache_name = cache_name
|
||||||
|
if ttl is not None:
|
||||||
|
self.ttl = CollectionTtl.of(ttl)
|
||||||
|
else:
|
||||||
|
self.ttl = CollectionTtl.from_cache_ttl()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_client_params(
|
||||||
|
cls,
|
||||||
|
session_id: str,
|
||||||
|
cache_name: str,
|
||||||
|
ttl: timedelta,
|
||||||
|
*,
|
||||||
|
configuration: Optional[momento.config.Configuration] = None,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> MomentoChatMessageHistory:
|
||||||
|
"""Construct cache from CacheClient parameters."""
|
||||||
|
try:
|
||||||
|
from momento import CacheClient, Configurations, CredentialProvider
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import momento python package. "
|
||||||
|
"Please install it with `pip install momento`."
|
||||||
|
)
|
||||||
|
if configuration is None:
|
||||||
|
configuration = Configurations.Laptop.v1()
|
||||||
|
auth_token = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN")
|
||||||
|
credentials = CredentialProvider.from_string(auth_token)
|
||||||
|
cache_client = CacheClient(configuration, credentials, default_ttl=ttl)
|
||||||
|
return cls(session_id, cache_client, cache_name, ttl=ttl, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self) -> list[BaseMessage]: # type: ignore[override]
|
||||||
|
"""Retrieve the messages from Momento.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SdkException: Momento service or network error
|
||||||
|
Exception: Unexpected response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[BaseMessage]: List of cached messages
|
||||||
|
"""
|
||||||
|
from momento.responses import CacheListFetch
|
||||||
|
|
||||||
|
fetch_response = self.cache_client.list_fetch(self.cache_name, self.key)
|
||||||
|
|
||||||
|
if isinstance(fetch_response, CacheListFetch.Hit):
|
||||||
|
items = [json.loads(m) for m in fetch_response.value_list_string]
|
||||||
|
return messages_from_dict(items)
|
||||||
|
elif isinstance(fetch_response, CacheListFetch.Miss):
|
||||||
|
return []
|
||||||
|
elif isinstance(fetch_response, CacheListFetch.Error):
|
||||||
|
raise fetch_response.inner_exception
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unexpected response: {fetch_response}")
|
||||||
|
|
||||||
|
def add_user_message(self, message: str) -> None:
|
||||||
|
"""Store a user message in the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The message to store.
|
||||||
|
"""
|
||||||
|
self.__add_message(HumanMessage(content=message))
|
||||||
|
|
||||||
|
def add_ai_message(self, message: str) -> None:
|
||||||
|
"""Store an AI message in the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The message to store.
|
||||||
|
"""
|
||||||
|
self.__add_message(AIMessage(content=message))
|
||||||
|
|
||||||
|
def __add_message(self, message: BaseMessage) -> None:
|
||||||
|
"""Store a message in the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (BaseMessage): The message object to store.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SdkException: Momento service or network error.
|
||||||
|
Exception: Unexpected response.
|
||||||
|
"""
|
||||||
|
from momento.responses import CacheListPushBack
|
||||||
|
|
||||||
|
item = json.dumps(_message_to_dict(message))
|
||||||
|
push_response = self.cache_client.list_push_back(
|
||||||
|
self.cache_name, self.key, item, ttl=self.ttl
|
||||||
|
)
|
||||||
|
if isinstance(push_response, CacheListPushBack.Success):
|
||||||
|
return None
|
||||||
|
elif isinstance(push_response, CacheListPushBack.Error):
|
||||||
|
raise push_response.inner_exception
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unexpected response: {push_response}")
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Remove the session's messages from the cache.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SdkException: Momento service or network error.
|
||||||
|
Exception: Unexpected response.
|
||||||
|
"""
|
||||||
|
from momento.responses import CacheDelete
|
||||||
|
|
||||||
|
delete_response = self.cache_client.delete(self.cache_name, self.key)
|
||||||
|
if isinstance(delete_response, CacheDelete.Success):
|
||||||
|
return None
|
||||||
|
elif isinstance(delete_response, CacheDelete.Error):
|
||||||
|
raise delete_response.inner_exception
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unexpected response: {delete_response}")
|
41
poetry.lock
generated
41
poetry.lock
generated
@ -2621,7 +2621,7 @@ name = "grpcio"
|
|||||||
version = "1.47.5"
|
version = "1.47.5"
|
||||||
description = "HTTP/2-based RPC framework"
|
description = "HTTP/2-based RPC framework"
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = true
|
optional = false
|
||||||
python-versions = ">=3.6"
|
python-versions = ">=3.6"
|
||||||
files = [
|
files = [
|
||||||
{file = "grpcio-1.47.5-cp310-cp310-linux_armv7l.whl", hash = "sha256:acc73289d0c44650aa1f21eccfa967f5623b01c3b5e2b4596fe5f9c5bf10956d"},
|
{file = "grpcio-1.47.5-cp310-cp310-linux_armv7l.whl", hash = "sha256:acc73289d0c44650aa1f21eccfa967f5623b01c3b5e2b4596fe5f9c5bf10956d"},
|
||||||
@ -4462,6 +4462,39 @@ files = [
|
|||||||
{file = "mmh3-3.1.0.tar.gz", hash = "sha256:9b0f2b2ab4a915333c9d1089572e290a021ebb5b900bb7f7114dccc03995d732"},
|
{file = "mmh3-3.1.0.tar.gz", hash = "sha256:9b0f2b2ab4a915333c9d1089572e290a021ebb5b900bb7f7114dccc03995d732"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "momento"
|
||||||
|
version = "1.5.0"
|
||||||
|
description = "SDK for Momento"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7,<4.0"
|
||||||
|
files = [
|
||||||
|
{file = "momento-1.5.0-py3-none-any.whl", hash = "sha256:7f633fb26ddf1bfcaf99a7add37b085f0e23c96f85972b655eaaf2de1c61d7f1"},
|
||||||
|
{file = "momento-1.5.0.tar.gz", hash = "sha256:68ca5d24b4cb08c5c0bd22d4edd3b8b0fcf087a85d30673cb2c55b11971c76ec"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
grpcio = ">=1.46.0,<2.0.0"
|
||||||
|
momento-wire-types = ">=0.64,<0.65"
|
||||||
|
pyjwt = ">=2.4.0,<3.0.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "momento-wire-types"
|
||||||
|
version = "0.64.0"
|
||||||
|
description = "Momento Client Proto Generated Files"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7,<4.0"
|
||||||
|
files = [
|
||||||
|
{file = "momento_wire_types-0.64.0-py3-none-any.whl", hash = "sha256:30a5d523cef9209c0863db25e6344044b0c7240fea183a41c1433ca83cefea5b"},
|
||||||
|
{file = "momento_wire_types-0.64.0.tar.gz", hash = "sha256:5d3647210d49d0c3032a74ae5f5cc012a9faf826786272e1436a3d84d70a8bd5"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
grpcio = "*"
|
||||||
|
protobuf = ">=3,<5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "monotonic"
|
name = "monotonic"
|
||||||
version = "1.6"
|
version = "1.6"
|
||||||
@ -6410,7 +6443,7 @@ name = "protobuf"
|
|||||||
version = "3.19.6"
|
version = "3.19.6"
|
||||||
description = "Protocol Buffers"
|
description = "Protocol Buffers"
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = true
|
optional = false
|
||||||
python-versions = ">=3.5"
|
python-versions = ">=3.5"
|
||||||
files = [
|
files = [
|
||||||
{file = "protobuf-3.19.6-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:010be24d5a44be7b0613750ab40bc8b8cedc796db468eae6c779b395f50d1fa1"},
|
{file = "protobuf-3.19.6-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:010be24d5a44be7b0613750ab40bc8b8cedc796db468eae6c779b395f50d1fa1"},
|
||||||
@ -10865,7 +10898,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
|
|||||||
cffi = ["cffi (>=1.11)"]
|
cffi = ["cffi (>=1.11)"]
|
||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "lxml", "manifest-ml", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "requests-toolbelt", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
|
all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "lxml", "manifest-ml", "momento", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "requests-toolbelt", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
|
||||||
azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "openai"]
|
azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "openai"]
|
||||||
cohere = ["cohere"]
|
cohere = ["cohere"]
|
||||||
docarray = ["docarray"]
|
docarray = ["docarray"]
|
||||||
@ -10879,4 +10912,4 @@ text-helpers = ["chardet"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "0f2ed0b37063f533b3403d545b48633ced4d02c15bc8f4e47f1ded1652ab9764"
|
content-hash = "5e83a1f4ca8c0d3107363e393485174fd72ce9db93db5dc7c21b2dd37b184e66"
|
||||||
|
@ -97,6 +97,7 @@ scikit-learn = {version = "^1.2.2", optional = true}
|
|||||||
azure-ai-formrecognizer = {version = "^3.2.1", optional = true}
|
azure-ai-formrecognizer = {version = "^3.2.1", optional = true}
|
||||||
azure-ai-vision = {version = "^0.11.1b1", optional = true}
|
azure-ai-vision = {version = "^0.11.1b1", optional = true}
|
||||||
azure-cognitiveservices-speech = {version = "^1.28.0", optional = true}
|
azure-cognitiveservices-speech = {version = "^1.28.0", optional = true}
|
||||||
|
momento = {version = "^1.5.0", optional = true}
|
||||||
bibtexparser = {version = "^1.4.0", optional = true}
|
bibtexparser = {version = "^1.4.0", optional = true}
|
||||||
|
|
||||||
[tool.poetry.group.docs.dependencies]
|
[tool.poetry.group.docs.dependencies]
|
||||||
@ -160,6 +161,7 @@ pymongo = "^4.3.3"
|
|||||||
cassandra-driver = "^3.27.0"
|
cassandra-driver = "^3.27.0"
|
||||||
arxiv = "^1.4"
|
arxiv = "^1.4"
|
||||||
mastodon-py = "^1.8.1"
|
mastodon-py = "^1.8.1"
|
||||||
|
momento = "^1.5.0"
|
||||||
|
|
||||||
[tool.poetry.group.lint.dependencies]
|
[tool.poetry.group.lint.dependencies]
|
||||||
ruff = "^0.0.249"
|
ruff = "^0.0.249"
|
||||||
@ -253,6 +255,7 @@ all = [
|
|||||||
"azure-ai-formrecognizer",
|
"azure-ai-formrecognizer",
|
||||||
"azure-ai-vision",
|
"azure-ai-vision",
|
||||||
"azure-cognitiveservices-speech",
|
"azure-cognitiveservices-speech",
|
||||||
|
"momento"
|
||||||
]
|
]
|
||||||
|
|
||||||
# An extra used to be able to add extended testing.
|
# An extra used to be able to add extended testing.
|
||||||
|
94
tests/integration_tests/cache/test_momento_cache.py
vendored
Normal file
94
tests/integration_tests/cache/test_momento_cache.py
vendored
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
"""Test Momento cache functionality.
|
||||||
|
|
||||||
|
To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid
|
||||||
|
Momento auth token. This can be obtained by signing up for a free
|
||||||
|
Momento account at https://gomomento.com/.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from momento import CacheClient, Configurations, CredentialProvider
|
||||||
|
|
||||||
|
import langchain
|
||||||
|
from langchain.cache import MomentoCache
|
||||||
|
from langchain.schema import Generation, LLMResult
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
def random_string() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def momento_cache() -> Iterator[MomentoCache]:
|
||||||
|
cache_name = f"langchain-test-cache-{random_string()}"
|
||||||
|
client = CacheClient(
|
||||||
|
Configurations.Laptop.v1(),
|
||||||
|
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
|
||||||
|
default_ttl=timedelta(seconds=30),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
llm_cache = MomentoCache(client, cache_name)
|
||||||
|
langchain.llm_cache = llm_cache
|
||||||
|
yield llm_cache
|
||||||
|
finally:
|
||||||
|
client.delete_cache(cache_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_ttl() -> None:
|
||||||
|
client = CacheClient(
|
||||||
|
Configurations.Laptop.v1(),
|
||||||
|
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
|
||||||
|
default_ttl=timedelta(seconds=30),
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
MomentoCache(client, cache_name=random_string(), ttl=timedelta(seconds=-1))
|
||||||
|
|
||||||
|
|
||||||
|
def test_momento_cache_miss(momento_cache: MomentoCache) -> None:
|
||||||
|
llm = FakeLLM()
|
||||||
|
stub_llm_output = LLMResult(generations=[[Generation(text="foo")]])
|
||||||
|
assert llm.generate([random_string()]) == stub_llm_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"prompts, generations",
|
||||||
|
[
|
||||||
|
# Single prompt, single generation
|
||||||
|
([random_string()], [[random_string()]]),
|
||||||
|
# Single prompt, multiple generations
|
||||||
|
([random_string()], [[random_string(), random_string()]]),
|
||||||
|
# Single prompt, multiple generations
|
||||||
|
([random_string()], [[random_string(), random_string(), random_string()]]),
|
||||||
|
# Multiple prompts, multiple generations
|
||||||
|
(
|
||||||
|
[random_string(), random_string()],
|
||||||
|
[[random_string()], [random_string(), random_string()]],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_momento_cache_hit(
|
||||||
|
momento_cache: MomentoCache, prompts: list[str], generations: list[list[str]]
|
||||||
|
) -> None:
|
||||||
|
llm = FakeLLM()
|
||||||
|
params = llm.dict()
|
||||||
|
params["stop"] = None
|
||||||
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
|
||||||
|
llm_generations = [
|
||||||
|
[
|
||||||
|
Generation(text=generation, generation_info=params)
|
||||||
|
for generation in prompt_i_generations
|
||||||
|
]
|
||||||
|
for prompt_i_generations in generations
|
||||||
|
]
|
||||||
|
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
|
||||||
|
momento_cache.update(prompt_i, llm_string, llm_generations_i)
|
||||||
|
|
||||||
|
assert llm.generate(prompts) == LLMResult(
|
||||||
|
generations=llm_generations, llm_output={}
|
||||||
|
)
|
70
tests/integration_tests/memory/test_momento.py
Normal file
70
tests/integration_tests/memory/test_momento.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
"""Test Momento chat message history functionality.
|
||||||
|
|
||||||
|
To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid
|
||||||
|
Momento auth token. This can be obtained by signing up for a free
|
||||||
|
Momento account at https://gomomento.com/.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from momento import CacheClient, Configurations, CredentialProvider
|
||||||
|
|
||||||
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
from langchain.memory.chat_message_histories import MomentoChatMessageHistory
|
||||||
|
from langchain.schema import _message_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
def random_string() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def message_history() -> Iterator[MomentoChatMessageHistory]:
|
||||||
|
cache_name = f"langchain-test-cache-{random_string()}"
|
||||||
|
client = CacheClient(
|
||||||
|
Configurations.Laptop.v1(),
|
||||||
|
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
|
||||||
|
default_ttl=timedelta(seconds=30),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
chat_message_history = MomentoChatMessageHistory(
|
||||||
|
session_id="my-test-session",
|
||||||
|
cache_client=client,
|
||||||
|
cache_name=cache_name,
|
||||||
|
)
|
||||||
|
yield chat_message_history
|
||||||
|
finally:
|
||||||
|
client.delete_cache(cache_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_empty_on_new_session(
|
||||||
|
message_history: MomentoChatMessageHistory,
|
||||||
|
) -> None:
|
||||||
|
memory = ConversationBufferMemory(
|
||||||
|
memory_key="foo", chat_memory=message_history, return_messages=True
|
||||||
|
)
|
||||||
|
assert memory.chat_memory.messages == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_with_message_store(message_history: MomentoChatMessageHistory) -> None:
|
||||||
|
memory = ConversationBufferMemory(
|
||||||
|
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add some messages to the memory store
|
||||||
|
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||||
|
memory.chat_memory.add_user_message("This is me, the human")
|
||||||
|
|
||||||
|
# Verify that the messages are in the store
|
||||||
|
messages = memory.chat_memory.messages
|
||||||
|
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||||
|
|
||||||
|
assert "This is me, the AI" in messages_json
|
||||||
|
assert "This is me, the human" in messages_json
|
||||||
|
|
||||||
|
# Verify clearing the store
|
||||||
|
memory.chat_memory.clear()
|
||||||
|
assert memory.chat_memory.messages == []
|
Loading…
Reference in New Issue
Block a user