Cassandra support for LLM cache (exact-match and semantic) (#9772)

This PR implements two new classes in the cache module: `CassandraCache`
and `CassandraSemanticCache`, similar in structure and functionality to
their Redis counterpart: providing a cache for the response to a
(prompt, llm) pair.

Integration tests are included. Moreover, linting and type checks are
all passing on my machine.

Dependencies: the `pyproject.toml` and `poetry.lock` have the newest
version of cassIO (the very same as in the Cassandra vector store
metadata PR, submitted as #9280).

If I may suggest, this issue and #9280 might be reviewed together (as
they bring the same poetry changes along), so I'm tagging @baskaryan who
already helped out a little with poetry-related conflicts there. (Thank
you!)

I'd be happy to add a short notebook if this is deemed necessary (but it
seems to me that, contrary e.g. to vector stores, caches are not covered
in specific notebooks).

Thank you!

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Stefano Lottini 2023-09-04 05:27:02 +02:00 committed by GitHub
parent 16945c9922
commit c9ff0ab2e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 849 additions and 31 deletions

View File

@ -27,6 +27,7 @@ import json
import logging import logging
import warnings import warnings
from datetime import timedelta from datetime import timedelta
from functools import lru_cache
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -51,6 +52,7 @@ except ImportError:
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.llms.base import LLM, get_prompts
from langchain.load.dump import dumps from langchain.load.dump import dumps
from langchain.load.load import loads from langchain.load.load import loads
from langchain.schema import ChatGeneration, Generation from langchain.schema import ChatGeneration, Generation
@ -62,6 +64,7 @@ logger = logging.getLogger(__file__)
if TYPE_CHECKING: if TYPE_CHECKING:
import momento import momento
from cassandra.cluster import Session as CassandraSession
def _hash(_input: str) -> str: def _hash(_input: str) -> str:
@ -711,3 +714,264 @@ class MomentoCache(BaseCache):
pass pass
elif isinstance(flush_response, CacheFlush.Error): elif isinstance(flush_response, CacheFlush.Error):
raise flush_response.inner_exception raise flush_response.inner_exception
CASSANDRA_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_cache"
CASSANDRA_CACHE_DEFAULT_TTL_SECONDS = None
class CassandraCache(BaseCache):
"""
Cache that uses Cassandra / Astra DB as a backend.
It uses a single Cassandra table.
The lookup keys (which get to form the primary key) are:
- prompt, a string
- llm_string, a deterministic str representation of the model parameters.
(needed to prevent collisions same-prompt-different-model collisions)
"""
def __init__(
self,
session: CassandraSession,
keyspace: str,
table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME,
ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS,
):
"""
Initialize with a ready session and a keyspace name.
Args:
session (cassandra.cluster.Session): an open Cassandra session
keyspace (str): the keyspace to use for storing the cache
table_name (str): name of the Cassandra table to use as cache
ttl_seconds (optional int): time-to-live for cache entries
(default: None, i.e. forever)
"""
try:
from cassio.table import ElasticCassandraTable
except (ImportError, ModuleNotFoundError):
raise ValueError(
"Could not import cassio python package. "
"Please install it with `pip install cassio`."
)
self.session = session
self.keyspace = keyspace
self.table_name = table_name
self.ttl_seconds = ttl_seconds
self.kv_cache = ElasticCassandraTable(
session=self.session,
keyspace=self.keyspace,
table=self.table_name,
keys=["llm_string", "prompt"],
primary_key_type=["TEXT", "TEXT"],
ttl_seconds=self.ttl_seconds,
)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
item = self.kv_cache.get(
llm_string=_hash(llm_string),
prompt=_hash(prompt),
)
if item:
return _load_generations_from_json(item["body_blob"])
else:
return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
blob = _dump_generations_to_json(return_val)
self.kv_cache.put(
llm_string=_hash(llm_string),
prompt=_hash(prompt),
body_blob=blob,
)
def delete_through_llm(
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
) -> None:
"""
A wrapper around `delete` with the LLM being passed.
In case the llm(prompt) calls have a `stop` param, you should pass it here
"""
llm_string = get_prompts(
{**llm.dict(), **{"stop": stop}},
[],
)[1]
return self.delete(prompt, llm_string=llm_string)
def delete(self, prompt: str, llm_string: str) -> None:
"""Evict from cache if there's an entry."""
return self.kv_cache.delete(
llm_string=_hash(llm_string),
prompt=_hash(prompt),
)
def clear(self, **kwargs: Any) -> None:
"""Clear cache. This is for all LLMs at once."""
self.kv_cache.clear()
CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC = "dot"
CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD = 0.85
CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_semantic_cache"
CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS = None
CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16
class CassandraSemanticCache(BaseCache):
"""
Cache that uses Cassandra as a vector-store backend for semantic
(i.e. similarity-based) lookup.
It uses a single (vector) Cassandra table and stores, in principle,
cached values from several LLMs, so the LLM's llm_string is part
of the rows' primary keys.
The similarity is based on one of several distance metrics (default: "dot").
If choosing another metric, the default threshold is to be re-tuned accordingly.
"""
def __init__(
self,
session: CassandraSession,
keyspace: str,
embedding: Embeddings,
table_name: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME,
distance_metric: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC,
score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD,
ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS,
):
"""
Initialize the cache with all relevant parameters.
Args:
session (cassandra.cluster.Session): an open Cassandra session
keyspace (str): the keyspace to use for storing the cache
embedding (Embedding): Embedding provider for semantic
encoding and search.
table_name (str): name of the Cassandra (vector) table
to use as cache
distance_metric (str, 'dot'): which measure to adopt for
similarity searches
score_threshold (optional float): numeric value to use as
cutoff for the similarity searches
ttl_seconds (optional int): time-to-live for cache entries
(default: None, i.e. forever)
The default score threshold is tuned to the default metric.
Tune it carefully yourself if switching to another distance metric.
"""
try:
from cassio.table import MetadataVectorCassandraTable
except (ImportError, ModuleNotFoundError):
raise ValueError(
"Could not import cassio python package. "
"Please install it with `pip install cassio`."
)
self.session = session
self.keyspace = keyspace
self.embedding = embedding
self.table_name = table_name
self.distance_metric = distance_metric
self.score_threshold = score_threshold
self.ttl_seconds = ttl_seconds
# The contract for this class has separate lookup and update:
# in order to spare some embedding calculations we cache them between
# the two calls.
# Note: each instance of this class has its own `_get_embedding` with
# its own lru.
@lru_cache(maxsize=CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
def _cache_embedding(text: str) -> List[float]:
return self.embedding.embed_query(text=text)
self._get_embedding = _cache_embedding
self.embedding_dimension = self._get_embedding_dimension()
self.table = MetadataVectorCassandraTable(
session=self.session,
keyspace=self.keyspace,
table=self.table_name,
primary_key_type=["TEXT"],
vector_dimension=self.embedding_dimension,
ttl_seconds=self.ttl_seconds,
metadata_indexing=("allow", {"_llm_string_hash"}),
)
def _get_embedding_dimension(self) -> int:
return len(self._get_embedding(text="This is a sample sentence."))
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
embedding_vector = self._get_embedding(text=prompt)
llm_string_hash = _hash(llm_string)
body = _dump_generations_to_json(return_val)
metadata = {
"_prompt": prompt,
"_llm_string_hash": llm_string_hash,
}
row_id = f"{_hash(prompt)}-{llm_string_hash}"
#
self.table.put(
body_blob=body,
vector=embedding_vector,
row_id=row_id,
metadata=metadata,
)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
hit_with_id = self.lookup_with_id(prompt, llm_string)
if hit_with_id is not None:
return hit_with_id[1]
else:
return None
def lookup_with_id(
self, prompt: str, llm_string: str
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
"""
Look up based on prompt and llm_string.
If there are hits, return (document_id, cached_entry)
"""
prompt_embedding: List[float] = self._get_embedding(text=prompt)
hits = list(
self.table.metric_ann_search(
vector=prompt_embedding,
metadata={"_llm_string_hash": _hash(llm_string)},
n=1,
metric=self.distance_metric,
metric_threshold=self.score_threshold,
)
)
if hits:
hit = hits[0]
generations_str = hit["body_blob"]
return (
hit["row_id"],
_load_generations_from_json(generations_str),
)
else:
return None
def lookup_with_id_through_llm(
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
llm_string = get_prompts(
{**llm.dict(), **{"stop": stop}},
[],
)[1]
return self.lookup_with_id(prompt, llm_string=llm_string)
def delete_by_document_id(self, document_id: str) -> None:
"""
Given this is a "similarity search" cache, an invalidation pattern
that makes sense is first a lookup to get an ID, and then deleting
with that ID. This is for the second step.
"""
self.table.delete(row_id=document_id)
def clear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache."""
self.table.clear()

File diff suppressed because it is too large Load Diff

View File

@ -109,7 +109,7 @@ azure-search-documents = {version = "11.4.0b8", optional = true}
esprima = {version = "^4.0.1", optional = true} esprima = {version = "^4.0.1", optional = true}
streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"} streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
psychicapi = {version = "^0.8.0", optional = true} psychicapi = {version = "^0.8.0", optional = true}
cassio = {version = "^0.0.7", optional = true} cassio = {version = "^0.1.0", optional = true}
rdflib = {version = "^6.3.2", optional = true} rdflib = {version = "^6.3.2", optional = true}
sympy = {version = "^1.12", optional = true} sympy = {version = "^1.12", optional = true}
rapidfuzz = {version = "^3.1.1", optional = true} rapidfuzz = {version = "^3.1.1", optional = true}
@ -174,7 +174,7 @@ pytest-vcr = "^1.0.2"
wrapt = "^1.15.0" wrapt = "^1.15.0"
openai = "^0.27.4" openai = "^0.27.4"
python-dotenv = "^1.0.0" python-dotenv = "^1.0.0"
cassio = "^0.0.7" cassio = "^0.1.0"
tiktoken = "^0.3.2" tiktoken = "^0.3.2"
[tool.poetry.group.lint.dependencies] [tool.poetry.group.lint.dependencies]

View File

@ -0,0 +1,100 @@
"""Test Cassandra caches. Requires a running vector-capable Cassandra cluster."""
import os
import time
from typing import Any, Iterator, Tuple
import pytest
import langchain
from langchain.cache import CassandraCache, CassandraSemanticCache
from langchain.schema import Generation, LLMResult
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.fixture(scope="module")
def cassandra_connection() -> Iterator[Tuple[Any, str]]:
from cassandra.cluster import Cluster
keyspace = "langchain_cache_test_keyspace"
# get db connection
if "CASSANDRA_CONTACT_POINTS" in os.environ:
contact_points = os.environ["CONTACT_POINTS"].split(",")
cluster = Cluster(contact_points)
else:
cluster = Cluster()
#
session = cluster.connect()
# ensure keyspace exists
session.execute(
(
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
)
)
yield (session, keyspace)
def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection
cache = CassandraCache(session, keyspace)
langchain.llm_cache = cache
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
print(output)
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
print(expected_output)
assert output == expected_output
cache.clear()
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection
cache = CassandraCache(session, keyspace, ttl_seconds=2)
langchain.llm_cache = cache
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
output = llm.generate(["foo"])
assert output == expected_output
time.sleep(2.5)
# entry has expired away.
output = llm.generate(["foo"])
assert output != expected_output
cache.clear()
def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection
sem_cache = CassandraSemanticCache(session, keyspace, embedding=FakeEmbeddings())
langchain.llm_cache = sem_cache
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["bar"]) # same embedding as 'foo'
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
sem_cache.clear()
output = llm.generate(["bar"]) # 'fizz' is erased away now
assert output != expected_output
sem_cache.clear()