mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
190 lines
6.9 KiB
Python
190 lines
6.9 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import json
|
||
|
from datetime import timedelta
|
||
|
from typing import TYPE_CHECKING, Any, Optional
|
||
|
|
||
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||
|
from langchain_core.messages import (
|
||
|
BaseMessage,
|
||
|
message_to_dict,
|
||
|
messages_from_dict,
|
||
|
)
|
||
|
from langchain_core.utils import get_from_env
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
import momento
|
||
|
|
||
|
|
||
|
def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None:
|
||
|
"""Create cache if it doesn't exist.
|
||
|
|
||
|
Raises:
|
||
|
SdkException: Momento service or network error
|
||
|
Exception: Unexpected response
|
||
|
"""
|
||
|
from momento.responses import CreateCache
|
||
|
|
||
|
create_cache_response = cache_client.create_cache(cache_name)
|
||
|
if isinstance(create_cache_response, CreateCache.Success) or isinstance(
|
||
|
create_cache_response, CreateCache.CacheAlreadyExists
|
||
|
):
|
||
|
return None
|
||
|
elif isinstance(create_cache_response, CreateCache.Error):
|
||
|
raise create_cache_response.inner_exception
|
||
|
else:
|
||
|
raise Exception(f"Unexpected response cache creation: {create_cache_response}")
|
||
|
|
||
|
|
||
|
class MomentoChatMessageHistory(BaseChatMessageHistory):
|
||
|
"""Chat message history cache that uses Momento as a backend.
|
||
|
|
||
|
See https://gomomento.com/"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
session_id: str,
|
||
|
cache_client: momento.CacheClient,
|
||
|
cache_name: str,
|
||
|
*,
|
||
|
key_prefix: str = "message_store:",
|
||
|
ttl: Optional[timedelta] = None,
|
||
|
ensure_cache_exists: bool = True,
|
||
|
):
|
||
|
"""Instantiate a chat message history cache that uses Momento as a backend.
|
||
|
|
||
|
Note: to instantiate the cache client passed to MomentoChatMessageHistory,
|
||
|
you must have a Momento account at https://gomomento.com/.
|
||
|
|
||
|
Args:
|
||
|
session_id (str): The session ID to use for this chat session.
|
||
|
cache_client (CacheClient): The Momento cache client.
|
||
|
cache_name (str): The name of the cache to use to store the messages.
|
||
|
key_prefix (str, optional): The prefix to apply to the cache key.
|
||
|
Defaults to "message_store:".
|
||
|
ttl (Optional[timedelta], optional): The TTL to use for the messages.
|
||
|
Defaults to None, ie the default TTL of the cache will be used.
|
||
|
ensure_cache_exists (bool, optional): Create the cache if it doesn't exist.
|
||
|
Defaults to True.
|
||
|
|
||
|
Raises:
|
||
|
ImportError: Momento python package is not installed.
|
||
|
TypeError: cache_client is not of type momento.CacheClientObject
|
||
|
"""
|
||
|
try:
|
||
|
from momento import CacheClient
|
||
|
from momento.requests import CollectionTtl
|
||
|
except ImportError:
|
||
|
raise ImportError(
|
||
|
"Could not import momento python package. "
|
||
|
"Please install it with `pip install momento`."
|
||
|
)
|
||
|
if not isinstance(cache_client, CacheClient):
|
||
|
raise TypeError("cache_client must be a momento.CacheClient object.")
|
||
|
if ensure_cache_exists:
|
||
|
_ensure_cache_exists(cache_client, cache_name)
|
||
|
self.key = key_prefix + session_id
|
||
|
self.cache_client = cache_client
|
||
|
self.cache_name = cache_name
|
||
|
if ttl is not None:
|
||
|
self.ttl = CollectionTtl.of(ttl)
|
||
|
else:
|
||
|
self.ttl = CollectionTtl.from_cache_ttl()
|
||
|
|
||
|
@classmethod
|
||
|
def from_client_params(
|
||
|
cls,
|
||
|
session_id: str,
|
||
|
cache_name: str,
|
||
|
ttl: timedelta,
|
||
|
*,
|
||
|
configuration: Optional[momento.config.Configuration] = None,
|
||
|
api_key: Optional[str] = None,
|
||
|
auth_token: Optional[str] = None, # for backwards compatibility
|
||
|
**kwargs: Any,
|
||
|
) -> MomentoChatMessageHistory:
|
||
|
"""Construct cache from CacheClient parameters."""
|
||
|
try:
|
||
|
from momento import CacheClient, Configurations, CredentialProvider
|
||
|
except ImportError:
|
||
|
raise ImportError(
|
||
|
"Could not import momento python package. "
|
||
|
"Please install it with `pip install momento`."
|
||
|
)
|
||
|
if configuration is None:
|
||
|
configuration = Configurations.Laptop.v1()
|
||
|
|
||
|
# Try checking `MOMENTO_AUTH_TOKEN` first for backwards compatibility
|
||
|
try:
|
||
|
api_key = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN")
|
||
|
except ValueError:
|
||
|
api_key = api_key or get_from_env("api_key", "MOMENTO_API_KEY")
|
||
|
credentials = CredentialProvider.from_string(api_key)
|
||
|
cache_client = CacheClient(configuration, credentials, default_ttl=ttl)
|
||
|
return cls(session_id, cache_client, cache_name, ttl=ttl, **kwargs)
|
||
|
|
||
|
@property
|
||
|
def messages(self) -> list[BaseMessage]: # type: ignore[override]
|
||
|
"""Retrieve the messages from Momento.
|
||
|
|
||
|
Raises:
|
||
|
SdkException: Momento service or network error
|
||
|
Exception: Unexpected response
|
||
|
|
||
|
Returns:
|
||
|
list[BaseMessage]: List of cached messages
|
||
|
"""
|
||
|
from momento.responses import CacheListFetch
|
||
|
|
||
|
fetch_response = self.cache_client.list_fetch(self.cache_name, self.key)
|
||
|
|
||
|
if isinstance(fetch_response, CacheListFetch.Hit):
|
||
|
items = [json.loads(m) for m in fetch_response.value_list_string]
|
||
|
return messages_from_dict(items)
|
||
|
elif isinstance(fetch_response, CacheListFetch.Miss):
|
||
|
return []
|
||
|
elif isinstance(fetch_response, CacheListFetch.Error):
|
||
|
raise fetch_response.inner_exception
|
||
|
else:
|
||
|
raise Exception(f"Unexpected response: {fetch_response}")
|
||
|
|
||
|
def add_message(self, message: BaseMessage) -> None:
|
||
|
"""Store a message in the cache.
|
||
|
|
||
|
Args:
|
||
|
message (BaseMessage): The message object to store.
|
||
|
|
||
|
Raises:
|
||
|
SdkException: Momento service or network error.
|
||
|
Exception: Unexpected response.
|
||
|
"""
|
||
|
from momento.responses import CacheListPushBack
|
||
|
|
||
|
item = json.dumps(message_to_dict(message))
|
||
|
push_response = self.cache_client.list_push_back(
|
||
|
self.cache_name, self.key, item, ttl=self.ttl
|
||
|
)
|
||
|
if isinstance(push_response, CacheListPushBack.Success):
|
||
|
return None
|
||
|
elif isinstance(push_response, CacheListPushBack.Error):
|
||
|
raise push_response.inner_exception
|
||
|
else:
|
||
|
raise Exception(f"Unexpected response: {push_response}")
|
||
|
|
||
|
def clear(self) -> None:
|
||
|
"""Remove the session's messages from the cache.
|
||
|
|
||
|
Raises:
|
||
|
SdkException: Momento service or network error.
|
||
|
Exception: Unexpected response.
|
||
|
"""
|
||
|
from momento.responses import CacheDelete
|
||
|
|
||
|
delete_response = self.cache_client.delete(self.cache_name, self.key)
|
||
|
if isinstance(delete_response, CacheDelete.Success):
|
||
|
return None
|
||
|
elif isinstance(delete_response, CacheDelete.Error):
|
||
|
raise delete_response.inner_exception
|
||
|
else:
|
||
|
raise Exception(f"Unexpected response: {delete_response}")
|