mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Cassandra support for LLM cache (exact-match and semantic) (#9772)
This PR implements two new classes in the cache module: `CassandraCache` and `CassandraSemanticCache`, similar in structure and functionality to their Redis counterpart: providing a cache for the response to a (prompt, llm) pair. Integration tests are included. Moreover, linting and type checks are all passing on my machine. Dependencies: the `pyproject.toml` and `poetry.lock` have the newest version of cassIO (the very same as in the Cassandra vector store metadata PR, submitted as #9280). If I may suggest, this issue and #9280 might be reviewed together (as they bring the same poetry changes along), so I'm tagging @baskaryan who already helped out a little with poetry-related conflicts there. (Thank you!) I'd be happy to add a short notebook if this is deemed necessary (but it seems to me that, contrary e.g. to vector stores, caches are not covered in specific notebooks). Thank you! --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
16945c9922
commit
c9ff0ab2e9
@ -27,6 +27,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from functools import lru_cache
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -51,6 +52,7 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.llms.base import LLM, get_prompts
|
||||||
from langchain.load.dump import dumps
|
from langchain.load.dump import dumps
|
||||||
from langchain.load.load import loads
|
from langchain.load.load import loads
|
||||||
from langchain.schema import ChatGeneration, Generation
|
from langchain.schema import ChatGeneration, Generation
|
||||||
@ -62,6 +64,7 @@ logger = logging.getLogger(__file__)
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import momento
|
import momento
|
||||||
|
from cassandra.cluster import Session as CassandraSession
|
||||||
|
|
||||||
|
|
||||||
def _hash(_input: str) -> str:
|
def _hash(_input: str) -> str:
|
||||||
@ -711,3 +714,264 @@ class MomentoCache(BaseCache):
|
|||||||
pass
|
pass
|
||||||
elif isinstance(flush_response, CacheFlush.Error):
|
elif isinstance(flush_response, CacheFlush.Error):
|
||||||
raise flush_response.inner_exception
|
raise flush_response.inner_exception
|
||||||
|
|
||||||
|
|
||||||
|
CASSANDRA_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_cache"
|
||||||
|
CASSANDRA_CACHE_DEFAULT_TTL_SECONDS = None
|
||||||
|
|
||||||
|
|
||||||
|
class CassandraCache(BaseCache):
|
||||||
|
"""
|
||||||
|
Cache that uses Cassandra / Astra DB as a backend.
|
||||||
|
|
||||||
|
It uses a single Cassandra table.
|
||||||
|
The lookup keys (which get to form the primary key) are:
|
||||||
|
- prompt, a string
|
||||||
|
- llm_string, a deterministic str representation of the model parameters.
|
||||||
|
(needed to prevent collisions same-prompt-different-model collisions)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: CassandraSession,
|
||||||
|
keyspace: str,
|
||||||
|
table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME,
|
||||||
|
ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize with a ready session and a keyspace name.
|
||||||
|
Args:
|
||||||
|
session (cassandra.cluster.Session): an open Cassandra session
|
||||||
|
keyspace (str): the keyspace to use for storing the cache
|
||||||
|
table_name (str): name of the Cassandra table to use as cache
|
||||||
|
ttl_seconds (optional int): time-to-live for cache entries
|
||||||
|
(default: None, i.e. forever)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from cassio.table import ElasticCassandraTable
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import cassio python package. "
|
||||||
|
"Please install it with `pip install cassio`."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.session = session
|
||||||
|
self.keyspace = keyspace
|
||||||
|
self.table_name = table_name
|
||||||
|
self.ttl_seconds = ttl_seconds
|
||||||
|
|
||||||
|
self.kv_cache = ElasticCassandraTable(
|
||||||
|
session=self.session,
|
||||||
|
keyspace=self.keyspace,
|
||||||
|
table=self.table_name,
|
||||||
|
keys=["llm_string", "prompt"],
|
||||||
|
primary_key_type=["TEXT", "TEXT"],
|
||||||
|
ttl_seconds=self.ttl_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
|
"""Look up based on prompt and llm_string."""
|
||||||
|
item = self.kv_cache.get(
|
||||||
|
llm_string=_hash(llm_string),
|
||||||
|
prompt=_hash(prompt),
|
||||||
|
)
|
||||||
|
if item:
|
||||||
|
return _load_generations_from_json(item["body_blob"])
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
|
"""Update cache based on prompt and llm_string."""
|
||||||
|
blob = _dump_generations_to_json(return_val)
|
||||||
|
self.kv_cache.put(
|
||||||
|
llm_string=_hash(llm_string),
|
||||||
|
prompt=_hash(prompt),
|
||||||
|
body_blob=blob,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_through_llm(
|
||||||
|
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
A wrapper around `delete` with the LLM being passed.
|
||||||
|
In case the llm(prompt) calls have a `stop` param, you should pass it here
|
||||||
|
"""
|
||||||
|
llm_string = get_prompts(
|
||||||
|
{**llm.dict(), **{"stop": stop}},
|
||||||
|
[],
|
||||||
|
)[1]
|
||||||
|
return self.delete(prompt, llm_string=llm_string)
|
||||||
|
|
||||||
|
def delete(self, prompt: str, llm_string: str) -> None:
|
||||||
|
"""Evict from cache if there's an entry."""
|
||||||
|
return self.kv_cache.delete(
|
||||||
|
llm_string=_hash(llm_string),
|
||||||
|
prompt=_hash(prompt),
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear(self, **kwargs: Any) -> None:
|
||||||
|
"""Clear cache. This is for all LLMs at once."""
|
||||||
|
self.kv_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC = "dot"
|
||||||
|
CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD = 0.85
|
||||||
|
CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_semantic_cache"
|
||||||
|
CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS = None
|
||||||
|
CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16
|
||||||
|
|
||||||
|
|
||||||
|
class CassandraSemanticCache(BaseCache):
|
||||||
|
"""
|
||||||
|
Cache that uses Cassandra as a vector-store backend for semantic
|
||||||
|
(i.e. similarity-based) lookup.
|
||||||
|
|
||||||
|
It uses a single (vector) Cassandra table and stores, in principle,
|
||||||
|
cached values from several LLMs, so the LLM's llm_string is part
|
||||||
|
of the rows' primary keys.
|
||||||
|
|
||||||
|
The similarity is based on one of several distance metrics (default: "dot").
|
||||||
|
If choosing another metric, the default threshold is to be re-tuned accordingly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: CassandraSession,
|
||||||
|
keyspace: str,
|
||||||
|
embedding: Embeddings,
|
||||||
|
table_name: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME,
|
||||||
|
distance_metric: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC,
|
||||||
|
score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD,
|
||||||
|
ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the cache with all relevant parameters.
|
||||||
|
Args:
|
||||||
|
session (cassandra.cluster.Session): an open Cassandra session
|
||||||
|
keyspace (str): the keyspace to use for storing the cache
|
||||||
|
embedding (Embedding): Embedding provider for semantic
|
||||||
|
encoding and search.
|
||||||
|
table_name (str): name of the Cassandra (vector) table
|
||||||
|
to use as cache
|
||||||
|
distance_metric (str, 'dot'): which measure to adopt for
|
||||||
|
similarity searches
|
||||||
|
score_threshold (optional float): numeric value to use as
|
||||||
|
cutoff for the similarity searches
|
||||||
|
ttl_seconds (optional int): time-to-live for cache entries
|
||||||
|
(default: None, i.e. forever)
|
||||||
|
The default score threshold is tuned to the default metric.
|
||||||
|
Tune it carefully yourself if switching to another distance metric.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from cassio.table import MetadataVectorCassandraTable
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import cassio python package. "
|
||||||
|
"Please install it with `pip install cassio`."
|
||||||
|
)
|
||||||
|
self.session = session
|
||||||
|
self.keyspace = keyspace
|
||||||
|
self.embedding = embedding
|
||||||
|
self.table_name = table_name
|
||||||
|
self.distance_metric = distance_metric
|
||||||
|
self.score_threshold = score_threshold
|
||||||
|
self.ttl_seconds = ttl_seconds
|
||||||
|
|
||||||
|
# The contract for this class has separate lookup and update:
|
||||||
|
# in order to spare some embedding calculations we cache them between
|
||||||
|
# the two calls.
|
||||||
|
# Note: each instance of this class has its own `_get_embedding` with
|
||||||
|
# its own lru.
|
||||||
|
@lru_cache(maxsize=CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
|
||||||
|
def _cache_embedding(text: str) -> List[float]:
|
||||||
|
return self.embedding.embed_query(text=text)
|
||||||
|
|
||||||
|
self._get_embedding = _cache_embedding
|
||||||
|
self.embedding_dimension = self._get_embedding_dimension()
|
||||||
|
|
||||||
|
self.table = MetadataVectorCassandraTable(
|
||||||
|
session=self.session,
|
||||||
|
keyspace=self.keyspace,
|
||||||
|
table=self.table_name,
|
||||||
|
primary_key_type=["TEXT"],
|
||||||
|
vector_dimension=self.embedding_dimension,
|
||||||
|
ttl_seconds=self.ttl_seconds,
|
||||||
|
metadata_indexing=("allow", {"_llm_string_hash"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_embedding_dimension(self) -> int:
|
||||||
|
return len(self._get_embedding(text="This is a sample sentence."))
|
||||||
|
|
||||||
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
|
"""Update cache based on prompt and llm_string."""
|
||||||
|
embedding_vector = self._get_embedding(text=prompt)
|
||||||
|
llm_string_hash = _hash(llm_string)
|
||||||
|
body = _dump_generations_to_json(return_val)
|
||||||
|
metadata = {
|
||||||
|
"_prompt": prompt,
|
||||||
|
"_llm_string_hash": llm_string_hash,
|
||||||
|
}
|
||||||
|
row_id = f"{_hash(prompt)}-{llm_string_hash}"
|
||||||
|
#
|
||||||
|
self.table.put(
|
||||||
|
body_blob=body,
|
||||||
|
vector=embedding_vector,
|
||||||
|
row_id=row_id,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
|
"""Look up based on prompt and llm_string."""
|
||||||
|
hit_with_id = self.lookup_with_id(prompt, llm_string)
|
||||||
|
if hit_with_id is not None:
|
||||||
|
return hit_with_id[1]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def lookup_with_id(
|
||||||
|
self, prompt: str, llm_string: str
|
||||||
|
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||||
|
"""
|
||||||
|
Look up based on prompt and llm_string.
|
||||||
|
If there are hits, return (document_id, cached_entry)
|
||||||
|
"""
|
||||||
|
prompt_embedding: List[float] = self._get_embedding(text=prompt)
|
||||||
|
hits = list(
|
||||||
|
self.table.metric_ann_search(
|
||||||
|
vector=prompt_embedding,
|
||||||
|
metadata={"_llm_string_hash": _hash(llm_string)},
|
||||||
|
n=1,
|
||||||
|
metric=self.distance_metric,
|
||||||
|
metric_threshold=self.score_threshold,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if hits:
|
||||||
|
hit = hits[0]
|
||||||
|
generations_str = hit["body_blob"]
|
||||||
|
return (
|
||||||
|
hit["row_id"],
|
||||||
|
_load_generations_from_json(generations_str),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def lookup_with_id_through_llm(
|
||||||
|
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
||||||
|
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||||
|
llm_string = get_prompts(
|
||||||
|
{**llm.dict(), **{"stop": stop}},
|
||||||
|
[],
|
||||||
|
)[1]
|
||||||
|
return self.lookup_with_id(prompt, llm_string=llm_string)
|
||||||
|
|
||||||
|
def delete_by_document_id(self, document_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Given this is a "similarity search" cache, an invalidation pattern
|
||||||
|
that makes sense is first a lookup to get an ID, and then deleting
|
||||||
|
with that ID. This is for the second step.
|
||||||
|
"""
|
||||||
|
self.table.delete(row_id=document_id)
|
||||||
|
|
||||||
|
def clear(self, **kwargs: Any) -> None:
|
||||||
|
"""Clear the *whole* semantic cache."""
|
||||||
|
self.table.clear()
|
||||||
|
512
libs/langchain/poetry.lock
generated
512
libs/langchain/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -109,7 +109,7 @@ azure-search-documents = {version = "11.4.0b8", optional = true}
|
|||||||
esprima = {version = "^4.0.1", optional = true}
|
esprima = {version = "^4.0.1", optional = true}
|
||||||
streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
|
streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
|
||||||
psychicapi = {version = "^0.8.0", optional = true}
|
psychicapi = {version = "^0.8.0", optional = true}
|
||||||
cassio = {version = "^0.0.7", optional = true}
|
cassio = {version = "^0.1.0", optional = true}
|
||||||
rdflib = {version = "^6.3.2", optional = true}
|
rdflib = {version = "^6.3.2", optional = true}
|
||||||
sympy = {version = "^1.12", optional = true}
|
sympy = {version = "^1.12", optional = true}
|
||||||
rapidfuzz = {version = "^3.1.1", optional = true}
|
rapidfuzz = {version = "^3.1.1", optional = true}
|
||||||
@ -174,7 +174,7 @@ pytest-vcr = "^1.0.2"
|
|||||||
wrapt = "^1.15.0"
|
wrapt = "^1.15.0"
|
||||||
openai = "^0.27.4"
|
openai = "^0.27.4"
|
||||||
python-dotenv = "^1.0.0"
|
python-dotenv = "^1.0.0"
|
||||||
cassio = "^0.0.7"
|
cassio = "^0.1.0"
|
||||||
tiktoken = "^0.3.2"
|
tiktoken = "^0.3.2"
|
||||||
|
|
||||||
[tool.poetry.group.lint.dependencies]
|
[tool.poetry.group.lint.dependencies]
|
||||||
|
100
libs/langchain/tests/integration_tests/cache/test_cassandra.py
vendored
Normal file
100
libs/langchain/tests/integration_tests/cache/test_cassandra.py
vendored
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
"""Test Cassandra caches. Requires a running vector-capable Cassandra cluster."""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Any, Iterator, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import langchain
|
||||||
|
from langchain.cache import CassandraCache, CassandraSemanticCache
|
||||||
|
from langchain.schema import Generation, LLMResult
|
||||||
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def cassandra_connection() -> Iterator[Tuple[Any, str]]:
|
||||||
|
from cassandra.cluster import Cluster
|
||||||
|
|
||||||
|
keyspace = "langchain_cache_test_keyspace"
|
||||||
|
# get db connection
|
||||||
|
if "CASSANDRA_CONTACT_POINTS" in os.environ:
|
||||||
|
contact_points = os.environ["CONTACT_POINTS"].split(",")
|
||||||
|
cluster = Cluster(contact_points)
|
||||||
|
else:
|
||||||
|
cluster = Cluster()
|
||||||
|
#
|
||||||
|
session = cluster.connect()
|
||||||
|
# ensure keyspace exists
|
||||||
|
session.execute(
|
||||||
|
(
|
||||||
|
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
|
||||||
|
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield (session, keyspace)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||||
|
session, keyspace = cassandra_connection
|
||||||
|
cache = CassandraCache(session, keyspace)
|
||||||
|
langchain.llm_cache = cache
|
||||||
|
llm = FakeLLM()
|
||||||
|
params = llm.dict()
|
||||||
|
params["stop"] = None
|
||||||
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||||
|
output = llm.generate(["foo"])
|
||||||
|
print(output)
|
||||||
|
expected_output = LLMResult(
|
||||||
|
generations=[[Generation(text="fizz")]],
|
||||||
|
llm_output={},
|
||||||
|
)
|
||||||
|
print(expected_output)
|
||||||
|
assert output == expected_output
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
||||||
|
session, keyspace = cassandra_connection
|
||||||
|
cache = CassandraCache(session, keyspace, ttl_seconds=2)
|
||||||
|
langchain.llm_cache = cache
|
||||||
|
llm = FakeLLM()
|
||||||
|
params = llm.dict()
|
||||||
|
params["stop"] = None
|
||||||
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||||
|
expected_output = LLMResult(
|
||||||
|
generations=[[Generation(text="fizz")]],
|
||||||
|
llm_output={},
|
||||||
|
)
|
||||||
|
output = llm.generate(["foo"])
|
||||||
|
assert output == expected_output
|
||||||
|
time.sleep(2.5)
|
||||||
|
# entry has expired away.
|
||||||
|
output = llm.generate(["foo"])
|
||||||
|
assert output != expected_output
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||||
|
session, keyspace = cassandra_connection
|
||||||
|
sem_cache = CassandraSemanticCache(session, keyspace, embedding=FakeEmbeddings())
|
||||||
|
langchain.llm_cache = sem_cache
|
||||||
|
llm = FakeLLM()
|
||||||
|
params = llm.dict()
|
||||||
|
params["stop"] = None
|
||||||
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||||
|
output = llm.generate(["bar"]) # same embedding as 'foo'
|
||||||
|
expected_output = LLMResult(
|
||||||
|
generations=[[Generation(text="fizz")]],
|
||||||
|
llm_output={},
|
||||||
|
)
|
||||||
|
assert output == expected_output
|
||||||
|
# clear the cache
|
||||||
|
sem_cache.clear()
|
||||||
|
output = llm.generate(["bar"]) # 'fizz' is erased away now
|
||||||
|
assert output != expected_output
|
||||||
|
sem_cache.clear()
|
Loading…
Reference in New Issue
Block a user