forked from Archives/langchain
feat: add Momento as a standard cache and chat message history provider (#5221)
# Add Momento as a standard cache and chat message history provider This PR adds Momento as a standard caching provider. Implements the interface, adds integration tests, and documentation. We also add Momento as a chat history message provider along with integration tests, and documentation. [Momento](https://www.gomomento.com/) is a fully serverless cache. Similar to S3 or DynamoDB, it requires zero configuration, infrastructure management, and is instantly available. Users sign up for free and get 50GB of data in/out for free every month. ## Before submitting ✅ We have added documentation, notebooks, and integration tests demonstrating usage. Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
56ad56c812
commit
7047a2c1af
53
docs/integrations/momento.md
Normal file
53
docs/integrations/momento.md
Normal file
@ -0,0 +1,53 @@
|
||||
# Momento
|
||||
|
||||
This page covers how to use the [Momento](https://gomomento.com) ecosystem within LangChain.
|
||||
It is broken into two parts: installation and setup, and then references to specific Momento wrappers.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
- Sign up for a free account [here](https://docs.momentohq.com/getting-started) and get an auth token
|
||||
- Install the Momento Python SDK with `pip install momento`
|
||||
|
||||
## Wrappers
|
||||
|
||||
### Cache
|
||||
|
||||
The Cache wrapper allows for [Momento](https://gomomento.com) to be used as a serverless, distributed, low-latency cache for LLM prompts and responses.
|
||||
|
||||
#### Standard Cache
|
||||
|
||||
The standard cache is the go-to use case for [Momento](https://gomomento.com) users in any environment.
|
||||
|
||||
Import the cache as follows:
|
||||
|
||||
```python
|
||||
from langchain.cache import MomentoCache
|
||||
```
|
||||
|
||||
And set up like so:
|
||||
|
||||
```python
|
||||
from datetime import timedelta
|
||||
from momento import CacheClient, Configurations, CredentialProvider
|
||||
import langchain
|
||||
|
||||
# Instantiate the Momento client
|
||||
cache_client = CacheClient(
|
||||
Configurations.Laptop.v1(),
|
||||
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
|
||||
default_ttl=timedelta(days=1))
|
||||
|
||||
# Choose a Momento cache name of your choice
|
||||
cache_name = "langchain"
|
||||
|
||||
# Instantiate the LLM cache
|
||||
langchain.llm_cache = MomentoCache(cache_client, cache_name)
|
||||
```
|
||||
|
||||
### Memory
|
||||
|
||||
Momento can be used as a distributed memory store for LLMs.
|
||||
|
||||
#### Chat Message History Memory
|
||||
|
||||
See [this notebook](../modules/memory/examples/momento_chat_message_history.ipynb) for a walkthrough of how to use Momento as a memory store for chat message history.
|
@ -0,0 +1,86 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "91c6a7ef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Momento\n",
|
||||
"\n",
|
||||
"This notebook goes over how to use [Momento Cache](https://gomomento.com) to store chat message history using the `MomentoChatMessageHistory` class. See the Momento [docs](https://docs.momentohq.com/getting-started) for more detail on how to get set up with Momento.\n",
|
||||
"\n",
|
||||
"Note that, by default we will create a cache if one with the given name doesn't already exist.\n",
|
||||
"\n",
|
||||
"You'll need to get a Momento auth token to use this class. This can either be passed in to a momento.CacheClient if you'd like to instantiate that directly, as a named parameter `auth_token` to `MomentoChatMessageHistory.from_client_params`, or can just be set as an environment variable `MOMENTO_AUTH_TOKEN`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "d15e3302",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datetime import timedelta\n",
|
||||
"\n",
|
||||
"from langchain.memory import MomentoChatMessageHistory\n",
|
||||
"\n",
|
||||
"session_id = \"foo\"\n",
|
||||
"cache_name = \"langchain\"\n",
|
||||
"ttl = timedelta(days=1),\n",
|
||||
"history = MomentoChatMessageHistory.from_client_params(\n",
|
||||
" session_id, \n",
|
||||
" cache_name,\n",
|
||||
" ttl,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"history.add_user_message(\"hi!\")\n",
|
||||
"\n",
|
||||
"history.add_ai_message(\"whats up?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "64fc465e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n",
|
||||
" AIMessage(content='whats up?', additional_kwargs={}, example=False)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"history.messages"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -41,7 +41,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 6,
|
||||
"id": "f69f6283",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -612,6 +612,115 @@
|
||||
"llm(\"Tell me joke\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "726fe754",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Momento Cache\n",
|
||||
"Use [Momento](../../../../integrations/momento.md) to cache prompts and responses.\n",
|
||||
"\n",
|
||||
"Requires momento to use, uncomment below to install:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e8949f29",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !pip install momento"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "56ea6a08",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You'll need to get a Momemto auth token to use this class. This can either be passed in to a momento.CacheClient if you'd like to instantiate that directly, as a named parameter `auth_token` to `MomentoChatMessageHistory.from_client_params`, or can just be set as an environment variable `MOMENTO_AUTH_TOKEN`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "2005f03a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datetime import timedelta\n",
|
||||
"\n",
|
||||
"from langchain.cache import MomentoCache\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"cache_name = \"langchain\"\n",
|
||||
"ttl = timedelta(days=1)\n",
|
||||
"langchain.llm_cache = MomentoCache.from_client_params(cache_name, ttl)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "c6a6c238",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CPU times: user 40.7 ms, sys: 16.5 ms, total: 57.2 ms\n",
|
||||
"Wall time: 1.73 s\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"# The first time, it is not yet in cache, so it should take longer\n",
|
||||
"llm(\"Tell me a joke\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "b8f78f9d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CPU times: user 3.16 ms, sys: 2.98 ms, total: 6.14 ms\n",
|
||||
"Wall time: 57.9 ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"# The second time it is, so it goes faster\n",
|
||||
"# When run in the same region as the cache, latencies are single digit ms\n",
|
||||
"llm(\"Tell me a joke\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "934943dc",
|
||||
@ -909,9 +1018,9 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
"name": "venv"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
@ -923,7 +1032,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.8"
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -1,14 +1,30 @@
|
||||
"""Beta Feature: base interface for cache."""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
from datetime import timedelta
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from sqlalchemy import Column, Integer, String, create_engine, select
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from langchain.utils import get_from_env
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
except ImportError:
|
||||
@ -18,6 +34,9 @@ from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Generation
|
||||
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import momento
|
||||
|
||||
RETURN_VAL_TYPE = List[Generation]
|
||||
|
||||
|
||||
@ -26,6 +45,39 @@ def _hash(_input: str) -> str:
|
||||
return hashlib.md5(_input.encode()).hexdigest()
|
||||
|
||||
|
||||
def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str:
|
||||
"""Dump generations to json.
|
||||
|
||||
Args:
|
||||
generations (RETURN_VAL_TYPE): A list of language model generations.
|
||||
|
||||
Returns:
|
||||
str: Json representing a list of generations.
|
||||
"""
|
||||
return json.dumps([generation.dict() for generation in generations])
|
||||
|
||||
|
||||
def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
|
||||
"""Load generations from json.
|
||||
|
||||
Args:
|
||||
generations_json (str): A string of json representing a list of generations.
|
||||
|
||||
Raises:
|
||||
ValueError: Could not decode json string to list of generations.
|
||||
|
||||
Returns:
|
||||
RETURN_VAL_TYPE: A list of generations.
|
||||
"""
|
||||
try:
|
||||
results = json.loads(generations_json)
|
||||
return [Generation(**generation_dict) for generation_dict in results]
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(
|
||||
f"Could not decode json to list of generations: {generations_json}"
|
||||
)
|
||||
|
||||
|
||||
class BaseCache(ABC):
|
||||
"""Base interface for cache."""
|
||||
|
||||
@ -390,3 +442,179 @@ class GPTCache(BaseCache):
|
||||
gptcache_instance.flush()
|
||||
|
||||
self.gptcache_dict.clear()
|
||||
|
||||
|
||||
def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None:
|
||||
"""Create cache if it doesn't exist.
|
||||
|
||||
Raises:
|
||||
SdkException: Momento service or network error
|
||||
Exception: Unexpected response
|
||||
"""
|
||||
from momento.responses import CreateCache
|
||||
|
||||
create_cache_response = cache_client.create_cache(cache_name)
|
||||
if isinstance(create_cache_response, CreateCache.Success) or isinstance(
|
||||
create_cache_response, CreateCache.CacheAlreadyExists
|
||||
):
|
||||
return None
|
||||
elif isinstance(create_cache_response, CreateCache.Error):
|
||||
raise create_cache_response.inner_exception
|
||||
else:
|
||||
raise Exception(f"Unexpected response cache creation: {create_cache_response}")
|
||||
|
||||
|
||||
def _validate_ttl(ttl: Optional[timedelta]) -> None:
|
||||
if ttl is not None and ttl <= timedelta(seconds=0):
|
||||
raise ValueError(f"ttl must be positive but was {ttl}.")
|
||||
|
||||
|
||||
class MomentoCache(BaseCache):
|
||||
"""Cache that uses Momento as a backend. See https://gomomento.com/"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_client: momento.CacheClient,
|
||||
cache_name: str,
|
||||
*,
|
||||
ttl: Optional[timedelta] = None,
|
||||
ensure_cache_exists: bool = True,
|
||||
):
|
||||
"""Instantiate a prompt cache using Momento as a backend.
|
||||
|
||||
Note: to instantiate the cache client passed to MomentoCache,
|
||||
you must have a Momento account. See https://gomomento.com/.
|
||||
|
||||
Args:
|
||||
cache_client (CacheClient): The Momento cache client.
|
||||
cache_name (str): The name of the cache to use to store the data.
|
||||
ttl (Optional[timedelta], optional): The time to live for the cache items.
|
||||
Defaults to None, ie use the client default TTL.
|
||||
ensure_cache_exists (bool, optional): Create the cache if it doesn't
|
||||
exist. Defaults to True.
|
||||
|
||||
Raises:
|
||||
ImportError: Momento python package is not installed.
|
||||
TypeError: cache_client is not of type momento.CacheClientObject
|
||||
ValueError: ttl is non-null and non-negative
|
||||
"""
|
||||
try:
|
||||
from momento import CacheClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import momento python package. "
|
||||
"Please install it with `pip install momento`."
|
||||
)
|
||||
if not isinstance(cache_client, CacheClient):
|
||||
raise TypeError("cache_client must be a momento.CacheClient object.")
|
||||
_validate_ttl(ttl)
|
||||
if ensure_cache_exists:
|
||||
_ensure_cache_exists(cache_client, cache_name)
|
||||
|
||||
self.cache_client = cache_client
|
||||
self.cache_name = cache_name
|
||||
self.ttl = ttl
|
||||
|
||||
@classmethod
|
||||
def from_client_params(
|
||||
cls,
|
||||
cache_name: str,
|
||||
ttl: timedelta,
|
||||
*,
|
||||
configuration: Optional[momento.config.Configuration] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> MomentoCache:
|
||||
"""Construct cache from CacheClient parameters."""
|
||||
try:
|
||||
from momento import CacheClient, Configurations, CredentialProvider
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import momento python package. "
|
||||
"Please install it with `pip install momento`."
|
||||
)
|
||||
if configuration is None:
|
||||
configuration = Configurations.Laptop.v1()
|
||||
auth_token = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN")
|
||||
credentials = CredentialProvider.from_string(auth_token)
|
||||
cache_client = CacheClient(configuration, credentials, default_ttl=ttl)
|
||||
return cls(cache_client, cache_name, ttl=ttl, **kwargs)
|
||||
|
||||
def __key(self, prompt: str, llm_string: str) -> str:
|
||||
"""Compute cache key from prompt and associated model and settings.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt run through the language model.
|
||||
llm_string (str): The language model version and settings.
|
||||
|
||||
Returns:
|
||||
str: The cache key.
|
||||
"""
|
||||
return _hash(prompt + llm_string)
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Lookup llm generations in cache by prompt and associated model and settings.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt run through the language model.
|
||||
llm_string (str): The language model version and settings.
|
||||
|
||||
Raises:
|
||||
SdkException: Momento service or network error
|
||||
|
||||
Returns:
|
||||
Optional[RETURN_VAL_TYPE]: A list of language model generations.
|
||||
"""
|
||||
from momento.responses import CacheGet
|
||||
|
||||
generations = []
|
||||
|
||||
get_response = self.cache_client.get(
|
||||
self.cache_name, self.__key(prompt, llm_string)
|
||||
)
|
||||
if isinstance(get_response, CacheGet.Hit):
|
||||
value = get_response.value_string
|
||||
generations = _load_generations_from_json(value)
|
||||
elif isinstance(get_response, CacheGet.Miss):
|
||||
pass
|
||||
elif isinstance(get_response, CacheGet.Error):
|
||||
raise get_response.inner_exception
|
||||
return generations if generations else None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Store llm generations in cache.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt run through the language model.
|
||||
llm_string (str): The language model string.
|
||||
return_val (RETURN_VAL_TYPE): A list of language model generations.
|
||||
|
||||
Raises:
|
||||
SdkException: Momento service or network error
|
||||
Exception: Unexpected response
|
||||
"""
|
||||
key = self.__key(prompt, llm_string)
|
||||
value = _dump_generations_to_json(return_val)
|
||||
set_response = self.cache_client.set(self.cache_name, key, value, self.ttl)
|
||||
from momento.responses import CacheSet
|
||||
|
||||
if isinstance(set_response, CacheSet.Success):
|
||||
pass
|
||||
elif isinstance(set_response, CacheSet.Error):
|
||||
raise set_response.inner_exception
|
||||
else:
|
||||
raise Exception(f"Unexpected response: {set_response}")
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear the cache.
|
||||
|
||||
Raises:
|
||||
SdkException: Momento service or network error
|
||||
"""
|
||||
from momento.responses import CacheFlush
|
||||
|
||||
flush_response = self.cache_client.flush_cache(self.cache_name)
|
||||
if isinstance(flush_response, CacheFlush.Success):
|
||||
pass
|
||||
elif isinstance(flush_response, CacheFlush.Error):
|
||||
raise flush_response.inner_exception
|
||||
|
@ -3,6 +3,7 @@ from langchain.memory.buffer import (
|
||||
ConversationStringBufferMemory,
|
||||
)
|
||||
from langchain.memory.buffer_window import ConversationBufferWindowMemory
|
||||
from langchain.memory.chat_message_histories import MomentoChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.cassandra import (
|
||||
CassandraChatMessageHistory,
|
||||
)
|
||||
@ -50,4 +51,5 @@ __all__ = [
|
||||
"FileChatMessageHistory",
|
||||
"MongoDBChatMessageHistory",
|
||||
"CassandraChatMessageHistory",
|
||||
"MomentoChatMessageHistory",
|
||||
]
|
||||
|
@ -7,6 +7,7 @@ from langchain.memory.chat_message_histories.file import FileChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.firestore import (
|
||||
FirestoreChatMessageHistory,
|
||||
)
|
||||
from langchain.memory.chat_message_histories.momento import MomentoChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
|
||||
@ -24,4 +25,5 @@ __all__ = [
|
||||
"MongoDBChatMessageHistory",
|
||||
"CassandraChatMessageHistory",
|
||||
"ZepChatMessageHistory",
|
||||
"MomentoChatMessageHistory",
|
||||
]
|
||||
|
200
langchain/memory/chat_message_histories/momento.py
Normal file
200
langchain/memory/chat_message_histories/momento.py
Normal file
@ -0,0 +1,200 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
_message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
from langchain.utils import get_from_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import momento
|
||||
|
||||
|
||||
def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None:
|
||||
"""Create cache if it doesn't exist.
|
||||
|
||||
Raises:
|
||||
SdkException: Momento service or network error
|
||||
Exception: Unexpected response
|
||||
"""
|
||||
from momento.responses import CreateCache
|
||||
|
||||
create_cache_response = cache_client.create_cache(cache_name)
|
||||
if isinstance(create_cache_response, CreateCache.Success) or isinstance(
|
||||
create_cache_response, CreateCache.CacheAlreadyExists
|
||||
):
|
||||
return None
|
||||
elif isinstance(create_cache_response, CreateCache.Error):
|
||||
raise create_cache_response.inner_exception
|
||||
else:
|
||||
raise Exception(f"Unexpected response cache creation: {create_cache_response}")
|
||||
|
||||
|
||||
class MomentoChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history cache that uses Momento as a backend.
|
||||
See https://gomomento.com/"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
cache_client: momento.CacheClient,
|
||||
cache_name: str,
|
||||
*,
|
||||
key_prefix: str = "message_store:",
|
||||
ttl: Optional[timedelta] = None,
|
||||
ensure_cache_exists: bool = True,
|
||||
):
|
||||
"""Instantiate a chat message history cache that uses Momento as a backend.
|
||||
|
||||
Note: to instantiate the cache client passed to MomentoChatMessageHistory,
|
||||
you must have a Momento account at https://gomomento.com/.
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID to use for this chat session.
|
||||
cache_client (CacheClient): The Momento cache client.
|
||||
cache_name (str): The name of the cache to use to store the messages.
|
||||
key_prefix (str, optional): The prefix to apply to the cache key.
|
||||
Defaults to "message_store:".
|
||||
ttl (Optional[timedelta], optional): The TTL to use for the messages.
|
||||
Defaults to None, ie the default TTL of the cache will be used.
|
||||
ensure_cache_exists (bool, optional): Create the cache if it doesn't exist.
|
||||
Defaults to True.
|
||||
|
||||
Raises:
|
||||
ImportError: Momento python package is not installed.
|
||||
TypeError: cache_client is not of type momento.CacheClientObject
|
||||
"""
|
||||
try:
|
||||
from momento import CacheClient
|
||||
from momento.requests import CollectionTtl
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import momento python package. "
|
||||
"Please install it with `pip install momento`."
|
||||
)
|
||||
if not isinstance(cache_client, CacheClient):
|
||||
raise TypeError("cache_client must be a momento.CacheClient object.")
|
||||
if ensure_cache_exists:
|
||||
_ensure_cache_exists(cache_client, cache_name)
|
||||
self.key = key_prefix + session_id
|
||||
self.cache_client = cache_client
|
||||
self.cache_name = cache_name
|
||||
if ttl is not None:
|
||||
self.ttl = CollectionTtl.of(ttl)
|
||||
else:
|
||||
self.ttl = CollectionTtl.from_cache_ttl()
|
||||
|
||||
@classmethod
|
||||
def from_client_params(
|
||||
cls,
|
||||
session_id: str,
|
||||
cache_name: str,
|
||||
ttl: timedelta,
|
||||
*,
|
||||
configuration: Optional[momento.config.Configuration] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> MomentoChatMessageHistory:
|
||||
"""Construct cache from CacheClient parameters."""
|
||||
try:
|
||||
from momento import CacheClient, Configurations, CredentialProvider
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import momento python package. "
|
||||
"Please install it with `pip install momento`."
|
||||
)
|
||||
if configuration is None:
|
||||
configuration = Configurations.Laptop.v1()
|
||||
auth_token = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN")
|
||||
credentials = CredentialProvider.from_string(auth_token)
|
||||
cache_client = CacheClient(configuration, credentials, default_ttl=ttl)
|
||||
return cls(session_id, cache_client, cache_name, ttl=ttl, **kwargs)
|
||||
|
||||
@property
|
||||
def messages(self) -> list[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve the messages from Momento.
|
||||
|
||||
Raises:
|
||||
SdkException: Momento service or network error
|
||||
Exception: Unexpected response
|
||||
|
||||
Returns:
|
||||
list[BaseMessage]: List of cached messages
|
||||
"""
|
||||
from momento.responses import CacheListFetch
|
||||
|
||||
fetch_response = self.cache_client.list_fetch(self.cache_name, self.key)
|
||||
|
||||
if isinstance(fetch_response, CacheListFetch.Hit):
|
||||
items = [json.loads(m) for m in fetch_response.value_list_string]
|
||||
return messages_from_dict(items)
|
||||
elif isinstance(fetch_response, CacheListFetch.Miss):
|
||||
return []
|
||||
elif isinstance(fetch_response, CacheListFetch.Error):
|
||||
raise fetch_response.inner_exception
|
||||
else:
|
||||
raise Exception(f"Unexpected response: {fetch_response}")
|
||||
|
||||
def add_user_message(self, message: str) -> None:
|
||||
"""Store a user message in the cache.
|
||||
|
||||
Args:
|
||||
message (str): The message to store.
|
||||
"""
|
||||
self.__add_message(HumanMessage(content=message))
|
||||
|
||||
def add_ai_message(self, message: str) -> None:
|
||||
"""Store an AI message in the cache.
|
||||
|
||||
Args:
|
||||
message (str): The message to store.
|
||||
"""
|
||||
self.__add_message(AIMessage(content=message))
|
||||
|
||||
def __add_message(self, message: BaseMessage) -> None:
|
||||
"""Store a message in the cache.
|
||||
|
||||
Args:
|
||||
message (BaseMessage): The message object to store.
|
||||
|
||||
Raises:
|
||||
SdkException: Momento service or network error.
|
||||
Exception: Unexpected response.
|
||||
"""
|
||||
from momento.responses import CacheListPushBack
|
||||
|
||||
item = json.dumps(_message_to_dict(message))
|
||||
push_response = self.cache_client.list_push_back(
|
||||
self.cache_name, self.key, item, ttl=self.ttl
|
||||
)
|
||||
if isinstance(push_response, CacheListPushBack.Success):
|
||||
return None
|
||||
elif isinstance(push_response, CacheListPushBack.Error):
|
||||
raise push_response.inner_exception
|
||||
else:
|
||||
raise Exception(f"Unexpected response: {push_response}")
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove the session's messages from the cache.
|
||||
|
||||
Raises:
|
||||
SdkException: Momento service or network error.
|
||||
Exception: Unexpected response.
|
||||
"""
|
||||
from momento.responses import CacheDelete
|
||||
|
||||
delete_response = self.cache_client.delete(self.cache_name, self.key)
|
||||
if isinstance(delete_response, CacheDelete.Success):
|
||||
return None
|
||||
elif isinstance(delete_response, CacheDelete.Error):
|
||||
raise delete_response.inner_exception
|
||||
else:
|
||||
raise Exception(f"Unexpected response: {delete_response}")
|
41
poetry.lock
generated
41
poetry.lock
generated
@ -2621,7 +2621,7 @@ name = "grpcio"
|
||||
version = "1.47.5"
|
||||
description = "HTTP/2-based RPC framework"
|
||||
category = "main"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "grpcio-1.47.5-cp310-cp310-linux_armv7l.whl", hash = "sha256:acc73289d0c44650aa1f21eccfa967f5623b01c3b5e2b4596fe5f9c5bf10956d"},
|
||||
@ -4462,6 +4462,39 @@ files = [
|
||||
{file = "mmh3-3.1.0.tar.gz", hash = "sha256:9b0f2b2ab4a915333c9d1089572e290a021ebb5b900bb7f7114dccc03995d732"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "momento"
|
||||
version = "1.5.0"
|
||||
description = "SDK for Momento"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7,<4.0"
|
||||
files = [
|
||||
{file = "momento-1.5.0-py3-none-any.whl", hash = "sha256:7f633fb26ddf1bfcaf99a7add37b085f0e23c96f85972b655eaaf2de1c61d7f1"},
|
||||
{file = "momento-1.5.0.tar.gz", hash = "sha256:68ca5d24b4cb08c5c0bd22d4edd3b8b0fcf087a85d30673cb2c55b11971c76ec"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
grpcio = ">=1.46.0,<2.0.0"
|
||||
momento-wire-types = ">=0.64,<0.65"
|
||||
pyjwt = ">=2.4.0,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "momento-wire-types"
|
||||
version = "0.64.0"
|
||||
description = "Momento Client Proto Generated Files"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7,<4.0"
|
||||
files = [
|
||||
{file = "momento_wire_types-0.64.0-py3-none-any.whl", hash = "sha256:30a5d523cef9209c0863db25e6344044b0c7240fea183a41c1433ca83cefea5b"},
|
||||
{file = "momento_wire_types-0.64.0.tar.gz", hash = "sha256:5d3647210d49d0c3032a74ae5f5cc012a9faf826786272e1436a3d84d70a8bd5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
grpcio = "*"
|
||||
protobuf = ">=3,<5"
|
||||
|
||||
[[package]]
|
||||
name = "monotonic"
|
||||
version = "1.6"
|
||||
@ -6410,7 +6443,7 @@ name = "protobuf"
|
||||
version = "3.19.6"
|
||||
description = "Protocol Buffers"
|
||||
category = "main"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
files = [
|
||||
{file = "protobuf-3.19.6-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:010be24d5a44be7b0613750ab40bc8b8cedc796db468eae6c779b395f50d1fa1"},
|
||||
@ -10865,7 +10898,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
|
||||
cffi = ["cffi (>=1.11)"]
|
||||
|
||||
[extras]
|
||||
all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "lxml", "manifest-ml", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "requests-toolbelt", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
|
||||
all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "lxml", "manifest-ml", "momento", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "requests-toolbelt", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
|
||||
azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "openai"]
|
||||
cohere = ["cohere"]
|
||||
docarray = ["docarray"]
|
||||
@ -10879,4 +10912,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "0f2ed0b37063f533b3403d545b48633ced4d02c15bc8f4e47f1ded1652ab9764"
|
||||
content-hash = "5e83a1f4ca8c0d3107363e393485174fd72ce9db93db5dc7c21b2dd37b184e66"
|
||||
|
@ -97,6 +97,7 @@ scikit-learn = {version = "^1.2.2", optional = true}
|
||||
azure-ai-formrecognizer = {version = "^3.2.1", optional = true}
|
||||
azure-ai-vision = {version = "^0.11.1b1", optional = true}
|
||||
azure-cognitiveservices-speech = {version = "^1.28.0", optional = true}
|
||||
momento = {version = "^1.5.0", optional = true}
|
||||
bibtexparser = {version = "^1.4.0", optional = true}
|
||||
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
@ -160,6 +161,7 @@ pymongo = "^4.3.3"
|
||||
cassandra-driver = "^3.27.0"
|
||||
arxiv = "^1.4"
|
||||
mastodon-py = "^1.8.1"
|
||||
momento = "^1.5.0"
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.0.249"
|
||||
@ -253,6 +255,7 @@ all = [
|
||||
"azure-ai-formrecognizer",
|
||||
"azure-ai-vision",
|
||||
"azure-cognitiveservices-speech",
|
||||
"momento"
|
||||
]
|
||||
|
||||
# An extra used to be able to add extended testing.
|
||||
|
94
tests/integration_tests/cache/test_momento_cache.py
vendored
Normal file
94
tests/integration_tests/cache/test_momento_cache.py
vendored
Normal file
@ -0,0 +1,94 @@
|
||||
"""Test Momento cache functionality.
|
||||
|
||||
To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid
|
||||
Momento auth token. This can be obtained by signing up for a free
|
||||
Momento account at https://gomomento.com/.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from momento import CacheClient, Configurations, CredentialProvider
|
||||
|
||||
import langchain
|
||||
from langchain.cache import MomentoCache
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def random_string() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def momento_cache() -> Iterator[MomentoCache]:
|
||||
cache_name = f"langchain-test-cache-{random_string()}"
|
||||
client = CacheClient(
|
||||
Configurations.Laptop.v1(),
|
||||
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
|
||||
default_ttl=timedelta(seconds=30),
|
||||
)
|
||||
try:
|
||||
llm_cache = MomentoCache(client, cache_name)
|
||||
langchain.llm_cache = llm_cache
|
||||
yield llm_cache
|
||||
finally:
|
||||
client.delete_cache(cache_name)
|
||||
|
||||
|
||||
def test_invalid_ttl() -> None:
|
||||
client = CacheClient(
|
||||
Configurations.Laptop.v1(),
|
||||
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
|
||||
default_ttl=timedelta(seconds=30),
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
MomentoCache(client, cache_name=random_string(), ttl=timedelta(seconds=-1))
|
||||
|
||||
|
||||
def test_momento_cache_miss(momento_cache: MomentoCache) -> None:
|
||||
llm = FakeLLM()
|
||||
stub_llm_output = LLMResult(generations=[[Generation(text="foo")]])
|
||||
assert llm.generate([random_string()]) == stub_llm_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompts, generations",
|
||||
[
|
||||
# Single prompt, single generation
|
||||
([random_string()], [[random_string()]]),
|
||||
# Single prompt, multiple generations
|
||||
([random_string()], [[random_string(), random_string()]]),
|
||||
# Single prompt, multiple generations
|
||||
([random_string()], [[random_string(), random_string(), random_string()]]),
|
||||
# Multiple prompts, multiple generations
|
||||
(
|
||||
[random_string(), random_string()],
|
||||
[[random_string()], [random_string(), random_string()]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_momento_cache_hit(
|
||||
momento_cache: MomentoCache, prompts: list[str], generations: list[list[str]]
|
||||
) -> None:
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
|
||||
llm_generations = [
|
||||
[
|
||||
Generation(text=generation, generation_info=params)
|
||||
for generation in prompt_i_generations
|
||||
]
|
||||
for prompt_i_generations in generations
|
||||
]
|
||||
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
|
||||
momento_cache.update(prompt_i, llm_string, llm_generations_i)
|
||||
|
||||
assert llm.generate(prompts) == LLMResult(
|
||||
generations=llm_generations, llm_output={}
|
||||
)
|
70
tests/integration_tests/memory/test_momento.py
Normal file
70
tests/integration_tests/memory/test_momento.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""Test Momento chat message history functionality.
|
||||
|
||||
To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid
|
||||
Momento auth token. This can be obtained by signing up for a free
|
||||
Momento account at https://gomomento.com/.
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from momento import CacheClient, Configurations, CredentialProvider
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_message_histories import MomentoChatMessageHistory
|
||||
from langchain.schema import _message_to_dict
|
||||
|
||||
|
||||
def random_string() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def message_history() -> Iterator[MomentoChatMessageHistory]:
|
||||
cache_name = f"langchain-test-cache-{random_string()}"
|
||||
client = CacheClient(
|
||||
Configurations.Laptop.v1(),
|
||||
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
|
||||
default_ttl=timedelta(seconds=30),
|
||||
)
|
||||
try:
|
||||
chat_message_history = MomentoChatMessageHistory(
|
||||
session_id="my-test-session",
|
||||
cache_client=client,
|
||||
cache_name=cache_name,
|
||||
)
|
||||
yield chat_message_history
|
||||
finally:
|
||||
client.delete_cache(cache_name)
|
||||
|
||||
|
||||
def test_memory_empty_on_new_session(
|
||||
message_history: MomentoChatMessageHistory,
|
||||
) -> None:
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="foo", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
assert memory.chat_memory.messages == []
|
||||
|
||||
|
||||
def test_memory_with_message_store(message_history: MomentoChatMessageHistory) -> None:
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# Add some messages to the memory store
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# Verify that the messages are in the store
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# Verify clearing the store
|
||||
memory.chat_memory.clear()
|
||||
assert memory.chat_memory.messages == []
|
Loading…
Reference in New Issue
Block a user