From 7047a2c1afce1f1e2e6e4e3e9d94bbf369466a5f Mon Sep 17 00:00:00 2001 From: Michael Landis Date: Thu, 25 May 2023 19:13:21 -0700 Subject: [PATCH] feat: add Momento as a standard cache and chat message history provider (#5221) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Add Momento as a standard cache and chat message history provider This PR adds Momento as a standard caching provider. Implements the interface, adds integration tests, and documentation. We also add Momento as a chat history message provider along with integration tests, and documentation. [Momento](https://www.gomomento.com/) is a fully serverless cache. Similar to S3 or DynamoDB, it requires zero configuration, infrastructure management, and is instantly available. Users sign up for free and get 50GB of data in/out for free every month. ## Before submitting ✅ We have added documentation, notebooks, and integration tests demonstrating usage. Co-authored-by: Dev 2049 --- docs/integrations/momento.md | 53 ++++ .../momento_chat_message_history.ipynb | 86 +++++++ .../models/llms/examples/llm_caching.ipynb | 117 ++++++++- langchain/cache.py | 230 +++++++++++++++++- langchain/memory/__init__.py | 2 + .../memory/chat_message_histories/__init__.py | 2 + .../memory/chat_message_histories/momento.py | 200 +++++++++++++++ poetry.lock | 41 +++- pyproject.toml | 3 + .../cache/test_momento_cache.py | 94 +++++++ .../integration_tests/memory/test_momento.py | 70 ++++++ 11 files changed, 889 insertions(+), 9 deletions(-) create mode 100644 docs/integrations/momento.md create mode 100644 docs/modules/memory/examples/momento_chat_message_history.ipynb create mode 100644 langchain/memory/chat_message_histories/momento.py create mode 100644 tests/integration_tests/cache/test_momento_cache.py create mode 100644 tests/integration_tests/memory/test_momento.py diff --git a/docs/integrations/momento.md b/docs/integrations/momento.md new file mode 100644 index 00000000..41c29a05 --- /dev/null +++ b/docs/integrations/momento.md @@ -0,0 +1,53 @@ +# Momento + +This page covers how to use the [Momento](https://gomomento.com) ecosystem within LangChain. +It is broken into two parts: installation and setup, and then references to specific Momento wrappers. + +## Installation and Setup + +- Sign up for a free account [here](https://docs.momentohq.com/getting-started) and get an auth token +- Install the Momento Python SDK with `pip install momento` + +## Wrappers + +### Cache + +The Cache wrapper allows for [Momento](https://gomomento.com) to be used as a serverless, distributed, low-latency cache for LLM prompts and responses. + +#### Standard Cache + +The standard cache is the go-to use case for [Momento](https://gomomento.com) users in any environment. + +Import the cache as follows: + +```python +from langchain.cache import MomentoCache +``` + +And set up like so: + +```python +from datetime import timedelta +from momento import CacheClient, Configurations, CredentialProvider +import langchain + +# Instantiate the Momento client +cache_client = CacheClient( + Configurations.Laptop.v1(), + CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"), + default_ttl=timedelta(days=1)) + +# Choose a Momento cache name of your choice +cache_name = "langchain" + +# Instantiate the LLM cache +langchain.llm_cache = MomentoCache(cache_client, cache_name) +``` + +### Memory + +Momento can be used as a distributed memory store for LLMs. + +#### Chat Message History Memory + +See [this notebook](../modules/memory/examples/momento_chat_message_history.ipynb) for a walkthrough of how to use Momento as a memory store for chat message history. diff --git a/docs/modules/memory/examples/momento_chat_message_history.ipynb b/docs/modules/memory/examples/momento_chat_message_history.ipynb new file mode 100644 index 00000000..85e6fa59 --- /dev/null +++ b/docs/modules/memory/examples/momento_chat_message_history.ipynb @@ -0,0 +1,86 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "91c6a7ef", + "metadata": {}, + "source": [ + "# Momento\n", + "\n", + "This notebook goes over how to use [Momento Cache](https://gomomento.com) to store chat message history using the `MomentoChatMessageHistory` class. See the Momento [docs](https://docs.momentohq.com/getting-started) for more detail on how to get set up with Momento.\n", + "\n", + "Note that, by default we will create a cache if one with the given name doesn't already exist.\n", + "\n", + "You'll need to get a Momento auth token to use this class. This can either be passed in to a momento.CacheClient if you'd like to instantiate that directly, as a named parameter `auth_token` to `MomentoChatMessageHistory.from_client_params`, or can just be set as an environment variable `MOMENTO_AUTH_TOKEN`." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d15e3302", + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import timedelta\n", + "\n", + "from langchain.memory import MomentoChatMessageHistory\n", + "\n", + "session_id = \"foo\"\n", + "cache_name = \"langchain\"\n", + "ttl = timedelta(days=1),\n", + "history = MomentoChatMessageHistory.from_client_params(\n", + " session_id, \n", + " cache_name,\n", + " ttl,\n", + ")\n", + "\n", + "history.add_user_message(\"hi!\")\n", + "\n", + "history.add_ai_message(\"whats up?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "64fc465e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n", + " AIMessage(content='whats up?', additional_kwargs={}, example=False)]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "history.messages" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/modules/models/llms/examples/llm_caching.ipynb b/docs/modules/models/llms/examples/llm_caching.ipynb index a049422c..149d7969 100644 --- a/docs/modules/models/llms/examples/llm_caching.ipynb +++ b/docs/modules/models/llms/examples/llm_caching.ipynb @@ -41,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "id": "f69f6283", "metadata": {}, "outputs": [], @@ -612,6 +612,115 @@ "llm(\"Tell me joke\")" ] }, + { + "cell_type": "markdown", + "id": "726fe754", + "metadata": {}, + "source": [ + "## Momento Cache\n", + "Use [Momento](../../../../integrations/momento.md) to cache prompts and responses.\n", + "\n", + "Requires momento to use, uncomment below to install:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8949f29", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install momento" + ] + }, + { + "cell_type": "markdown", + "id": "56ea6a08", + "metadata": {}, + "source": [ + "You'll need to get a Momemto auth token to use this class. This can either be passed in to a momento.CacheClient if you'd like to instantiate that directly, as a named parameter `auth_token` to `MomentoChatMessageHistory.from_client_params`, or can just be set as an environment variable `MOMENTO_AUTH_TOKEN`." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2005f03a", + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import timedelta\n", + "\n", + "from langchain.cache import MomentoCache\n", + "\n", + "\n", + "cache_name = \"langchain\"\n", + "ttl = timedelta(days=1)\n", + "langchain.llm_cache = MomentoCache.from_client_params(cache_name, ttl)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c6a6c238", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 40.7 ms, sys: 16.5 ms, total: 57.2 ms\n", + "Wall time: 1.73 s\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "# The first time, it is not yet in cache, so it should take longer\n", + "llm(\"Tell me a joke\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b8f78f9d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3.16 ms, sys: 2.98 ms, total: 6.14 ms\n", + "Wall time: 57.9 ms\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "# The second time it is, so it goes faster\n", + "# When run in the same region as the cache, latencies are single digit ms\n", + "llm(\"Tell me a joke\")" + ] + }, { "cell_type": "markdown", "id": "934943dc", @@ -909,9 +1018,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "venv", "language": "python", - "name": "python3" + "name": "venv" }, "language_info": { "codemirror_mode": { @@ -923,7 +1032,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.8" + "version": "3.11.3" } }, "nbformat": 4, diff --git a/langchain/cache.py b/langchain/cache.py index 5b2cf2c0..857e3dc5 100644 --- a/langchain/cache.py +++ b/langchain/cache.py @@ -1,14 +1,30 @@ """Beta Feature: base interface for cache.""" +from __future__ import annotations + import hashlib import inspect import json from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast +from datetime import timedelta +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, + cast, +) from sqlalchemy import Column, Integer, String, create_engine, select from sqlalchemy.engine.base import Engine from sqlalchemy.orm import Session +from langchain.utils import get_from_env + try: from sqlalchemy.orm import declarative_base except ImportError: @@ -18,6 +34,9 @@ from langchain.embeddings.base import Embeddings from langchain.schema import Generation from langchain.vectorstores.redis import Redis as RedisVectorstore +if TYPE_CHECKING: + import momento + RETURN_VAL_TYPE = List[Generation] @@ -26,6 +45,39 @@ def _hash(_input: str) -> str: return hashlib.md5(_input.encode()).hexdigest() +def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str: + """Dump generations to json. + + Args: + generations (RETURN_VAL_TYPE): A list of language model generations. + + Returns: + str: Json representing a list of generations. + """ + return json.dumps([generation.dict() for generation in generations]) + + +def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE: + """Load generations from json. + + Args: + generations_json (str): A string of json representing a list of generations. + + Raises: + ValueError: Could not decode json string to list of generations. + + Returns: + RETURN_VAL_TYPE: A list of generations. + """ + try: + results = json.loads(generations_json) + return [Generation(**generation_dict) for generation_dict in results] + except json.JSONDecodeError: + raise ValueError( + f"Could not decode json to list of generations: {generations_json}" + ) + + class BaseCache(ABC): """Base interface for cache.""" @@ -390,3 +442,179 @@ class GPTCache(BaseCache): gptcache_instance.flush() self.gptcache_dict.clear() + + +def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None: + """Create cache if it doesn't exist. + + Raises: + SdkException: Momento service or network error + Exception: Unexpected response + """ + from momento.responses import CreateCache + + create_cache_response = cache_client.create_cache(cache_name) + if isinstance(create_cache_response, CreateCache.Success) or isinstance( + create_cache_response, CreateCache.CacheAlreadyExists + ): + return None + elif isinstance(create_cache_response, CreateCache.Error): + raise create_cache_response.inner_exception + else: + raise Exception(f"Unexpected response cache creation: {create_cache_response}") + + +def _validate_ttl(ttl: Optional[timedelta]) -> None: + if ttl is not None and ttl <= timedelta(seconds=0): + raise ValueError(f"ttl must be positive but was {ttl}.") + + +class MomentoCache(BaseCache): + """Cache that uses Momento as a backend. See https://gomomento.com/""" + + def __init__( + self, + cache_client: momento.CacheClient, + cache_name: str, + *, + ttl: Optional[timedelta] = None, + ensure_cache_exists: bool = True, + ): + """Instantiate a prompt cache using Momento as a backend. + + Note: to instantiate the cache client passed to MomentoCache, + you must have a Momento account. See https://gomomento.com/. + + Args: + cache_client (CacheClient): The Momento cache client. + cache_name (str): The name of the cache to use to store the data. + ttl (Optional[timedelta], optional): The time to live for the cache items. + Defaults to None, ie use the client default TTL. + ensure_cache_exists (bool, optional): Create the cache if it doesn't + exist. Defaults to True. + + Raises: + ImportError: Momento python package is not installed. + TypeError: cache_client is not of type momento.CacheClientObject + ValueError: ttl is non-null and non-negative + """ + try: + from momento import CacheClient + except ImportError: + raise ImportError( + "Could not import momento python package. " + "Please install it with `pip install momento`." + ) + if not isinstance(cache_client, CacheClient): + raise TypeError("cache_client must be a momento.CacheClient object.") + _validate_ttl(ttl) + if ensure_cache_exists: + _ensure_cache_exists(cache_client, cache_name) + + self.cache_client = cache_client + self.cache_name = cache_name + self.ttl = ttl + + @classmethod + def from_client_params( + cls, + cache_name: str, + ttl: timedelta, + *, + configuration: Optional[momento.config.Configuration] = None, + auth_token: Optional[str] = None, + **kwargs: Any, + ) -> MomentoCache: + """Construct cache from CacheClient parameters.""" + try: + from momento import CacheClient, Configurations, CredentialProvider + except ImportError: + raise ImportError( + "Could not import momento python package. " + "Please install it with `pip install momento`." + ) + if configuration is None: + configuration = Configurations.Laptop.v1() + auth_token = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN") + credentials = CredentialProvider.from_string(auth_token) + cache_client = CacheClient(configuration, credentials, default_ttl=ttl) + return cls(cache_client, cache_name, ttl=ttl, **kwargs) + + def __key(self, prompt: str, llm_string: str) -> str: + """Compute cache key from prompt and associated model and settings. + + Args: + prompt (str): The prompt run through the language model. + llm_string (str): The language model version and settings. + + Returns: + str: The cache key. + """ + return _hash(prompt + llm_string) + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Lookup llm generations in cache by prompt and associated model and settings. + + Args: + prompt (str): The prompt run through the language model. + llm_string (str): The language model version and settings. + + Raises: + SdkException: Momento service or network error + + Returns: + Optional[RETURN_VAL_TYPE]: A list of language model generations. + """ + from momento.responses import CacheGet + + generations = [] + + get_response = self.cache_client.get( + self.cache_name, self.__key(prompt, llm_string) + ) + if isinstance(get_response, CacheGet.Hit): + value = get_response.value_string + generations = _load_generations_from_json(value) + elif isinstance(get_response, CacheGet.Miss): + pass + elif isinstance(get_response, CacheGet.Error): + raise get_response.inner_exception + return generations if generations else None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Store llm generations in cache. + + Args: + prompt (str): The prompt run through the language model. + llm_string (str): The language model string. + return_val (RETURN_VAL_TYPE): A list of language model generations. + + Raises: + SdkException: Momento service or network error + Exception: Unexpected response + """ + key = self.__key(prompt, llm_string) + value = _dump_generations_to_json(return_val) + set_response = self.cache_client.set(self.cache_name, key, value, self.ttl) + from momento.responses import CacheSet + + if isinstance(set_response, CacheSet.Success): + pass + elif isinstance(set_response, CacheSet.Error): + raise set_response.inner_exception + else: + raise Exception(f"Unexpected response: {set_response}") + + def clear(self, **kwargs: Any) -> None: + """Clear the cache. + + Raises: + SdkException: Momento service or network error + """ + from momento.responses import CacheFlush + + flush_response = self.cache_client.flush_cache(self.cache_name) + if isinstance(flush_response, CacheFlush.Success): + pass + elif isinstance(flush_response, CacheFlush.Error): + raise flush_response.inner_exception diff --git a/langchain/memory/__init__.py b/langchain/memory/__init__.py index ee491e34..e11316e8 100644 --- a/langchain/memory/__init__.py +++ b/langchain/memory/__init__.py @@ -3,6 +3,7 @@ from langchain.memory.buffer import ( ConversationStringBufferMemory, ) from langchain.memory.buffer_window import ConversationBufferWindowMemory +from langchain.memory.chat_message_histories import MomentoChatMessageHistory from langchain.memory.chat_message_histories.cassandra import ( CassandraChatMessageHistory, ) @@ -50,4 +51,5 @@ __all__ = [ "FileChatMessageHistory", "MongoDBChatMessageHistory", "CassandraChatMessageHistory", + "MomentoChatMessageHistory", ] diff --git a/langchain/memory/chat_message_histories/__init__.py b/langchain/memory/chat_message_histories/__init__.py index a5341ea5..a6a0b125 100644 --- a/langchain/memory/chat_message_histories/__init__.py +++ b/langchain/memory/chat_message_histories/__init__.py @@ -7,6 +7,7 @@ from langchain.memory.chat_message_histories.file import FileChatMessageHistory from langchain.memory.chat_message_histories.firestore import ( FirestoreChatMessageHistory, ) +from langchain.memory.chat_message_histories.momento import MomentoChatMessageHistory from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHistory from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory @@ -24,4 +25,5 @@ __all__ = [ "MongoDBChatMessageHistory", "CassandraChatMessageHistory", "ZepChatMessageHistory", + "MomentoChatMessageHistory", ] diff --git a/langchain/memory/chat_message_histories/momento.py b/langchain/memory/chat_message_histories/momento.py new file mode 100644 index 00000000..1bc74981 --- /dev/null +++ b/langchain/memory/chat_message_histories/momento.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import json +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Optional + +from langchain.schema import ( + AIMessage, + BaseChatMessageHistory, + BaseMessage, + HumanMessage, + _message_to_dict, + messages_from_dict, +) +from langchain.utils import get_from_env + +if TYPE_CHECKING: + import momento + + +def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None: + """Create cache if it doesn't exist. + + Raises: + SdkException: Momento service or network error + Exception: Unexpected response + """ + from momento.responses import CreateCache + + create_cache_response = cache_client.create_cache(cache_name) + if isinstance(create_cache_response, CreateCache.Success) or isinstance( + create_cache_response, CreateCache.CacheAlreadyExists + ): + return None + elif isinstance(create_cache_response, CreateCache.Error): + raise create_cache_response.inner_exception + else: + raise Exception(f"Unexpected response cache creation: {create_cache_response}") + + +class MomentoChatMessageHistory(BaseChatMessageHistory): + """Chat message history cache that uses Momento as a backend. + See https://gomomento.com/""" + + def __init__( + self, + session_id: str, + cache_client: momento.CacheClient, + cache_name: str, + *, + key_prefix: str = "message_store:", + ttl: Optional[timedelta] = None, + ensure_cache_exists: bool = True, + ): + """Instantiate a chat message history cache that uses Momento as a backend. + + Note: to instantiate the cache client passed to MomentoChatMessageHistory, + you must have a Momento account at https://gomomento.com/. + + Args: + session_id (str): The session ID to use for this chat session. + cache_client (CacheClient): The Momento cache client. + cache_name (str): The name of the cache to use to store the messages. + key_prefix (str, optional): The prefix to apply to the cache key. + Defaults to "message_store:". + ttl (Optional[timedelta], optional): The TTL to use for the messages. + Defaults to None, ie the default TTL of the cache will be used. + ensure_cache_exists (bool, optional): Create the cache if it doesn't exist. + Defaults to True. + + Raises: + ImportError: Momento python package is not installed. + TypeError: cache_client is not of type momento.CacheClientObject + """ + try: + from momento import CacheClient + from momento.requests import CollectionTtl + except ImportError: + raise ImportError( + "Could not import momento python package. " + "Please install it with `pip install momento`." + ) + if not isinstance(cache_client, CacheClient): + raise TypeError("cache_client must be a momento.CacheClient object.") + if ensure_cache_exists: + _ensure_cache_exists(cache_client, cache_name) + self.key = key_prefix + session_id + self.cache_client = cache_client + self.cache_name = cache_name + if ttl is not None: + self.ttl = CollectionTtl.of(ttl) + else: + self.ttl = CollectionTtl.from_cache_ttl() + + @classmethod + def from_client_params( + cls, + session_id: str, + cache_name: str, + ttl: timedelta, + *, + configuration: Optional[momento.config.Configuration] = None, + auth_token: Optional[str] = None, + **kwargs: Any, + ) -> MomentoChatMessageHistory: + """Construct cache from CacheClient parameters.""" + try: + from momento import CacheClient, Configurations, CredentialProvider + except ImportError: + raise ImportError( + "Could not import momento python package. " + "Please install it with `pip install momento`." + ) + if configuration is None: + configuration = Configurations.Laptop.v1() + auth_token = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN") + credentials = CredentialProvider.from_string(auth_token) + cache_client = CacheClient(configuration, credentials, default_ttl=ttl) + return cls(session_id, cache_client, cache_name, ttl=ttl, **kwargs) + + @property + def messages(self) -> list[BaseMessage]: # type: ignore[override] + """Retrieve the messages from Momento. + + Raises: + SdkException: Momento service or network error + Exception: Unexpected response + + Returns: + list[BaseMessage]: List of cached messages + """ + from momento.responses import CacheListFetch + + fetch_response = self.cache_client.list_fetch(self.cache_name, self.key) + + if isinstance(fetch_response, CacheListFetch.Hit): + items = [json.loads(m) for m in fetch_response.value_list_string] + return messages_from_dict(items) + elif isinstance(fetch_response, CacheListFetch.Miss): + return [] + elif isinstance(fetch_response, CacheListFetch.Error): + raise fetch_response.inner_exception + else: + raise Exception(f"Unexpected response: {fetch_response}") + + def add_user_message(self, message: str) -> None: + """Store a user message in the cache. + + Args: + message (str): The message to store. + """ + self.__add_message(HumanMessage(content=message)) + + def add_ai_message(self, message: str) -> None: + """Store an AI message in the cache. + + Args: + message (str): The message to store. + """ + self.__add_message(AIMessage(content=message)) + + def __add_message(self, message: BaseMessage) -> None: + """Store a message in the cache. + + Args: + message (BaseMessage): The message object to store. + + Raises: + SdkException: Momento service or network error. + Exception: Unexpected response. + """ + from momento.responses import CacheListPushBack + + item = json.dumps(_message_to_dict(message)) + push_response = self.cache_client.list_push_back( + self.cache_name, self.key, item, ttl=self.ttl + ) + if isinstance(push_response, CacheListPushBack.Success): + return None + elif isinstance(push_response, CacheListPushBack.Error): + raise push_response.inner_exception + else: + raise Exception(f"Unexpected response: {push_response}") + + def clear(self) -> None: + """Remove the session's messages from the cache. + + Raises: + SdkException: Momento service or network error. + Exception: Unexpected response. + """ + from momento.responses import CacheDelete + + delete_response = self.cache_client.delete(self.cache_name, self.key) + if isinstance(delete_response, CacheDelete.Success): + return None + elif isinstance(delete_response, CacheDelete.Error): + raise delete_response.inner_exception + else: + raise Exception(f"Unexpected response: {delete_response}") diff --git a/poetry.lock b/poetry.lock index a316ee03..df7d3dc1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2621,7 +2621,7 @@ name = "grpcio" version = "1.47.5" description = "HTTP/2-based RPC framework" category = "main" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "grpcio-1.47.5-cp310-cp310-linux_armv7l.whl", hash = "sha256:acc73289d0c44650aa1f21eccfa967f5623b01c3b5e2b4596fe5f9c5bf10956d"}, @@ -4462,6 +4462,39 @@ files = [ {file = "mmh3-3.1.0.tar.gz", hash = "sha256:9b0f2b2ab4a915333c9d1089572e290a021ebb5b900bb7f7114dccc03995d732"}, ] +[[package]] +name = "momento" +version = "1.5.0" +description = "SDK for Momento" +category = "main" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "momento-1.5.0-py3-none-any.whl", hash = "sha256:7f633fb26ddf1bfcaf99a7add37b085f0e23c96f85972b655eaaf2de1c61d7f1"}, + {file = "momento-1.5.0.tar.gz", hash = "sha256:68ca5d24b4cb08c5c0bd22d4edd3b8b0fcf087a85d30673cb2c55b11971c76ec"}, +] + +[package.dependencies] +grpcio = ">=1.46.0,<2.0.0" +momento-wire-types = ">=0.64,<0.65" +pyjwt = ">=2.4.0,<3.0.0" + +[[package]] +name = "momento-wire-types" +version = "0.64.0" +description = "Momento Client Proto Generated Files" +category = "main" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "momento_wire_types-0.64.0-py3-none-any.whl", hash = "sha256:30a5d523cef9209c0863db25e6344044b0c7240fea183a41c1433ca83cefea5b"}, + {file = "momento_wire_types-0.64.0.tar.gz", hash = "sha256:5d3647210d49d0c3032a74ae5f5cc012a9faf826786272e1436a3d84d70a8bd5"}, +] + +[package.dependencies] +grpcio = "*" +protobuf = ">=3,<5" + [[package]] name = "monotonic" version = "1.6" @@ -6410,7 +6443,7 @@ name = "protobuf" version = "3.19.6" description = "Protocol Buffers" category = "main" -optional = true +optional = false python-versions = ">=3.5" files = [ {file = "protobuf-3.19.6-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:010be24d5a44be7b0613750ab40bc8b8cedc796db468eae6c779b395f50d1fa1"}, @@ -10865,7 +10898,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "lxml", "manifest-ml", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "requests-toolbelt", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] +all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "lxml", "manifest-ml", "momento", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "requests-toolbelt", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "openai"] cohere = ["cohere"] docarray = ["docarray"] @@ -10879,4 +10912,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "0f2ed0b37063f533b3403d545b48633ced4d02c15bc8f4e47f1ded1652ab9764" +content-hash = "5e83a1f4ca8c0d3107363e393485174fd72ce9db93db5dc7c21b2dd37b184e66" diff --git a/pyproject.toml b/pyproject.toml index 5f595978..75d7bf5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,7 @@ scikit-learn = {version = "^1.2.2", optional = true} azure-ai-formrecognizer = {version = "^3.2.1", optional = true} azure-ai-vision = {version = "^0.11.1b1", optional = true} azure-cognitiveservices-speech = {version = "^1.28.0", optional = true} +momento = {version = "^1.5.0", optional = true} bibtexparser = {version = "^1.4.0", optional = true} [tool.poetry.group.docs.dependencies] @@ -160,6 +161,7 @@ pymongo = "^4.3.3" cassandra-driver = "^3.27.0" arxiv = "^1.4" mastodon-py = "^1.8.1" +momento = "^1.5.0" [tool.poetry.group.lint.dependencies] ruff = "^0.0.249" @@ -253,6 +255,7 @@ all = [ "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", + "momento" ] # An extra used to be able to add extended testing. diff --git a/tests/integration_tests/cache/test_momento_cache.py b/tests/integration_tests/cache/test_momento_cache.py new file mode 100644 index 00000000..dec8f44c --- /dev/null +++ b/tests/integration_tests/cache/test_momento_cache.py @@ -0,0 +1,94 @@ +"""Test Momento cache functionality. + +To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid +Momento auth token. This can be obtained by signing up for a free +Momento account at https://gomomento.com/. +""" +from __future__ import annotations + +import uuid +from datetime import timedelta +from typing import Iterator + +import pytest +from momento import CacheClient, Configurations, CredentialProvider + +import langchain +from langchain.cache import MomentoCache +from langchain.schema import Generation, LLMResult +from tests.unit_tests.llms.fake_llm import FakeLLM + + +def random_string() -> str: + return str(uuid.uuid4()) + + +@pytest.fixture(scope="module") +def momento_cache() -> Iterator[MomentoCache]: + cache_name = f"langchain-test-cache-{random_string()}" + client = CacheClient( + Configurations.Laptop.v1(), + CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"), + default_ttl=timedelta(seconds=30), + ) + try: + llm_cache = MomentoCache(client, cache_name) + langchain.llm_cache = llm_cache + yield llm_cache + finally: + client.delete_cache(cache_name) + + +def test_invalid_ttl() -> None: + client = CacheClient( + Configurations.Laptop.v1(), + CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"), + default_ttl=timedelta(seconds=30), + ) + with pytest.raises(ValueError): + MomentoCache(client, cache_name=random_string(), ttl=timedelta(seconds=-1)) + + +def test_momento_cache_miss(momento_cache: MomentoCache) -> None: + llm = FakeLLM() + stub_llm_output = LLMResult(generations=[[Generation(text="foo")]]) + assert llm.generate([random_string()]) == stub_llm_output + + +@pytest.mark.parametrize( + "prompts, generations", + [ + # Single prompt, single generation + ([random_string()], [[random_string()]]), + # Single prompt, multiple generations + ([random_string()], [[random_string(), random_string()]]), + # Single prompt, multiple generations + ([random_string()], [[random_string(), random_string(), random_string()]]), + # Multiple prompts, multiple generations + ( + [random_string(), random_string()], + [[random_string()], [random_string(), random_string()]], + ), + ], +) +def test_momento_cache_hit( + momento_cache: MomentoCache, prompts: list[str], generations: list[list[str]] +) -> None: + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + + llm_generations = [ + [ + Generation(text=generation, generation_info=params) + for generation in prompt_i_generations + ] + for prompt_i_generations in generations + ] + for prompt_i, llm_generations_i in zip(prompts, llm_generations): + momento_cache.update(prompt_i, llm_string, llm_generations_i) + + assert llm.generate(prompts) == LLMResult( + generations=llm_generations, llm_output={} + ) diff --git a/tests/integration_tests/memory/test_momento.py b/tests/integration_tests/memory/test_momento.py new file mode 100644 index 00000000..0260f6db --- /dev/null +++ b/tests/integration_tests/memory/test_momento.py @@ -0,0 +1,70 @@ +"""Test Momento chat message history functionality. + +To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid +Momento auth token. This can be obtained by signing up for a free +Momento account at https://gomomento.com/. +""" +import json +import uuid +from datetime import timedelta +from typing import Iterator + +import pytest +from momento import CacheClient, Configurations, CredentialProvider + +from langchain.memory import ConversationBufferMemory +from langchain.memory.chat_message_histories import MomentoChatMessageHistory +from langchain.schema import _message_to_dict + + +def random_string() -> str: + return str(uuid.uuid4()) + + +@pytest.fixture(scope="function") +def message_history() -> Iterator[MomentoChatMessageHistory]: + cache_name = f"langchain-test-cache-{random_string()}" + client = CacheClient( + Configurations.Laptop.v1(), + CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"), + default_ttl=timedelta(seconds=30), + ) + try: + chat_message_history = MomentoChatMessageHistory( + session_id="my-test-session", + cache_client=client, + cache_name=cache_name, + ) + yield chat_message_history + finally: + client.delete_cache(cache_name) + + +def test_memory_empty_on_new_session( + message_history: MomentoChatMessageHistory, +) -> None: + memory = ConversationBufferMemory( + memory_key="foo", chat_memory=message_history, return_messages=True + ) + assert memory.chat_memory.messages == [] + + +def test_memory_with_message_store(message_history: MomentoChatMessageHistory) -> None: + memory = ConversationBufferMemory( + memory_key="baz", chat_memory=message_history, return_messages=True + ) + + # Add some messages to the memory store + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + + # Verify that the messages are in the store + messages = memory.chat_memory.messages + messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) + + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # Verify clearing the store + memory.chat_memory.clear() + assert memory.chat_memory.messages == []