forked from Archives/langchain
7047a2c1af
# Add Momento as a standard cache and chat message history provider This PR adds Momento as a standard caching provider. Implements the interface, adds integration tests, and documentation. We also add Momento as a chat history message provider along with integration tests, and documentation. [Momento](https://www.gomomento.com/) is a fully serverless cache. Similar to S3 or DynamoDB, it requires zero configuration, infrastructure management, and is instantly available. Users sign up for free and get 50GB of data in/out for free every month. ## Before submitting ✅ We have added documentation, notebooks, and integration tests demonstrating usage. Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
621 lines
21 KiB
Python
621 lines
21 KiB
Python
"""Beta Feature: base interface for cache."""
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import inspect
|
|
import json
|
|
from abc import ABC, abstractmethod
|
|
from datetime import timedelta
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
from sqlalchemy import Column, Integer, String, create_engine, select
|
|
from sqlalchemy.engine.base import Engine
|
|
from sqlalchemy.orm import Session
|
|
|
|
from langchain.utils import get_from_env
|
|
|
|
try:
|
|
from sqlalchemy.orm import declarative_base
|
|
except ImportError:
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
|
|
from langchain.embeddings.base import Embeddings
|
|
from langchain.schema import Generation
|
|
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
|
|
|
if TYPE_CHECKING:
|
|
import momento
|
|
|
|
RETURN_VAL_TYPE = List[Generation]
|
|
|
|
|
|
def _hash(_input: str) -> str:
|
|
"""Use a deterministic hashing approach."""
|
|
return hashlib.md5(_input.encode()).hexdigest()
|
|
|
|
|
|
def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str:
|
|
"""Dump generations to json.
|
|
|
|
Args:
|
|
generations (RETURN_VAL_TYPE): A list of language model generations.
|
|
|
|
Returns:
|
|
str: Json representing a list of generations.
|
|
"""
|
|
return json.dumps([generation.dict() for generation in generations])
|
|
|
|
|
|
def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
|
|
"""Load generations from json.
|
|
|
|
Args:
|
|
generations_json (str): A string of json representing a list of generations.
|
|
|
|
Raises:
|
|
ValueError: Could not decode json string to list of generations.
|
|
|
|
Returns:
|
|
RETURN_VAL_TYPE: A list of generations.
|
|
"""
|
|
try:
|
|
results = json.loads(generations_json)
|
|
return [Generation(**generation_dict) for generation_dict in results]
|
|
except json.JSONDecodeError:
|
|
raise ValueError(
|
|
f"Could not decode json to list of generations: {generations_json}"
|
|
)
|
|
|
|
|
|
class BaseCache(ABC):
|
|
"""Base interface for cache."""
|
|
|
|
@abstractmethod
|
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
"""Look up based on prompt and llm_string."""
|
|
|
|
@abstractmethod
|
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
"""Update cache based on prompt and llm_string."""
|
|
|
|
@abstractmethod
|
|
def clear(self, **kwargs: Any) -> None:
|
|
"""Clear cache that can take additional keyword arguments."""
|
|
|
|
|
|
class InMemoryCache(BaseCache):
|
|
"""Cache that stores things in memory."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize with empty cache."""
|
|
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
|
|
|
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
"""Look up based on prompt and llm_string."""
|
|
return self._cache.get((prompt, llm_string), None)
|
|
|
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
"""Update cache based on prompt and llm_string."""
|
|
self._cache[(prompt, llm_string)] = return_val
|
|
|
|
def clear(self, **kwargs: Any) -> None:
|
|
"""Clear cache."""
|
|
self._cache = {}
|
|
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class FullLLMCache(Base): # type: ignore
|
|
"""SQLite table for full LLM Cache (all generations)."""
|
|
|
|
__tablename__ = "full_llm_cache"
|
|
prompt = Column(String, primary_key=True)
|
|
llm = Column(String, primary_key=True)
|
|
idx = Column(Integer, primary_key=True)
|
|
response = Column(String)
|
|
|
|
|
|
class SQLAlchemyCache(BaseCache):
|
|
"""Cache that uses SQAlchemy as a backend."""
|
|
|
|
def __init__(self, engine: Engine, cache_schema: Type[FullLLMCache] = FullLLMCache):
|
|
"""Initialize by creating all tables."""
|
|
self.engine = engine
|
|
self.cache_schema = cache_schema
|
|
self.cache_schema.metadata.create_all(self.engine)
|
|
|
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
"""Look up based on prompt and llm_string."""
|
|
stmt = (
|
|
select(self.cache_schema.response)
|
|
.where(self.cache_schema.prompt == prompt) # type: ignore
|
|
.where(self.cache_schema.llm == llm_string)
|
|
.order_by(self.cache_schema.idx)
|
|
)
|
|
with Session(self.engine) as session:
|
|
rows = session.execute(stmt).fetchall()
|
|
if rows:
|
|
return [Generation(text=row[0]) for row in rows]
|
|
return None
|
|
|
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
"""Update based on prompt and llm_string."""
|
|
items = [
|
|
self.cache_schema(prompt=prompt, llm=llm_string, response=gen.text, idx=i)
|
|
for i, gen in enumerate(return_val)
|
|
]
|
|
with Session(self.engine) as session, session.begin():
|
|
for item in items:
|
|
session.merge(item)
|
|
|
|
def clear(self, **kwargs: Any) -> None:
|
|
"""Clear cache."""
|
|
with Session(self.engine) as session:
|
|
session.execute(self.cache_schema.delete())
|
|
|
|
|
|
class SQLiteCache(SQLAlchemyCache):
|
|
"""Cache that uses SQLite as a backend."""
|
|
|
|
def __init__(self, database_path: str = ".langchain.db"):
|
|
"""Initialize by creating the engine and all tables."""
|
|
engine = create_engine(f"sqlite:///{database_path}")
|
|
super().__init__(engine)
|
|
|
|
|
|
class RedisCache(BaseCache):
|
|
"""Cache that uses Redis as a backend."""
|
|
|
|
# TODO - implement a TTL policy in Redis
|
|
|
|
def __init__(self, redis_: Any):
|
|
"""Initialize by passing in Redis instance."""
|
|
try:
|
|
from redis import Redis
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import redis python package. "
|
|
"Please install it with `pip install redis`."
|
|
)
|
|
if not isinstance(redis_, Redis):
|
|
raise ValueError("Please pass in Redis object.")
|
|
self.redis = redis_
|
|
|
|
def _key(self, prompt: str, llm_string: str) -> str:
|
|
"""Compute key from prompt and llm_string"""
|
|
return _hash(prompt + llm_string)
|
|
|
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
"""Look up based on prompt and llm_string."""
|
|
generations = []
|
|
# Read from a Redis HASH
|
|
results = self.redis.hgetall(self._key(prompt, llm_string))
|
|
if results:
|
|
for _, text in results.items():
|
|
generations.append(Generation(text=text))
|
|
return generations if generations else None
|
|
|
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
"""Update cache based on prompt and llm_string."""
|
|
# Write to a Redis HASH
|
|
key = self._key(prompt, llm_string)
|
|
self.redis.hset(
|
|
key,
|
|
mapping={
|
|
str(idx): generation.text for idx, generation in enumerate(return_val)
|
|
},
|
|
)
|
|
|
|
def clear(self, **kwargs: Any) -> None:
|
|
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
|
|
asynchronous = kwargs.get("asynchronous", False)
|
|
self.redis.flushdb(asynchronous=asynchronous, **kwargs)
|
|
|
|
|
|
class RedisSemanticCache(BaseCache):
|
|
"""Cache that uses Redis as a vector-store backend."""
|
|
|
|
# TODO - implement a TTL policy in Redis
|
|
|
|
def __init__(
|
|
self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2
|
|
):
|
|
"""Initialize by passing in the `init` GPTCache func
|
|
|
|
Args:
|
|
redis_url (str): URL to connect to Redis.
|
|
embedding (Embedding): Embedding provider for semantic encoding and search.
|
|
score_threshold (float, 0.2):
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
import langchain
|
|
|
|
from langchain.cache import RedisSemanticCache
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
|
|
langchain.llm_cache = RedisSemanticCache(
|
|
redis_url="redis://localhost:6379",
|
|
embedding=OpenAIEmbeddings()
|
|
)
|
|
|
|
"""
|
|
self._cache_dict: Dict[str, RedisVectorstore] = {}
|
|
self.redis_url = redis_url
|
|
self.embedding = embedding
|
|
self.score_threshold = score_threshold
|
|
|
|
def _index_name(self, llm_string: str) -> str:
|
|
hashed_index = _hash(llm_string)
|
|
return f"cache:{hashed_index}"
|
|
|
|
def _get_llm_cache(self, llm_string: str) -> RedisVectorstore:
|
|
index_name = self._index_name(llm_string)
|
|
|
|
# return vectorstore client for the specific llm string
|
|
if index_name in self._cache_dict:
|
|
return self._cache_dict[index_name]
|
|
|
|
# create new vectorstore client for the specific llm string
|
|
try:
|
|
self._cache_dict[index_name] = RedisVectorstore.from_existing_index(
|
|
embedding=self.embedding,
|
|
index_name=index_name,
|
|
redis_url=self.redis_url,
|
|
)
|
|
except ValueError:
|
|
redis = RedisVectorstore(
|
|
embedding_function=self.embedding.embed_query,
|
|
index_name=index_name,
|
|
redis_url=self.redis_url,
|
|
)
|
|
_embedding = self.embedding.embed_query(text="test")
|
|
redis._create_index(dim=len(_embedding))
|
|
self._cache_dict[index_name] = redis
|
|
|
|
return self._cache_dict[index_name]
|
|
|
|
def clear(self, **kwargs: Any) -> None:
|
|
"""Clear semantic cache for a given llm_string."""
|
|
index_name = self._index_name(kwargs["llm_string"])
|
|
if index_name in self._cache_dict:
|
|
self._cache_dict[index_name].drop_index(
|
|
index_name=index_name, delete_documents=True, redis_url=self.redis_url
|
|
)
|
|
del self._cache_dict[index_name]
|
|
|
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
"""Look up based on prompt and llm_string."""
|
|
llm_cache = self._get_llm_cache(llm_string)
|
|
generations = []
|
|
# Read from a Hash
|
|
results = llm_cache.similarity_search_limit_score(
|
|
query=prompt,
|
|
k=1,
|
|
score_threshold=self.score_threshold,
|
|
)
|
|
if results:
|
|
for document in results:
|
|
for text in document.metadata["return_val"]:
|
|
generations.append(Generation(text=text))
|
|
return generations if generations else None
|
|
|
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
"""Update cache based on prompt and llm_string."""
|
|
llm_cache = self._get_llm_cache(llm_string)
|
|
# Write to vectorstore
|
|
metadata = {
|
|
"llm_string": llm_string,
|
|
"prompt": prompt,
|
|
"return_val": [generation.text for generation in return_val],
|
|
}
|
|
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
|
|
|
|
|
|
class GPTCache(BaseCache):
|
|
"""Cache that uses GPTCache as a backend."""
|
|
|
|
def __init__(
|
|
self,
|
|
init_func: Union[
|
|
Callable[[Any, str], None], Callable[[Any], None], None
|
|
] = None,
|
|
):
|
|
"""Initialize by passing in init function (default: `None`).
|
|
|
|
Args:
|
|
init_func (Optional[Callable[[Any], None]]): init `GPTCache` function
|
|
(default: `None`)
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
# Initialize GPTCache with a custom init function
|
|
import gptcache
|
|
from gptcache.processor.pre import get_prompt
|
|
from gptcache.manager.factory import get_data_manager
|
|
|
|
# Avoid multiple caches using the same file,
|
|
causing different llm model caches to affect each other
|
|
|
|
def init_gptcache(cache_obj: gptcache.Cache, llm str):
|
|
cache_obj.init(
|
|
pre_embedding_func=get_prompt,
|
|
data_manager=manager_factory(
|
|
manager="map",
|
|
data_dir=f"map_cache_{llm}"
|
|
),
|
|
)
|
|
|
|
langchain.llm_cache = GPTCache(init_gptcache)
|
|
|
|
"""
|
|
try:
|
|
import gptcache # noqa: F401
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import gptcache python package. "
|
|
"Please install it with `pip install gptcache`."
|
|
)
|
|
|
|
self.init_gptcache_func: Union[
|
|
Callable[[Any, str], None], Callable[[Any], None], None
|
|
] = init_func
|
|
self.gptcache_dict: Dict[str, Any] = {}
|
|
|
|
def _new_gptcache(self, llm_string: str) -> Any:
|
|
"""New gptcache object"""
|
|
from gptcache import Cache
|
|
from gptcache.manager.factory import get_data_manager
|
|
from gptcache.processor.pre import get_prompt
|
|
|
|
_gptcache = Cache()
|
|
if self.init_gptcache_func is not None:
|
|
sig = inspect.signature(self.init_gptcache_func)
|
|
if len(sig.parameters) == 2:
|
|
self.init_gptcache_func(_gptcache, llm_string) # type: ignore[call-arg]
|
|
else:
|
|
self.init_gptcache_func(_gptcache) # type: ignore[call-arg]
|
|
else:
|
|
_gptcache.init(
|
|
pre_embedding_func=get_prompt,
|
|
data_manager=get_data_manager(data_path=llm_string),
|
|
)
|
|
|
|
self.gptcache_dict[llm_string] = _gptcache
|
|
return _gptcache
|
|
|
|
def _get_gptcache(self, llm_string: str) -> Any:
|
|
"""Get a cache object.
|
|
|
|
When the corresponding llm model cache does not exist, it will be created."""
|
|
|
|
return self.gptcache_dict.get(llm_string, self._new_gptcache(llm_string))
|
|
|
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
"""Look up the cache data.
|
|
First, retrieve the corresponding cache object using the `llm_string` parameter,
|
|
and then retrieve the data from the cache based on the `prompt`.
|
|
"""
|
|
from gptcache.adapter.api import get
|
|
|
|
_gptcache = self.gptcache_dict.get(llm_string, None)
|
|
if _gptcache is None:
|
|
return None
|
|
res = get(prompt, cache_obj=_gptcache)
|
|
if res:
|
|
return [
|
|
Generation(**generation_dict) for generation_dict in json.loads(res)
|
|
]
|
|
return None
|
|
|
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
"""Update cache.
|
|
First, retrieve the corresponding cache object using the `llm_string` parameter,
|
|
and then store the `prompt` and `return_val` in the cache object.
|
|
"""
|
|
from gptcache.adapter.api import put
|
|
|
|
_gptcache = self._get_gptcache(llm_string)
|
|
handled_data = json.dumps([generation.dict() for generation in return_val])
|
|
put(prompt, handled_data, cache_obj=_gptcache)
|
|
return None
|
|
|
|
def clear(self, **kwargs: Any) -> None:
|
|
"""Clear cache."""
|
|
from gptcache import Cache
|
|
|
|
for gptcache_instance in self.gptcache_dict.values():
|
|
gptcache_instance = cast(Cache, gptcache_instance)
|
|
gptcache_instance.flush()
|
|
|
|
self.gptcache_dict.clear()
|
|
|
|
|
|
def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None:
|
|
"""Create cache if it doesn't exist.
|
|
|
|
Raises:
|
|
SdkException: Momento service or network error
|
|
Exception: Unexpected response
|
|
"""
|
|
from momento.responses import CreateCache
|
|
|
|
create_cache_response = cache_client.create_cache(cache_name)
|
|
if isinstance(create_cache_response, CreateCache.Success) or isinstance(
|
|
create_cache_response, CreateCache.CacheAlreadyExists
|
|
):
|
|
return None
|
|
elif isinstance(create_cache_response, CreateCache.Error):
|
|
raise create_cache_response.inner_exception
|
|
else:
|
|
raise Exception(f"Unexpected response cache creation: {create_cache_response}")
|
|
|
|
|
|
def _validate_ttl(ttl: Optional[timedelta]) -> None:
|
|
if ttl is not None and ttl <= timedelta(seconds=0):
|
|
raise ValueError(f"ttl must be positive but was {ttl}.")
|
|
|
|
|
|
class MomentoCache(BaseCache):
|
|
"""Cache that uses Momento as a backend. See https://gomomento.com/"""
|
|
|
|
def __init__(
|
|
self,
|
|
cache_client: momento.CacheClient,
|
|
cache_name: str,
|
|
*,
|
|
ttl: Optional[timedelta] = None,
|
|
ensure_cache_exists: bool = True,
|
|
):
|
|
"""Instantiate a prompt cache using Momento as a backend.
|
|
|
|
Note: to instantiate the cache client passed to MomentoCache,
|
|
you must have a Momento account. See https://gomomento.com/.
|
|
|
|
Args:
|
|
cache_client (CacheClient): The Momento cache client.
|
|
cache_name (str): The name of the cache to use to store the data.
|
|
ttl (Optional[timedelta], optional): The time to live for the cache items.
|
|
Defaults to None, ie use the client default TTL.
|
|
ensure_cache_exists (bool, optional): Create the cache if it doesn't
|
|
exist. Defaults to True.
|
|
|
|
Raises:
|
|
ImportError: Momento python package is not installed.
|
|
TypeError: cache_client is not of type momento.CacheClientObject
|
|
ValueError: ttl is non-null and non-negative
|
|
"""
|
|
try:
|
|
from momento import CacheClient
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import momento python package. "
|
|
"Please install it with `pip install momento`."
|
|
)
|
|
if not isinstance(cache_client, CacheClient):
|
|
raise TypeError("cache_client must be a momento.CacheClient object.")
|
|
_validate_ttl(ttl)
|
|
if ensure_cache_exists:
|
|
_ensure_cache_exists(cache_client, cache_name)
|
|
|
|
self.cache_client = cache_client
|
|
self.cache_name = cache_name
|
|
self.ttl = ttl
|
|
|
|
@classmethod
|
|
def from_client_params(
|
|
cls,
|
|
cache_name: str,
|
|
ttl: timedelta,
|
|
*,
|
|
configuration: Optional[momento.config.Configuration] = None,
|
|
auth_token: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> MomentoCache:
|
|
"""Construct cache from CacheClient parameters."""
|
|
try:
|
|
from momento import CacheClient, Configurations, CredentialProvider
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import momento python package. "
|
|
"Please install it with `pip install momento`."
|
|
)
|
|
if configuration is None:
|
|
configuration = Configurations.Laptop.v1()
|
|
auth_token = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN")
|
|
credentials = CredentialProvider.from_string(auth_token)
|
|
cache_client = CacheClient(configuration, credentials, default_ttl=ttl)
|
|
return cls(cache_client, cache_name, ttl=ttl, **kwargs)
|
|
|
|
def __key(self, prompt: str, llm_string: str) -> str:
|
|
"""Compute cache key from prompt and associated model and settings.
|
|
|
|
Args:
|
|
prompt (str): The prompt run through the language model.
|
|
llm_string (str): The language model version and settings.
|
|
|
|
Returns:
|
|
str: The cache key.
|
|
"""
|
|
return _hash(prompt + llm_string)
|
|
|
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
"""Lookup llm generations in cache by prompt and associated model and settings.
|
|
|
|
Args:
|
|
prompt (str): The prompt run through the language model.
|
|
llm_string (str): The language model version and settings.
|
|
|
|
Raises:
|
|
SdkException: Momento service or network error
|
|
|
|
Returns:
|
|
Optional[RETURN_VAL_TYPE]: A list of language model generations.
|
|
"""
|
|
from momento.responses import CacheGet
|
|
|
|
generations = []
|
|
|
|
get_response = self.cache_client.get(
|
|
self.cache_name, self.__key(prompt, llm_string)
|
|
)
|
|
if isinstance(get_response, CacheGet.Hit):
|
|
value = get_response.value_string
|
|
generations = _load_generations_from_json(value)
|
|
elif isinstance(get_response, CacheGet.Miss):
|
|
pass
|
|
elif isinstance(get_response, CacheGet.Error):
|
|
raise get_response.inner_exception
|
|
return generations if generations else None
|
|
|
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
"""Store llm generations in cache.
|
|
|
|
Args:
|
|
prompt (str): The prompt run through the language model.
|
|
llm_string (str): The language model string.
|
|
return_val (RETURN_VAL_TYPE): A list of language model generations.
|
|
|
|
Raises:
|
|
SdkException: Momento service or network error
|
|
Exception: Unexpected response
|
|
"""
|
|
key = self.__key(prompt, llm_string)
|
|
value = _dump_generations_to_json(return_val)
|
|
set_response = self.cache_client.set(self.cache_name, key, value, self.ttl)
|
|
from momento.responses import CacheSet
|
|
|
|
if isinstance(set_response, CacheSet.Success):
|
|
pass
|
|
elif isinstance(set_response, CacheSet.Error):
|
|
raise set_response.inner_exception
|
|
else:
|
|
raise Exception(f"Unexpected response: {set_response}")
|
|
|
|
def clear(self, **kwargs: Any) -> None:
|
|
"""Clear the cache.
|
|
|
|
Raises:
|
|
SdkException: Momento service or network error
|
|
"""
|
|
from momento.responses import CacheFlush
|
|
|
|
flush_response = self.cache_client.flush_cache(self.cache_name)
|
|
if isinstance(flush_response, CacheFlush.Success):
|
|
pass
|
|
elif isinstance(flush_response, CacheFlush.Error):
|
|
raise flush_response.inner_exception
|