community[minor]: Add async methods to CassandraCache and CassandraSemanticCache (#20654)

pull/21078/head
Christophe Bornet 4 weeks ago committed by GitHub
parent d6e9bd3011
commit 5c77f45b06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -53,6 +53,7 @@ from sqlalchemy.engine import Row
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session
from langchain_community.utilities.cassandra import SetupMode as CassandraSetupMode
from langchain_community.vectorstores.azure_cosmos_db import (
CosmosDBSimilarityType,
CosmosDBVectorSearchType,
@ -63,7 +64,7 @@ try:
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain_core._api.deprecation import deprecated
from langchain_core._api.deprecation import deprecated, warn_deprecated
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import LLM, aget_prompts, get_prompts
@ -73,7 +74,9 @@ from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.utils import get_from_env
from langchain_community.utilities.astradb import (
SetupMode,
SetupMode as AstraSetupMode,
)
from langchain_community.utilities.astradb import (
_AstraDBCollectionEnvironment,
)
from langchain_community.vectorstores import AzureCosmosDBVectorSearch
@ -1056,6 +1059,7 @@ class CassandraCache(BaseCache):
table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME,
ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS,
skip_provisioning: bool = False,
setup_mode: CassandraSetupMode = CassandraSetupMode.SYNC,
):
"""
Initialize with a ready session and a keyspace name.
@ -1066,6 +1070,10 @@ class CassandraCache(BaseCache):
ttl_seconds (optional int): time-to-live for cache entries
(default: None, i.e. forever)
"""
if skip_provisioning:
warn_deprecated(
"0.0.33", alternative="Use setup_mode=CassandraSetupMode.OFF instead."
)
try:
from cassio.table import ElasticCassandraTable
except (ImportError, ModuleNotFoundError):
@ -1079,6 +1087,10 @@ class CassandraCache(BaseCache):
self.table_name = table_name
self.ttl_seconds = ttl_seconds
kwargs = {}
if setup_mode == CassandraSetupMode.ASYNC:
kwargs["async_setup"] = True
self.kv_cache = ElasticCassandraTable(
session=self.session,
keyspace=self.keyspace,
@ -1086,27 +1098,31 @@ class CassandraCache(BaseCache):
keys=["llm_string", "prompt"],
primary_key_type=["TEXT", "TEXT"],
ttl_seconds=self.ttl_seconds,
skip_provisioning=skip_provisioning,
skip_provisioning=skip_provisioning or setup_mode == CassandraSetupMode.OFF,
**kwargs,
)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
item = self.kv_cache.get(
llm_string=_hash(llm_string),
prompt=_hash(prompt),
)
if item is not None:
generations = _loads_generations(item["body_blob"])
# this protects against malformed cached items:
if generations is not None:
return generations
else:
return None
return _loads_generations(item["body_blob"])
else:
return None
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
item = await self.kv_cache.aget(
llm_string=_hash(llm_string),
prompt=_hash(prompt),
)
if item is not None:
return _loads_generations(item["body_blob"])
else:
return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
blob = _dumps_generations(return_val)
self.kv_cache.put(
llm_string=_hash(llm_string),
@ -1114,6 +1130,16 @@ class CassandraCache(BaseCache):
body_blob=blob,
)
async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
blob = _dumps_generations(return_val)
await self.kv_cache.aput(
llm_string=_hash(llm_string),
prompt=_hash(prompt),
body_blob=blob,
)
def delete_through_llm(
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
) -> None:
@ -1139,6 +1165,10 @@ class CassandraCache(BaseCache):
"""Clear cache. This is for all LLMs at once."""
self.kv_cache.clear()
async def aclear(self, **kwargs: Any) -> None:
"""Clear cache. This is for all LLMs at once."""
await self.kv_cache.aclear()
CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC = "dot"
CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD = 0.85
@ -1170,6 +1200,7 @@ class CassandraSemanticCache(BaseCache):
score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD,
ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS,
skip_provisioning: bool = False,
setup_mode: CassandraSetupMode = CassandraSetupMode.SYNC,
):
"""
Initialize the cache with all relevant parameters.
@ -1189,6 +1220,10 @@ class CassandraSemanticCache(BaseCache):
The default score threshold is tuned to the default metric.
Tune it carefully yourself if switching to another distance metric.
"""
if skip_provisioning:
warn_deprecated(
"0.0.33", alternative="Use setup_mode=CassandraSetupMode.OFF instead."
)
try:
from cassio.table import MetadataVectorCassandraTable
except (ImportError, ModuleNotFoundError):
@ -1214,24 +1249,42 @@ class CassandraSemanticCache(BaseCache):
return self.embedding.embed_query(text=text)
self._get_embedding = _cache_embedding
self.embedding_dimension = self._get_embedding_dimension()
@_async_lru_cache(maxsize=CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
async def _acache_embedding(text: str) -> List[float]:
return await self.embedding.aembed_query(text=text)
self._aget_embedding = _acache_embedding
embedding_dimension: Union[int, Awaitable[int], None] = None
if setup_mode == CassandraSetupMode.ASYNC:
embedding_dimension = self._aget_embedding_dimension()
elif setup_mode == CassandraSetupMode.SYNC:
embedding_dimension = self._get_embedding_dimension()
kwargs = {}
if setup_mode == CassandraSetupMode.ASYNC:
kwargs["async_setup"] = True
self.table = MetadataVectorCassandraTable(
session=self.session,
keyspace=self.keyspace,
table=self.table_name,
primary_key_type=["TEXT"],
vector_dimension=self.embedding_dimension,
vector_dimension=embedding_dimension,
ttl_seconds=self.ttl_seconds,
metadata_indexing=("allow", {"_llm_string_hash"}),
skip_provisioning=skip_provisioning,
skip_provisioning=skip_provisioning or setup_mode == CassandraSetupMode.OFF,
**kwargs,
)
def _get_embedding_dimension(self) -> int:
return len(self._get_embedding(text="This is a sample sentence."))
async def _aget_embedding_dimension(self) -> int:
return len(await self._aget_embedding(text="This is a sample sentence."))
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
embedding_vector = self._get_embedding(text=prompt)
llm_string_hash = _hash(llm_string)
body = _dumps_generations(return_val)
@ -1240,7 +1293,7 @@ class CassandraSemanticCache(BaseCache):
"_llm_string_hash": llm_string_hash,
}
row_id = f"{_hash(prompt)}-{llm_string_hash}"
#
self.table.put(
body_blob=body,
vector=embedding_vector,
@ -1248,14 +1301,39 @@ class CassandraSemanticCache(BaseCache):
metadata=metadata,
)
async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
embedding_vector = await self._aget_embedding(text=prompt)
llm_string_hash = _hash(llm_string)
body = _dumps_generations(return_val)
metadata = {
"_prompt": prompt,
"_llm_string_hash": llm_string_hash,
}
row_id = f"{_hash(prompt)}-{llm_string_hash}"
await self.table.aput(
body_blob=body,
vector=embedding_vector,
row_id=row_id,
metadata=metadata,
)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
hit_with_id = self.lookup_with_id(prompt, llm_string)
if hit_with_id is not None:
return hit_with_id[1]
else:
return None
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
hit_with_id = await self.alookup_with_id(prompt, llm_string)
if hit_with_id is not None:
return hit_with_id[1]
else:
return None
def lookup_with_id(
self, prompt: str, llm_string: str
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
@ -1287,6 +1365,37 @@ class CassandraSemanticCache(BaseCache):
else:
return None
async def alookup_with_id(
self, prompt: str, llm_string: str
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
"""
Look up based on prompt and llm_string.
If there are hits, return (document_id, cached_entry)
"""
prompt_embedding: List[float] = await self._aget_embedding(text=prompt)
hits = list(
await self.table.ametric_ann_search(
vector=prompt_embedding,
metadata={"_llm_string_hash": _hash(llm_string)},
n=1,
metric=self.distance_metric,
metric_threshold=self.score_threshold,
)
)
if hits:
hit = hits[0]
generations = _loads_generations(hit["body_blob"])
if generations is not None:
# this protects against malformed cached items:
return (
hit["row_id"],
generations,
)
else:
return None
else:
return None
def lookup_with_id_through_llm(
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
@ -1296,6 +1405,17 @@ class CassandraSemanticCache(BaseCache):
)[1]
return self.lookup_with_id(prompt, llm_string=llm_string)
async def alookup_with_id_through_llm(
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
llm_string = (
await aget_prompts(
{**llm.dict(), **{"stop": stop}},
[],
)
)[1]
return await self.alookup_with_id(prompt, llm_string=llm_string)
def delete_by_document_id(self, document_id: str) -> None:
"""
Given this is a "similarity search" cache, an invalidation pattern
@ -1304,10 +1424,22 @@ class CassandraSemanticCache(BaseCache):
"""
self.table.delete(row_id=document_id)
async def adelete_by_document_id(self, document_id: str) -> None:
"""
Given this is a "similarity search" cache, an invalidation pattern
that makes sense is first a lookup to get an ID, and then deleting
with that ID. This is for the second step.
"""
await self.table.adelete(row_id=document_id)
def clear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache."""
self.table.clear()
async def aclear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache."""
await self.table.aclear()
class FullMd5LLMCache(Base): # type: ignore
"""SQLite table for full LLM Cache (all generations)."""
@ -1412,7 +1544,7 @@ class AstraDBCache(BaseCache):
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
setup_mode: AstraSetupMode = AstraSetupMode.SYNC,
):
"""
Cache that uses Astra DB as a backend.
@ -1612,7 +1744,7 @@ class AstraDBSemanticCache(BaseCache):
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
setup_mode: AstraSetupMode = AstraSetupMode.SYNC,
pre_delete_collection: bool = False,
embedding: Embeddings,
metric: Optional[str] = None,
@ -1675,9 +1807,9 @@ class AstraDBSemanticCache(BaseCache):
self._aget_embedding = _acache_embedding
embedding_dimension: Union[int, Awaitable[int], None] = None
if setup_mode == SetupMode.ASYNC:
if setup_mode == AstraSetupMode.ASYNC:
embedding_dimension = self._aget_embedding_dimension()
elif setup_mode == SetupMode.SYNC:
elif setup_mode == AstraSetupMode.SYNC:
embedding_dimension = self._get_embedding_dimension()
self.astra_env = _AstraDBCollectionEnvironment(

@ -93,13 +93,47 @@ class BaseCache(ABC):
"""Clear cache that can take additional keyword arguments."""
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Async version of lookup."""
"""Look up based on prompt and llm_string.
A cache implementation is expected to generate a key from the 2-tuple
of prompt and llm_string (e.g., by concatenating them with a delimiter).
Args:
prompt: a string representation of the prompt.
In the case of a Chat model, the prompt is a non-trivial
serialization of the prompt into the language model.
llm_string: A string representation of the LLM configuration.
This is used to capture the invocation parameters of the LLM
(e.g., model name, temperature, stop tokens, max tokens, etc.).
These invocation parameters are serialized into a string
representation.
Returns:
On a cache miss, return None. On a cache hit, return the cached value.
The cached value is a list of Generations (or subclasses).
"""
return await run_in_executor(None, self.lookup, prompt, llm_string)
async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Async version of aupdate."""
"""Update cache based on prompt and llm_string.
The prompt and llm_string are used to generate a key for the cache.
The key should match that of the look up method.
Args:
prompt: a string representation of the prompt.
In the case of a Chat model, the prompt is a non-trivial
serialization of the prompt into the language model.
llm_string: A string representation of the LLM configuration.
This is used to capture the invocation parameters of the LLM
(e.g., model name, temperature, stop tokens, max tokens, etc.).
These invocation parameters are serialized into a string
representation.
return_val: The value to be cached. The value is a list of Generations
(or subclasses).
"""
return await run_in_executor(None, self.update, prompt, llm_string, return_val)
async def aclear(self, **kwargs: Any) -> None:

@ -1,10 +1,11 @@
"""Test Cassandra caches. Requires a running vector-capable Cassandra cluster."""
import asyncio
import os
import time
from typing import Any, Iterator, Tuple
import pytest
from langchain_community.utilities.cassandra import SetupMode
from langchain_core.outputs import Generation, LLMResult
from langchain.cache import CassandraCache, CassandraSemanticCache
@ -47,16 +48,34 @@ def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
print(output) # noqa: T201
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
print(expected_output) # noqa: T201
assert output == expected_output
cache.clear()
async def test_cassandra_cache_async(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection
cache = CassandraCache(
session=session, keyspace=keyspace, setup_mode=SetupMode.ASYNC
)
set_llm_cache(cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
output = await llm.agenerate(["foo"])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
await cache.aclear()
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection
cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2)
@ -79,6 +98,30 @@ def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
cache.clear()
async def test_cassandra_cache_ttl_async(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection
cache = CassandraCache(
session=session, keyspace=keyspace, ttl_seconds=2, setup_mode=SetupMode.ASYNC
)
set_llm_cache(cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
output = await llm.agenerate(["foo"])
assert output == expected_output
await asyncio.sleep(2.5)
# entry has expired away.
output = await llm.agenerate(["foo"])
assert output != expected_output
await cache.aclear()
def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection
sem_cache = CassandraSemanticCache(
@ -103,3 +146,32 @@ def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None
output = llm.generate(["bar"]) # 'fizz' is erased away now
assert output != expected_output
sem_cache.clear()
async def test_cassandra_semantic_cache_async(
cassandra_connection: Tuple[Any, str],
) -> None:
session, keyspace = cassandra_connection
sem_cache = CassandraSemanticCache(
session=session,
keyspace=keyspace,
embedding=FakeEmbeddings(),
setup_mode=SetupMode.ASYNC,
)
set_llm_cache(sem_cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
output = await llm.agenerate(["bar"]) # same embedding as 'foo'
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
await sem_cache.aclear()
output = await llm.agenerate(["bar"]) # 'fizz' is erased away now
assert output != expected_output
await sem_cache.aclear()

Loading…
Cancel
Save