feat: add Momento as a standard cache and chat message history provider (#5221)

# Add Momento as a standard cache and chat message history provider

This PR adds Momento as a standard caching provider. Implements the
interface, adds integration tests, and documentation. We also add
Momento as a chat history message provider along with integration tests,
and documentation.

[Momento](https://www.gomomento.com/) is a fully serverless cache.
Similar to S3 or DynamoDB, it requires zero configuration,
infrastructure management, and is instantly available. Users sign up for
free and get 50GB of data in/out for free every month.

## Before submitting

 We have added documentation, notebooks, and integration tests
demonstrating usage.

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
Michael Landis 2023-05-25 19:13:21 -07:00 committed by GitHub
parent 56ad56c812
commit 7047a2c1af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 889 additions and 9 deletions

View File

@ -0,0 +1,53 @@
# Momento
This page covers how to use the [Momento](https://gomomento.com) ecosystem within LangChain.
It is broken into two parts: installation and setup, and then references to specific Momento wrappers.
## Installation and Setup
- Sign up for a free account [here](https://docs.momentohq.com/getting-started) and get an auth token
- Install the Momento Python SDK with `pip install momento`
## Wrappers
### Cache
The Cache wrapper allows for [Momento](https://gomomento.com) to be used as a serverless, distributed, low-latency cache for LLM prompts and responses.
#### Standard Cache
The standard cache is the go-to use case for [Momento](https://gomomento.com) users in any environment.
Import the cache as follows:
```python
from langchain.cache import MomentoCache
```
And set up like so:
```python
from datetime import timedelta
from momento import CacheClient, Configurations, CredentialProvider
import langchain
# Instantiate the Momento client
cache_client = CacheClient(
Configurations.Laptop.v1(),
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
default_ttl=timedelta(days=1))
# Choose a Momento cache name of your choice
cache_name = "langchain"
# Instantiate the LLM cache
langchain.llm_cache = MomentoCache(cache_client, cache_name)
```
### Memory
Momento can be used as a distributed memory store for LLMs.
#### Chat Message History Memory
See [this notebook](../modules/memory/examples/momento_chat_message_history.ipynb) for a walkthrough of how to use Momento as a memory store for chat message history.

View File

@ -0,0 +1,86 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "91c6a7ef",
"metadata": {},
"source": [
"# Momento\n",
"\n",
"This notebook goes over how to use [Momento Cache](https://gomomento.com) to store chat message history using the `MomentoChatMessageHistory` class. See the Momento [docs](https://docs.momentohq.com/getting-started) for more detail on how to get set up with Momento.\n",
"\n",
"Note that, by default we will create a cache if one with the given name doesn't already exist.\n",
"\n",
"You'll need to get a Momento auth token to use this class. This can either be passed in to a momento.CacheClient if you'd like to instantiate that directly, as a named parameter `auth_token` to `MomentoChatMessageHistory.from_client_params`, or can just be set as an environment variable `MOMENTO_AUTH_TOKEN`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d15e3302",
"metadata": {},
"outputs": [],
"source": [
"from datetime import timedelta\n",
"\n",
"from langchain.memory import MomentoChatMessageHistory\n",
"\n",
"session_id = \"foo\"\n",
"cache_name = \"langchain\"\n",
"ttl = timedelta(days=1),\n",
"history = MomentoChatMessageHistory.from_client_params(\n",
" session_id, \n",
" cache_name,\n",
" ttl,\n",
")\n",
"\n",
"history.add_user_message(\"hi!\")\n",
"\n",
"history.add_ai_message(\"whats up?\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "64fc465e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n",
" AIMessage(content='whats up?', additional_kwargs={}, example=False)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"history.messages"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -41,7 +41,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 6,
"id": "f69f6283",
"metadata": {},
"outputs": [],
@ -612,6 +612,115 @@
"llm(\"Tell me joke\")"
]
},
{
"cell_type": "markdown",
"id": "726fe754",
"metadata": {},
"source": [
"## Momento Cache\n",
"Use [Momento](../../../../integrations/momento.md) to cache prompts and responses.\n",
"\n",
"Requires momento to use, uncomment below to install:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8949f29",
"metadata": {},
"outputs": [],
"source": [
"# !pip install momento"
]
},
{
"cell_type": "markdown",
"id": "56ea6a08",
"metadata": {},
"source": [
"You'll need to get a Momemto auth token to use this class. This can either be passed in to a momento.CacheClient if you'd like to instantiate that directly, as a named parameter `auth_token` to `MomentoChatMessageHistory.from_client_params`, or can just be set as an environment variable `MOMENTO_AUTH_TOKEN`."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2005f03a",
"metadata": {},
"outputs": [],
"source": [
"from datetime import timedelta\n",
"\n",
"from langchain.cache import MomentoCache\n",
"\n",
"\n",
"cache_name = \"langchain\"\n",
"ttl = timedelta(days=1)\n",
"langchain.llm_cache = MomentoCache.from_client_params(cache_name, ttl)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c6a6c238",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 40.7 ms, sys: 16.5 ms, total: 57.2 ms\n",
"Wall time: 1.73 s\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The first time, it is not yet in cache, so it should take longer\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b8f78f9d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3.16 ms, sys: 2.98 ms, total: 6.14 ms\n",
"Wall time: 57.9 ms\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The second time it is, so it goes faster\n",
"# When run in the same region as the cache, latencies are single digit ms\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "markdown",
"id": "934943dc",
@ -909,9 +1018,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "venv",
"language": "python",
"name": "python3"
"name": "venv"
},
"language_info": {
"codemirror_mode": {
@ -923,7 +1032,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.11.3"
}
},
"nbformat": 4,

View File

@ -1,14 +1,30 @@
"""Beta Feature: base interface for cache."""
from __future__ import annotations
import hashlib
import inspect
import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
from datetime import timedelta
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session
from langchain.utils import get_from_env
try:
from sqlalchemy.orm import declarative_base
except ImportError:
@ -18,6 +34,9 @@ from langchain.embeddings.base import Embeddings
from langchain.schema import Generation
from langchain.vectorstores.redis import Redis as RedisVectorstore
if TYPE_CHECKING:
import momento
RETURN_VAL_TYPE = List[Generation]
@ -26,6 +45,39 @@ def _hash(_input: str) -> str:
return hashlib.md5(_input.encode()).hexdigest()
def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str:
"""Dump generations to json.
Args:
generations (RETURN_VAL_TYPE): A list of language model generations.
Returns:
str: Json representing a list of generations.
"""
return json.dumps([generation.dict() for generation in generations])
def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
"""Load generations from json.
Args:
generations_json (str): A string of json representing a list of generations.
Raises:
ValueError: Could not decode json string to list of generations.
Returns:
RETURN_VAL_TYPE: A list of generations.
"""
try:
results = json.loads(generations_json)
return [Generation(**generation_dict) for generation_dict in results]
except json.JSONDecodeError:
raise ValueError(
f"Could not decode json to list of generations: {generations_json}"
)
class BaseCache(ABC):
"""Base interface for cache."""
@ -390,3 +442,179 @@ class GPTCache(BaseCache):
gptcache_instance.flush()
self.gptcache_dict.clear()
def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None:
"""Create cache if it doesn't exist.
Raises:
SdkException: Momento service or network error
Exception: Unexpected response
"""
from momento.responses import CreateCache
create_cache_response = cache_client.create_cache(cache_name)
if isinstance(create_cache_response, CreateCache.Success) or isinstance(
create_cache_response, CreateCache.CacheAlreadyExists
):
return None
elif isinstance(create_cache_response, CreateCache.Error):
raise create_cache_response.inner_exception
else:
raise Exception(f"Unexpected response cache creation: {create_cache_response}")
def _validate_ttl(ttl: Optional[timedelta]) -> None:
if ttl is not None and ttl <= timedelta(seconds=0):
raise ValueError(f"ttl must be positive but was {ttl}.")
class MomentoCache(BaseCache):
"""Cache that uses Momento as a backend. See https://gomomento.com/"""
def __init__(
self,
cache_client: momento.CacheClient,
cache_name: str,
*,
ttl: Optional[timedelta] = None,
ensure_cache_exists: bool = True,
):
"""Instantiate a prompt cache using Momento as a backend.
Note: to instantiate the cache client passed to MomentoCache,
you must have a Momento account. See https://gomomento.com/.
Args:
cache_client (CacheClient): The Momento cache client.
cache_name (str): The name of the cache to use to store the data.
ttl (Optional[timedelta], optional): The time to live for the cache items.
Defaults to None, ie use the client default TTL.
ensure_cache_exists (bool, optional): Create the cache if it doesn't
exist. Defaults to True.
Raises:
ImportError: Momento python package is not installed.
TypeError: cache_client is not of type momento.CacheClientObject
ValueError: ttl is non-null and non-negative
"""
try:
from momento import CacheClient
except ImportError:
raise ImportError(
"Could not import momento python package. "
"Please install it with `pip install momento`."
)
if not isinstance(cache_client, CacheClient):
raise TypeError("cache_client must be a momento.CacheClient object.")
_validate_ttl(ttl)
if ensure_cache_exists:
_ensure_cache_exists(cache_client, cache_name)
self.cache_client = cache_client
self.cache_name = cache_name
self.ttl = ttl
@classmethod
def from_client_params(
cls,
cache_name: str,
ttl: timedelta,
*,
configuration: Optional[momento.config.Configuration] = None,
auth_token: Optional[str] = None,
**kwargs: Any,
) -> MomentoCache:
"""Construct cache from CacheClient parameters."""
try:
from momento import CacheClient, Configurations, CredentialProvider
except ImportError:
raise ImportError(
"Could not import momento python package. "
"Please install it with `pip install momento`."
)
if configuration is None:
configuration = Configurations.Laptop.v1()
auth_token = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN")
credentials = CredentialProvider.from_string(auth_token)
cache_client = CacheClient(configuration, credentials, default_ttl=ttl)
return cls(cache_client, cache_name, ttl=ttl, **kwargs)
def __key(self, prompt: str, llm_string: str) -> str:
"""Compute cache key from prompt and associated model and settings.
Args:
prompt (str): The prompt run through the language model.
llm_string (str): The language model version and settings.
Returns:
str: The cache key.
"""
return _hash(prompt + llm_string)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Lookup llm generations in cache by prompt and associated model and settings.
Args:
prompt (str): The prompt run through the language model.
llm_string (str): The language model version and settings.
Raises:
SdkException: Momento service or network error
Returns:
Optional[RETURN_VAL_TYPE]: A list of language model generations.
"""
from momento.responses import CacheGet
generations = []
get_response = self.cache_client.get(
self.cache_name, self.__key(prompt, llm_string)
)
if isinstance(get_response, CacheGet.Hit):
value = get_response.value_string
generations = _load_generations_from_json(value)
elif isinstance(get_response, CacheGet.Miss):
pass
elif isinstance(get_response, CacheGet.Error):
raise get_response.inner_exception
return generations if generations else None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Store llm generations in cache.
Args:
prompt (str): The prompt run through the language model.
llm_string (str): The language model string.
return_val (RETURN_VAL_TYPE): A list of language model generations.
Raises:
SdkException: Momento service or network error
Exception: Unexpected response
"""
key = self.__key(prompt, llm_string)
value = _dump_generations_to_json(return_val)
set_response = self.cache_client.set(self.cache_name, key, value, self.ttl)
from momento.responses import CacheSet
if isinstance(set_response, CacheSet.Success):
pass
elif isinstance(set_response, CacheSet.Error):
raise set_response.inner_exception
else:
raise Exception(f"Unexpected response: {set_response}")
def clear(self, **kwargs: Any) -> None:
"""Clear the cache.
Raises:
SdkException: Momento service or network error
"""
from momento.responses import CacheFlush
flush_response = self.cache_client.flush_cache(self.cache_name)
if isinstance(flush_response, CacheFlush.Success):
pass
elif isinstance(flush_response, CacheFlush.Error):
raise flush_response.inner_exception

View File

@ -3,6 +3,7 @@ from langchain.memory.buffer import (
ConversationStringBufferMemory,
)
from langchain.memory.buffer_window import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import MomentoChatMessageHistory
from langchain.memory.chat_message_histories.cassandra import (
CassandraChatMessageHistory,
)
@ -50,4 +51,5 @@ __all__ = [
"FileChatMessageHistory",
"MongoDBChatMessageHistory",
"CassandraChatMessageHistory",
"MomentoChatMessageHistory",
]

View File

@ -7,6 +7,7 @@ from langchain.memory.chat_message_histories.file import FileChatMessageHistory
from langchain.memory.chat_message_histories.firestore import (
FirestoreChatMessageHistory,
)
from langchain.memory.chat_message_histories.momento import MomentoChatMessageHistory
from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHistory
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
@ -24,4 +25,5 @@ __all__ = [
"MongoDBChatMessageHistory",
"CassandraChatMessageHistory",
"ZepChatMessageHistory",
"MomentoChatMessageHistory",
]

View File

@ -0,0 +1,200 @@
from __future__ import annotations
import json
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional
from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
from langchain.utils import get_from_env
if TYPE_CHECKING:
import momento
def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None:
"""Create cache if it doesn't exist.
Raises:
SdkException: Momento service or network error
Exception: Unexpected response
"""
from momento.responses import CreateCache
create_cache_response = cache_client.create_cache(cache_name)
if isinstance(create_cache_response, CreateCache.Success) or isinstance(
create_cache_response, CreateCache.CacheAlreadyExists
):
return None
elif isinstance(create_cache_response, CreateCache.Error):
raise create_cache_response.inner_exception
else:
raise Exception(f"Unexpected response cache creation: {create_cache_response}")
class MomentoChatMessageHistory(BaseChatMessageHistory):
"""Chat message history cache that uses Momento as a backend.
See https://gomomento.com/"""
def __init__(
self,
session_id: str,
cache_client: momento.CacheClient,
cache_name: str,
*,
key_prefix: str = "message_store:",
ttl: Optional[timedelta] = None,
ensure_cache_exists: bool = True,
):
"""Instantiate a chat message history cache that uses Momento as a backend.
Note: to instantiate the cache client passed to MomentoChatMessageHistory,
you must have a Momento account at https://gomomento.com/.
Args:
session_id (str): The session ID to use for this chat session.
cache_client (CacheClient): The Momento cache client.
cache_name (str): The name of the cache to use to store the messages.
key_prefix (str, optional): The prefix to apply to the cache key.
Defaults to "message_store:".
ttl (Optional[timedelta], optional): The TTL to use for the messages.
Defaults to None, ie the default TTL of the cache will be used.
ensure_cache_exists (bool, optional): Create the cache if it doesn't exist.
Defaults to True.
Raises:
ImportError: Momento python package is not installed.
TypeError: cache_client is not of type momento.CacheClientObject
"""
try:
from momento import CacheClient
from momento.requests import CollectionTtl
except ImportError:
raise ImportError(
"Could not import momento python package. "
"Please install it with `pip install momento`."
)
if not isinstance(cache_client, CacheClient):
raise TypeError("cache_client must be a momento.CacheClient object.")
if ensure_cache_exists:
_ensure_cache_exists(cache_client, cache_name)
self.key = key_prefix + session_id
self.cache_client = cache_client
self.cache_name = cache_name
if ttl is not None:
self.ttl = CollectionTtl.of(ttl)
else:
self.ttl = CollectionTtl.from_cache_ttl()
@classmethod
def from_client_params(
cls,
session_id: str,
cache_name: str,
ttl: timedelta,
*,
configuration: Optional[momento.config.Configuration] = None,
auth_token: Optional[str] = None,
**kwargs: Any,
) -> MomentoChatMessageHistory:
"""Construct cache from CacheClient parameters."""
try:
from momento import CacheClient, Configurations, CredentialProvider
except ImportError:
raise ImportError(
"Could not import momento python package. "
"Please install it with `pip install momento`."
)
if configuration is None:
configuration = Configurations.Laptop.v1()
auth_token = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN")
credentials = CredentialProvider.from_string(auth_token)
cache_client = CacheClient(configuration, credentials, default_ttl=ttl)
return cls(session_id, cache_client, cache_name, ttl=ttl, **kwargs)
@property
def messages(self) -> list[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from Momento.
Raises:
SdkException: Momento service or network error
Exception: Unexpected response
Returns:
list[BaseMessage]: List of cached messages
"""
from momento.responses import CacheListFetch
fetch_response = self.cache_client.list_fetch(self.cache_name, self.key)
if isinstance(fetch_response, CacheListFetch.Hit):
items = [json.loads(m) for m in fetch_response.value_list_string]
return messages_from_dict(items)
elif isinstance(fetch_response, CacheListFetch.Miss):
return []
elif isinstance(fetch_response, CacheListFetch.Error):
raise fetch_response.inner_exception
else:
raise Exception(f"Unexpected response: {fetch_response}")
def add_user_message(self, message: str) -> None:
"""Store a user message in the cache.
Args:
message (str): The message to store.
"""
self.__add_message(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
"""Store an AI message in the cache.
Args:
message (str): The message to store.
"""
self.__add_message(AIMessage(content=message))
def __add_message(self, message: BaseMessage) -> None:
"""Store a message in the cache.
Args:
message (BaseMessage): The message object to store.
Raises:
SdkException: Momento service or network error.
Exception: Unexpected response.
"""
from momento.responses import CacheListPushBack
item = json.dumps(_message_to_dict(message))
push_response = self.cache_client.list_push_back(
self.cache_name, self.key, item, ttl=self.ttl
)
if isinstance(push_response, CacheListPushBack.Success):
return None
elif isinstance(push_response, CacheListPushBack.Error):
raise push_response.inner_exception
else:
raise Exception(f"Unexpected response: {push_response}")
def clear(self) -> None:
"""Remove the session's messages from the cache.
Raises:
SdkException: Momento service or network error.
Exception: Unexpected response.
"""
from momento.responses import CacheDelete
delete_response = self.cache_client.delete(self.cache_name, self.key)
if isinstance(delete_response, CacheDelete.Success):
return None
elif isinstance(delete_response, CacheDelete.Error):
raise delete_response.inner_exception
else:
raise Exception(f"Unexpected response: {delete_response}")

41
poetry.lock generated
View File

@ -2621,7 +2621,7 @@ name = "grpcio"
version = "1.47.5"
description = "HTTP/2-based RPC framework"
category = "main"
optional = true
optional = false
python-versions = ">=3.6"
files = [
{file = "grpcio-1.47.5-cp310-cp310-linux_armv7l.whl", hash = "sha256:acc73289d0c44650aa1f21eccfa967f5623b01c3b5e2b4596fe5f9c5bf10956d"},
@ -4462,6 +4462,39 @@ files = [
{file = "mmh3-3.1.0.tar.gz", hash = "sha256:9b0f2b2ab4a915333c9d1089572e290a021ebb5b900bb7f7114dccc03995d732"},
]
[[package]]
name = "momento"
version = "1.5.0"
description = "SDK for Momento"
category = "main"
optional = false
python-versions = ">=3.7,<4.0"
files = [
{file = "momento-1.5.0-py3-none-any.whl", hash = "sha256:7f633fb26ddf1bfcaf99a7add37b085f0e23c96f85972b655eaaf2de1c61d7f1"},
{file = "momento-1.5.0.tar.gz", hash = "sha256:68ca5d24b4cb08c5c0bd22d4edd3b8b0fcf087a85d30673cb2c55b11971c76ec"},
]
[package.dependencies]
grpcio = ">=1.46.0,<2.0.0"
momento-wire-types = ">=0.64,<0.65"
pyjwt = ">=2.4.0,<3.0.0"
[[package]]
name = "momento-wire-types"
version = "0.64.0"
description = "Momento Client Proto Generated Files"
category = "main"
optional = false
python-versions = ">=3.7,<4.0"
files = [
{file = "momento_wire_types-0.64.0-py3-none-any.whl", hash = "sha256:30a5d523cef9209c0863db25e6344044b0c7240fea183a41c1433ca83cefea5b"},
{file = "momento_wire_types-0.64.0.tar.gz", hash = "sha256:5d3647210d49d0c3032a74ae5f5cc012a9faf826786272e1436a3d84d70a8bd5"},
]
[package.dependencies]
grpcio = "*"
protobuf = ">=3,<5"
[[package]]
name = "monotonic"
version = "1.6"
@ -6410,7 +6443,7 @@ name = "protobuf"
version = "3.19.6"
description = "Protocol Buffers"
category = "main"
optional = true
optional = false
python-versions = ">=3.5"
files = [
{file = "protobuf-3.19.6-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:010be24d5a44be7b0613750ab40bc8b8cedc796db468eae6c779b395f50d1fa1"},
@ -10865,7 +10898,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
cffi = ["cffi (>=1.11)"]
[extras]
all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "lxml", "manifest-ml", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "requests-toolbelt", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "lxml", "manifest-ml", "momento", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "requests-toolbelt", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "openai"]
cohere = ["cohere"]
docarray = ["docarray"]
@ -10879,4 +10912,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "0f2ed0b37063f533b3403d545b48633ced4d02c15bc8f4e47f1ded1652ab9764"
content-hash = "5e83a1f4ca8c0d3107363e393485174fd72ce9db93db5dc7c21b2dd37b184e66"

View File

@ -97,6 +97,7 @@ scikit-learn = {version = "^1.2.2", optional = true}
azure-ai-formrecognizer = {version = "^3.2.1", optional = true}
azure-ai-vision = {version = "^0.11.1b1", optional = true}
azure-cognitiveservices-speech = {version = "^1.28.0", optional = true}
momento = {version = "^1.5.0", optional = true}
bibtexparser = {version = "^1.4.0", optional = true}
[tool.poetry.group.docs.dependencies]
@ -160,6 +161,7 @@ pymongo = "^4.3.3"
cassandra-driver = "^3.27.0"
arxiv = "^1.4"
mastodon-py = "^1.8.1"
momento = "^1.5.0"
[tool.poetry.group.lint.dependencies]
ruff = "^0.0.249"
@ -253,6 +255,7 @@ all = [
"azure-ai-formrecognizer",
"azure-ai-vision",
"azure-cognitiveservices-speech",
"momento"
]
# An extra used to be able to add extended testing.

View File

@ -0,0 +1,94 @@
"""Test Momento cache functionality.
To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid
Momento auth token. This can be obtained by signing up for a free
Momento account at https://gomomento.com/.
"""
from __future__ import annotations
import uuid
from datetime import timedelta
from typing import Iterator
import pytest
from momento import CacheClient, Configurations, CredentialProvider
import langchain
from langchain.cache import MomentoCache
from langchain.schema import Generation, LLMResult
from tests.unit_tests.llms.fake_llm import FakeLLM
def random_string() -> str:
return str(uuid.uuid4())
@pytest.fixture(scope="module")
def momento_cache() -> Iterator[MomentoCache]:
cache_name = f"langchain-test-cache-{random_string()}"
client = CacheClient(
Configurations.Laptop.v1(),
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
default_ttl=timedelta(seconds=30),
)
try:
llm_cache = MomentoCache(client, cache_name)
langchain.llm_cache = llm_cache
yield llm_cache
finally:
client.delete_cache(cache_name)
def test_invalid_ttl() -> None:
client = CacheClient(
Configurations.Laptop.v1(),
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
default_ttl=timedelta(seconds=30),
)
with pytest.raises(ValueError):
MomentoCache(client, cache_name=random_string(), ttl=timedelta(seconds=-1))
def test_momento_cache_miss(momento_cache: MomentoCache) -> None:
llm = FakeLLM()
stub_llm_output = LLMResult(generations=[[Generation(text="foo")]])
assert llm.generate([random_string()]) == stub_llm_output
@pytest.mark.parametrize(
"prompts, generations",
[
# Single prompt, single generation
([random_string()], [[random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string(), random_string()]]),
# Multiple prompts, multiple generations
(
[random_string(), random_string()],
[[random_string()], [random_string(), random_string()]],
),
],
)
def test_momento_cache_hit(
momento_cache: MomentoCache, prompts: list[str], generations: list[list[str]]
) -> None:
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_generations = [
[
Generation(text=generation, generation_info=params)
for generation in prompt_i_generations
]
for prompt_i_generations in generations
]
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
momento_cache.update(prompt_i, llm_string, llm_generations_i)
assert llm.generate(prompts) == LLMResult(
generations=llm_generations, llm_output={}
)

View File

@ -0,0 +1,70 @@
"""Test Momento chat message history functionality.
To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid
Momento auth token. This can be obtained by signing up for a free
Momento account at https://gomomento.com/.
"""
import json
import uuid
from datetime import timedelta
from typing import Iterator
import pytest
from momento import CacheClient, Configurations, CredentialProvider
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import MomentoChatMessageHistory
from langchain.schema import _message_to_dict
def random_string() -> str:
return str(uuid.uuid4())
@pytest.fixture(scope="function")
def message_history() -> Iterator[MomentoChatMessageHistory]:
cache_name = f"langchain-test-cache-{random_string()}"
client = CacheClient(
Configurations.Laptop.v1(),
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
default_ttl=timedelta(seconds=30),
)
try:
chat_message_history = MomentoChatMessageHistory(
session_id="my-test-session",
cache_client=client,
cache_name=cache_name,
)
yield chat_message_history
finally:
client.delete_cache(cache_name)
def test_memory_empty_on_new_session(
message_history: MomentoChatMessageHistory,
) -> None:
memory = ConversationBufferMemory(
memory_key="foo", chat_memory=message_history, return_messages=True
)
assert memory.chat_memory.messages == []
def test_memory_with_message_store(message_history: MomentoChatMessageHistory) -> None:
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# Add some messages to the memory store
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# Verify that the messages are in the store
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# Verify clearing the store
memory.chat_memory.clear()
assert memory.chat_memory.messages == []