mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
CassandraCache and CassandraSemanticCache can handle any "Generation" (#10563)
Hello, this PR improves coverage for caching by the two Cassandra-related caches (i.e. exact-match and semantic alike) by switching to the more general `dumps`/`loads` serdes utilities. This enables cache usage within e.g. `ChatOpenAI` contexts (which need to store lists of `ChatGeneration` instead of `Generation`s), which was not possible as long as the cache classes were relying on the legacy `_dump_generations_to_json` and `_load_generations_from_json`). Additionally, a slightly different init signature is introduced for the cache objects: - named parameters required for init, to pave the way for easier changes in the future connect-to-db flow (and tests adjusted accordingly) - added a `skip_provisioning` optional passthrough parameter for use cases where the user knows the underlying DB table, etc already exist. Thank you for a review!
This commit is contained in:
parent
e1e01d6586
commit
49b65a1b57
@ -80,6 +80,8 @@ def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Json representing a list of generations.
|
str: Json representing a list of generations.
|
||||||
|
|
||||||
|
Warning: would not work well with arbitrary subclasses of `Generation`
|
||||||
"""
|
"""
|
||||||
return json.dumps([generation.dict() for generation in generations])
|
return json.dumps([generation.dict() for generation in generations])
|
||||||
|
|
||||||
@ -95,6 +97,8 @@ def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
RETURN_VAL_TYPE: A list of generations.
|
RETURN_VAL_TYPE: A list of generations.
|
||||||
|
|
||||||
|
Warning: would not work well with arbitrary subclasses of `Generation`
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
results = json.loads(generations_json)
|
results = json.loads(generations_json)
|
||||||
@ -105,6 +109,65 @@ def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
|
||||||
|
"""
|
||||||
|
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generations (RETURN_VAL_TYPE): A list of language model generations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: a single string representing a list of generations.
|
||||||
|
|
||||||
|
This function (+ its counterpart `_loads_generations`) rely on
|
||||||
|
the dumps/loads pair with Reviver, so are able to deal
|
||||||
|
with all subclasses of Generation.
|
||||||
|
|
||||||
|
Each item in the list can be `dumps`ed to a string,
|
||||||
|
then we make the whole list of strings into a json-dumped.
|
||||||
|
"""
|
||||||
|
return json.dumps([dumps(_item) for _item in generations])
|
||||||
|
|
||||||
|
|
||||||
|
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
|
||||||
|
"""
|
||||||
|
Deserialization of a string into a generic RETURN_VAL_TYPE
|
||||||
|
(i.e. a sequence of `Generation`).
|
||||||
|
|
||||||
|
See `_dumps_generations`, the inverse of this function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generations_str (str): A string representing a list of generations.
|
||||||
|
|
||||||
|
Compatible with the legacy cache-blob format
|
||||||
|
Does not raise exceptions for malformed entries, just logs a warning
|
||||||
|
and returns none: the caller should be prepared for such a cache miss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RETURN_VAL_TYPE: A list of generations.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
|
||||||
|
return generations
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
# deferring the (soft) handling to after the legacy-format attempt
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
gen_dicts = json.loads(generations_str)
|
||||||
|
# not relying on `_load_generations_from_json` (which could disappear):
|
||||||
|
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
|
||||||
|
logger.warning(
|
||||||
|
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
|
||||||
|
)
|
||||||
|
return generations
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class InMemoryCache(BaseCache):
|
class InMemoryCache(BaseCache):
|
||||||
"""Cache that stores things in memory."""
|
"""Cache that stores things in memory."""
|
||||||
|
|
||||||
@ -733,10 +796,11 @@ class CassandraCache(BaseCache):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
session: CassandraSession,
|
session: Optional[CassandraSession] = None,
|
||||||
keyspace: str,
|
keyspace: Optional[str] = None,
|
||||||
table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME,
|
table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME,
|
||||||
ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS,
|
ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS,
|
||||||
|
skip_provisioning: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with a ready session and a keyspace name.
|
Initialize with a ready session and a keyspace name.
|
||||||
@ -767,6 +831,7 @@ class CassandraCache(BaseCache):
|
|||||||
keys=["llm_string", "prompt"],
|
keys=["llm_string", "prompt"],
|
||||||
primary_key_type=["TEXT", "TEXT"],
|
primary_key_type=["TEXT", "TEXT"],
|
||||||
ttl_seconds=self.ttl_seconds,
|
ttl_seconds=self.ttl_seconds,
|
||||||
|
skip_provisioning=skip_provisioning,
|
||||||
)
|
)
|
||||||
|
|
||||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
@ -775,14 +840,19 @@ class CassandraCache(BaseCache):
|
|||||||
llm_string=_hash(llm_string),
|
llm_string=_hash(llm_string),
|
||||||
prompt=_hash(prompt),
|
prompt=_hash(prompt),
|
||||||
)
|
)
|
||||||
if item:
|
if item is not None:
|
||||||
return _load_generations_from_json(item["body_blob"])
|
generations = _loads_generations(item["body_blob"])
|
||||||
|
# this protects against malformed cached items:
|
||||||
|
if generations is not None:
|
||||||
|
return generations
|
||||||
|
else:
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
"""Update cache based on prompt and llm_string."""
|
"""Update cache based on prompt and llm_string."""
|
||||||
blob = _dump_generations_to_json(return_val)
|
blob = _dumps_generations(return_val)
|
||||||
self.kv_cache.put(
|
self.kv_cache.put(
|
||||||
llm_string=_hash(llm_string),
|
llm_string=_hash(llm_string),
|
||||||
prompt=_hash(prompt),
|
prompt=_hash(prompt),
|
||||||
@ -836,13 +906,14 @@ class CassandraSemanticCache(BaseCache):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
session: CassandraSession,
|
session: Optional[CassandraSession],
|
||||||
keyspace: str,
|
keyspace: Optional[str],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
table_name: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME,
|
table_name: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME,
|
||||||
distance_metric: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC,
|
distance_metric: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC,
|
||||||
score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD,
|
score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD,
|
||||||
ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS,
|
ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS,
|
||||||
|
skip_provisioning: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the cache with all relevant parameters.
|
Initialize the cache with all relevant parameters.
|
||||||
@ -897,6 +968,7 @@ class CassandraSemanticCache(BaseCache):
|
|||||||
vector_dimension=self.embedding_dimension,
|
vector_dimension=self.embedding_dimension,
|
||||||
ttl_seconds=self.ttl_seconds,
|
ttl_seconds=self.ttl_seconds,
|
||||||
metadata_indexing=("allow", {"_llm_string_hash"}),
|
metadata_indexing=("allow", {"_llm_string_hash"}),
|
||||||
|
skip_provisioning=skip_provisioning,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_embedding_dimension(self) -> int:
|
def _get_embedding_dimension(self) -> int:
|
||||||
@ -906,7 +978,7 @@ class CassandraSemanticCache(BaseCache):
|
|||||||
"""Update cache based on prompt and llm_string."""
|
"""Update cache based on prompt and llm_string."""
|
||||||
embedding_vector = self._get_embedding(text=prompt)
|
embedding_vector = self._get_embedding(text=prompt)
|
||||||
llm_string_hash = _hash(llm_string)
|
llm_string_hash = _hash(llm_string)
|
||||||
body = _dump_generations_to_json(return_val)
|
body = _dumps_generations(return_val)
|
||||||
metadata = {
|
metadata = {
|
||||||
"_prompt": prompt,
|
"_prompt": prompt,
|
||||||
"_llm_string_hash": llm_string_hash,
|
"_llm_string_hash": llm_string_hash,
|
||||||
@ -947,11 +1019,15 @@ class CassandraSemanticCache(BaseCache):
|
|||||||
)
|
)
|
||||||
if hits:
|
if hits:
|
||||||
hit = hits[0]
|
hit = hits[0]
|
||||||
generations_str = hit["body_blob"]
|
generations = _loads_generations(hit["body_blob"])
|
||||||
return (
|
if generations is not None:
|
||||||
hit["row_id"],
|
# this protects against malformed cached items:
|
||||||
_load_generations_from_json(generations_str),
|
return (
|
||||||
)
|
hit["row_id"],
|
||||||
|
generations,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ def cassandra_connection() -> Iterator[Tuple[Any, str]]:
|
|||||||
|
|
||||||
def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||||
session, keyspace = cassandra_connection
|
session, keyspace = cassandra_connection
|
||||||
cache = CassandraCache(session, keyspace)
|
cache = CassandraCache(session=session, keyspace=keyspace)
|
||||||
langchain.llm_cache = cache
|
langchain.llm_cache = cache
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
@ -58,7 +58,7 @@ def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
|||||||
|
|
||||||
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
||||||
session, keyspace = cassandra_connection
|
session, keyspace = cassandra_connection
|
||||||
cache = CassandraCache(session, keyspace, ttl_seconds=2)
|
cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2)
|
||||||
langchain.llm_cache = cache
|
langchain.llm_cache = cache
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
@ -80,7 +80,11 @@ def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
|||||||
|
|
||||||
def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||||
session, keyspace = cassandra_connection
|
session, keyspace = cassandra_connection
|
||||||
sem_cache = CassandraSemanticCache(session, keyspace, embedding=FakeEmbeddings())
|
sem_cache = CassandraSemanticCache(
|
||||||
|
session=session,
|
||||||
|
keyspace=keyspace,
|
||||||
|
embedding=FakeEmbeddings(),
|
||||||
|
)
|
||||||
langchain.llm_cache = sem_cache
|
langchain.llm_cache = sem_cache
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
|
Loading…
Reference in New Issue
Block a user