CassandraCache and CassandraSemanticCache can handle any "Generation" (#10563)

Hello,
this PR improves coverage for caching by the two Cassandra-related
caches (i.e. exact-match and semantic alike) by switching to the more
general `dumps`/`loads` serdes utilities.

This enables cache usage within e.g. `ChatOpenAI` contexts (which need
to store lists of `ChatGeneration` instead of `Generation`s), which was
not possible as long as the cache classes were relying on the legacy
`_dump_generations_to_json` and `_load_generations_from_json`).

Additionally, a slightly different init signature is introduced for the
cache objects:
- named parameters required for init, to pave the way for easier changes
in the future connect-to-db flow (and tests adjusted accordingly)
- added a `skip_provisioning` optional passthrough parameter for use
cases where the user knows the underlying DB table, etc already exist.

Thank you for a review!
This commit is contained in:
Stefano Lottini 2023-09-14 17:33:06 +02:00 committed by GitHub
parent e1e01d6586
commit 49b65a1b57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 96 additions and 16 deletions

View File

@ -80,6 +80,8 @@ def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str:
Returns: Returns:
str: Json representing a list of generations. str: Json representing a list of generations.
Warning: would not work well with arbitrary subclasses of `Generation`
""" """
return json.dumps([generation.dict() for generation in generations]) return json.dumps([generation.dict() for generation in generations])
@ -95,6 +97,8 @@ def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
Returns: Returns:
RETURN_VAL_TYPE: A list of generations. RETURN_VAL_TYPE: A list of generations.
Warning: would not work well with arbitrary subclasses of `Generation`
""" """
try: try:
results = json.loads(generations_json) results = json.loads(generations_json)
@ -105,6 +109,65 @@ def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
) )
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
"""
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
Args:
generations (RETURN_VAL_TYPE): A list of language model generations.
Returns:
str: a single string representing a list of generations.
This function (+ its counterpart `_loads_generations`) rely on
the dumps/loads pair with Reviver, so are able to deal
with all subclasses of Generation.
Each item in the list can be `dumps`ed to a string,
then we make the whole list of strings into a json-dumped.
"""
return json.dumps([dumps(_item) for _item in generations])
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
"""
Deserialization of a string into a generic RETURN_VAL_TYPE
(i.e. a sequence of `Generation`).
See `_dumps_generations`, the inverse of this function.
Args:
generations_str (str): A string representing a list of generations.
Compatible with the legacy cache-blob format
Does not raise exceptions for malformed entries, just logs a warning
and returns none: the caller should be prepared for such a cache miss.
Returns:
RETURN_VAL_TYPE: A list of generations.
"""
try:
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
return generations
except (json.JSONDecodeError, TypeError):
# deferring the (soft) handling to after the legacy-format attempt
pass
try:
gen_dicts = json.loads(generations_str)
# not relying on `_load_generations_from_json` (which could disappear):
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
logger.warning(
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
)
return generations
except (json.JSONDecodeError, TypeError):
logger.warning(
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
)
return None
class InMemoryCache(BaseCache): class InMemoryCache(BaseCache):
"""Cache that stores things in memory.""" """Cache that stores things in memory."""
@ -733,10 +796,11 @@ class CassandraCache(BaseCache):
def __init__( def __init__(
self, self,
session: CassandraSession, session: Optional[CassandraSession] = None,
keyspace: str, keyspace: Optional[str] = None,
table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME, table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME,
ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS, ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS,
skip_provisioning: bool = False,
): ):
""" """
Initialize with a ready session and a keyspace name. Initialize with a ready session and a keyspace name.
@ -767,6 +831,7 @@ class CassandraCache(BaseCache):
keys=["llm_string", "prompt"], keys=["llm_string", "prompt"],
primary_key_type=["TEXT", "TEXT"], primary_key_type=["TEXT", "TEXT"],
ttl_seconds=self.ttl_seconds, ttl_seconds=self.ttl_seconds,
skip_provisioning=skip_provisioning,
) )
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
@ -775,14 +840,19 @@ class CassandraCache(BaseCache):
llm_string=_hash(llm_string), llm_string=_hash(llm_string),
prompt=_hash(prompt), prompt=_hash(prompt),
) )
if item: if item is not None:
return _load_generations_from_json(item["body_blob"]) generations = _loads_generations(item["body_blob"])
# this protects against malformed cached items:
if generations is not None:
return generations
else:
return None
else: else:
return None return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string.""" """Update cache based on prompt and llm_string."""
blob = _dump_generations_to_json(return_val) blob = _dumps_generations(return_val)
self.kv_cache.put( self.kv_cache.put(
llm_string=_hash(llm_string), llm_string=_hash(llm_string),
prompt=_hash(prompt), prompt=_hash(prompt),
@ -836,13 +906,14 @@ class CassandraSemanticCache(BaseCache):
def __init__( def __init__(
self, self,
session: CassandraSession, session: Optional[CassandraSession],
keyspace: str, keyspace: Optional[str],
embedding: Embeddings, embedding: Embeddings,
table_name: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME, table_name: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME,
distance_metric: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC, distance_metric: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC,
score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD, score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD,
ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS, ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS,
skip_provisioning: bool = False,
): ):
""" """
Initialize the cache with all relevant parameters. Initialize the cache with all relevant parameters.
@ -897,6 +968,7 @@ class CassandraSemanticCache(BaseCache):
vector_dimension=self.embedding_dimension, vector_dimension=self.embedding_dimension,
ttl_seconds=self.ttl_seconds, ttl_seconds=self.ttl_seconds,
metadata_indexing=("allow", {"_llm_string_hash"}), metadata_indexing=("allow", {"_llm_string_hash"}),
skip_provisioning=skip_provisioning,
) )
def _get_embedding_dimension(self) -> int: def _get_embedding_dimension(self) -> int:
@ -906,7 +978,7 @@ class CassandraSemanticCache(BaseCache):
"""Update cache based on prompt and llm_string.""" """Update cache based on prompt and llm_string."""
embedding_vector = self._get_embedding(text=prompt) embedding_vector = self._get_embedding(text=prompt)
llm_string_hash = _hash(llm_string) llm_string_hash = _hash(llm_string)
body = _dump_generations_to_json(return_val) body = _dumps_generations(return_val)
metadata = { metadata = {
"_prompt": prompt, "_prompt": prompt,
"_llm_string_hash": llm_string_hash, "_llm_string_hash": llm_string_hash,
@ -947,11 +1019,15 @@ class CassandraSemanticCache(BaseCache):
) )
if hits: if hits:
hit = hits[0] hit = hits[0]
generations_str = hit["body_blob"] generations = _loads_generations(hit["body_blob"])
return ( if generations is not None:
hit["row_id"], # this protects against malformed cached items:
_load_generations_from_json(generations_str), return (
) hit["row_id"],
generations,
)
else:
return None
else: else:
return None return None

View File

@ -38,7 +38,7 @@ def cassandra_connection() -> Iterator[Tuple[Any, str]]:
def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None: def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection session, keyspace = cassandra_connection
cache = CassandraCache(session, keyspace) cache = CassandraCache(session=session, keyspace=keyspace)
langchain.llm_cache = cache langchain.llm_cache = cache
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() params = llm.dict()
@ -58,7 +58,7 @@ def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None: def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection session, keyspace = cassandra_connection
cache = CassandraCache(session, keyspace, ttl_seconds=2) cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2)
langchain.llm_cache = cache langchain.llm_cache = cache
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() params = llm.dict()
@ -80,7 +80,11 @@ def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None: def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection session, keyspace = cassandra_connection
sem_cache = CassandraSemanticCache(session, keyspace, embedding=FakeEmbeddings()) sem_cache = CassandraSemanticCache(
session=session,
keyspace=keyspace,
embedding=FakeEmbeddings(),
)
langchain.llm_cache = sem_cache langchain.llm_cache = sem_cache
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() params = llm.dict()