@ -80,6 +80,8 @@ def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str:
Returns :
str : Json representing a list of generations .
Warning : would not work well with arbitrary subclasses of ` Generation `
"""
return json . dumps ( [ generation . dict ( ) for generation in generations ] )
@ -95,6 +97,8 @@ def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
Returns :
RETURN_VAL_TYPE : A list of generations .
Warning : would not work well with arbitrary subclasses of ` Generation `
"""
try :
results = json . loads ( generations_json )
@ -105,6 +109,65 @@ def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
)
def _dumps_generations ( generations : RETURN_VAL_TYPE ) - > str :
"""
Serialization for generic RETURN_VAL_TYPE , i . e . sequence of ` Generation `
Args :
generations ( RETURN_VAL_TYPE ) : A list of language model generations .
Returns :
str : a single string representing a list of generations .
This function ( + its counterpart ` _loads_generations ` ) rely on
the dumps / loads pair with Reviver , so are able to deal
with all subclasses of Generation .
Each item in the list can be ` dumps ` ed to a string ,
then we make the whole list of strings into a json - dumped .
"""
return json . dumps ( [ dumps ( _item ) for _item in generations ] )
def _loads_generations ( generations_str : str ) - > Union [ RETURN_VAL_TYPE , None ] :
"""
Deserialization of a string into a generic RETURN_VAL_TYPE
( i . e . a sequence of ` Generation ` ) .
See ` _dumps_generations ` , the inverse of this function .
Args :
generations_str ( str ) : A string representing a list of generations .
Compatible with the legacy cache - blob format
Does not raise exceptions for malformed entries , just logs a warning
and returns none : the caller should be prepared for such a cache miss .
Returns :
RETURN_VAL_TYPE : A list of generations .
"""
try :
generations = [ loads ( _item_str ) for _item_str in json . loads ( generations_str ) ]
return generations
except ( json . JSONDecodeError , TypeError ) :
# deferring the (soft) handling to after the legacy-format attempt
pass
try :
gen_dicts = json . loads ( generations_str )
# not relying on `_load_generations_from_json` (which could disappear):
generations = [ Generation ( * * generation_dict ) for generation_dict in gen_dicts ]
logger . warning (
f " Legacy ' Generation ' cached blob encountered: ' { generations_str } ' "
)
return generations
except ( json . JSONDecodeError , TypeError ) :
logger . warning (
f " Malformed/unparsable cached blob encountered: ' { generations_str } ' "
)
return None
class InMemoryCache ( BaseCache ) :
""" Cache that stores things in memory. """
@ -733,10 +796,11 @@ class CassandraCache(BaseCache):
def __init__ (
self ,
session : CassandraSession,
keyspace : str ,
session : Optional[ CassandraSession] = None ,
keyspace : Optional [ str ] = None ,
table_name : str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME ,
ttl_seconds : Optional [ int ] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS ,
skip_provisioning : bool = False ,
) :
"""
Initialize with a ready session and a keyspace name .
@ -767,6 +831,7 @@ class CassandraCache(BaseCache):
keys = [ " llm_string " , " prompt " ] ,
primary_key_type = [ " TEXT " , " TEXT " ] ,
ttl_seconds = self . ttl_seconds ,
skip_provisioning = skip_provisioning ,
)
def lookup ( self , prompt : str , llm_string : str ) - > Optional [ RETURN_VAL_TYPE ] :
@ -775,14 +840,19 @@ class CassandraCache(BaseCache):
llm_string = _hash ( llm_string ) ,
prompt = _hash ( prompt ) ,
)
if item :
return _load_generations_from_json ( item [ " body_blob " ] )
if item is not None :
generations = _loads_generations ( item [ " body_blob " ] )
# this protects against malformed cached items:
if generations is not None :
return generations
else :
return None
else :
return None
def update ( self , prompt : str , llm_string : str , return_val : RETURN_VAL_TYPE ) - > None :
""" Update cache based on prompt and llm_string. """
blob = _dump _generations_to_json ( return_val )
blob = _dump s _generations( return_val )
self . kv_cache . put (
llm_string = _hash ( llm_string ) ,
prompt = _hash ( prompt ) ,
@ -836,13 +906,14 @@ class CassandraSemanticCache(BaseCache):
def __init__ (
self ,
session : CassandraSession,
keyspace : str ,
session : Optional[ CassandraSession] ,
keyspace : Optional [ str ] ,
embedding : Embeddings ,
table_name : str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME ,
distance_metric : str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC ,
score_threshold : float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD ,
ttl_seconds : Optional [ int ] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS ,
skip_provisioning : bool = False ,
) :
"""
Initialize the cache with all relevant parameters .
@ -897,6 +968,7 @@ class CassandraSemanticCache(BaseCache):
vector_dimension = self . embedding_dimension ,
ttl_seconds = self . ttl_seconds ,
metadata_indexing = ( " allow " , { " _llm_string_hash " } ) ,
skip_provisioning = skip_provisioning ,
)
def _get_embedding_dimension ( self ) - > int :
@ -906,7 +978,7 @@ class CassandraSemanticCache(BaseCache):
""" Update cache based on prompt and llm_string. """
embedding_vector = self . _get_embedding ( text = prompt )
llm_string_hash = _hash ( llm_string )
body = _dump _generations_to_json ( return_val )
body = _dump s _generations( return_val )
metadata = {
" _prompt " : prompt ,
" _llm_string_hash " : llm_string_hash ,
@ -947,11 +1019,15 @@ class CassandraSemanticCache(BaseCache):
)
if hits :
hit = hits [ 0 ]
generations_str = hit [ " body_blob " ]
return (
hit [ " row_id " ] ,
_load_generations_from_json ( generations_str ) ,
)
generations = _loads_generations ( hit [ " body_blob " ] )
if generations is not None :
# this protects against malformed cached items:
return (
hit [ " row_id " ] ,
generations ,
)
else :
return None
else :
return None