|
|
|
@ -53,6 +53,7 @@ from sqlalchemy.engine import Row
|
|
|
|
|
from sqlalchemy.engine.base import Engine
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
from langchain_community.utilities.cassandra import SetupMode as CassandraSetupMode
|
|
|
|
|
from langchain_community.vectorstores.azure_cosmos_db import (
|
|
|
|
|
CosmosDBSimilarityType,
|
|
|
|
|
CosmosDBVectorSearchType,
|
|
|
|
@ -63,7 +64,7 @@ try:
|
|
|
|
|
except ImportError:
|
|
|
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
|
|
|
|
|
|
|
|
from langchain_core._api.deprecation import deprecated
|
|
|
|
|
from langchain_core._api.deprecation import deprecated, warn_deprecated
|
|
|
|
|
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
|
|
|
|
from langchain_core.embeddings import Embeddings
|
|
|
|
|
from langchain_core.language_models.llms import LLM, aget_prompts, get_prompts
|
|
|
|
@ -73,7 +74,9 @@ from langchain_core.outputs import ChatGeneration, Generation
|
|
|
|
|
from langchain_core.utils import get_from_env
|
|
|
|
|
|
|
|
|
|
from langchain_community.utilities.astradb import (
|
|
|
|
|
SetupMode,
|
|
|
|
|
SetupMode as AstraSetupMode,
|
|
|
|
|
)
|
|
|
|
|
from langchain_community.utilities.astradb import (
|
|
|
|
|
_AstraDBCollectionEnvironment,
|
|
|
|
|
)
|
|
|
|
|
from langchain_community.vectorstores import AzureCosmosDBVectorSearch
|
|
|
|
@ -1056,6 +1059,7 @@ class CassandraCache(BaseCache):
|
|
|
|
|
table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME,
|
|
|
|
|
ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS,
|
|
|
|
|
skip_provisioning: bool = False,
|
|
|
|
|
setup_mode: CassandraSetupMode = CassandraSetupMode.SYNC,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Initialize with a ready session and a keyspace name.
|
|
|
|
@ -1066,6 +1070,10 @@ class CassandraCache(BaseCache):
|
|
|
|
|
ttl_seconds (optional int): time-to-live for cache entries
|
|
|
|
|
(default: None, i.e. forever)
|
|
|
|
|
"""
|
|
|
|
|
if skip_provisioning:
|
|
|
|
|
warn_deprecated(
|
|
|
|
|
"0.0.33", alternative="Use setup_mode=CassandraSetupMode.OFF instead."
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
from cassio.table import ElasticCassandraTable
|
|
|
|
|
except (ImportError, ModuleNotFoundError):
|
|
|
|
@ -1079,6 +1087,10 @@ class CassandraCache(BaseCache):
|
|
|
|
|
self.table_name = table_name
|
|
|
|
|
self.ttl_seconds = ttl_seconds
|
|
|
|
|
|
|
|
|
|
kwargs = {}
|
|
|
|
|
if setup_mode == CassandraSetupMode.ASYNC:
|
|
|
|
|
kwargs["async_setup"] = True
|
|
|
|
|
|
|
|
|
|
self.kv_cache = ElasticCassandraTable(
|
|
|
|
|
session=self.session,
|
|
|
|
|
keyspace=self.keyspace,
|
|
|
|
@ -1086,27 +1098,31 @@ class CassandraCache(BaseCache):
|
|
|
|
|
keys=["llm_string", "prompt"],
|
|
|
|
|
primary_key_type=["TEXT", "TEXT"],
|
|
|
|
|
ttl_seconds=self.ttl_seconds,
|
|
|
|
|
skip_provisioning=skip_provisioning,
|
|
|
|
|
skip_provisioning=skip_provisioning or setup_mode == CassandraSetupMode.OFF,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
|
|
|
"""Look up based on prompt and llm_string."""
|
|
|
|
|
item = self.kv_cache.get(
|
|
|
|
|
llm_string=_hash(llm_string),
|
|
|
|
|
prompt=_hash(prompt),
|
|
|
|
|
)
|
|
|
|
|
if item is not None:
|
|
|
|
|
generations = _loads_generations(item["body_blob"])
|
|
|
|
|
# this protects against malformed cached items:
|
|
|
|
|
if generations is not None:
|
|
|
|
|
return generations
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
return _loads_generations(item["body_blob"])
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
|
|
|
item = await self.kv_cache.aget(
|
|
|
|
|
llm_string=_hash(llm_string),
|
|
|
|
|
prompt=_hash(prompt),
|
|
|
|
|
)
|
|
|
|
|
if item is not None:
|
|
|
|
|
return _loads_generations(item["body_blob"])
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
|
|
|
"""Update cache based on prompt and llm_string."""
|
|
|
|
|
blob = _dumps_generations(return_val)
|
|
|
|
|
self.kv_cache.put(
|
|
|
|
|
llm_string=_hash(llm_string),
|
|
|
|
@ -1114,6 +1130,16 @@ class CassandraCache(BaseCache):
|
|
|
|
|
body_blob=blob,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def aupdate(
|
|
|
|
|
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
|
|
|
|
|
) -> None:
|
|
|
|
|
blob = _dumps_generations(return_val)
|
|
|
|
|
await self.kv_cache.aput(
|
|
|
|
|
llm_string=_hash(llm_string),
|
|
|
|
|
prompt=_hash(prompt),
|
|
|
|
|
body_blob=blob,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def delete_through_llm(
|
|
|
|
|
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
|
|
|
|
) -> None:
|
|
|
|
@ -1139,6 +1165,10 @@ class CassandraCache(BaseCache):
|
|
|
|
|
"""Clear cache. This is for all LLMs at once."""
|
|
|
|
|
self.kv_cache.clear()
|
|
|
|
|
|
|
|
|
|
async def aclear(self, **kwargs: Any) -> None:
|
|
|
|
|
"""Clear cache. This is for all LLMs at once."""
|
|
|
|
|
await self.kv_cache.aclear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC = "dot"
|
|
|
|
|
CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD = 0.85
|
|
|
|
@ -1170,6 +1200,7 @@ class CassandraSemanticCache(BaseCache):
|
|
|
|
|
score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD,
|
|
|
|
|
ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS,
|
|
|
|
|
skip_provisioning: bool = False,
|
|
|
|
|
setup_mode: CassandraSetupMode = CassandraSetupMode.SYNC,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Initialize the cache with all relevant parameters.
|
|
|
|
@ -1189,6 +1220,10 @@ class CassandraSemanticCache(BaseCache):
|
|
|
|
|
The default score threshold is tuned to the default metric.
|
|
|
|
|
Tune it carefully yourself if switching to another distance metric.
|
|
|
|
|
"""
|
|
|
|
|
if skip_provisioning:
|
|
|
|
|
warn_deprecated(
|
|
|
|
|
"0.0.33", alternative="Use setup_mode=CassandraSetupMode.OFF instead."
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
from cassio.table import MetadataVectorCassandraTable
|
|
|
|
|
except (ImportError, ModuleNotFoundError):
|
|
|
|
@ -1214,24 +1249,42 @@ class CassandraSemanticCache(BaseCache):
|
|
|
|
|
return self.embedding.embed_query(text=text)
|
|
|
|
|
|
|
|
|
|
self._get_embedding = _cache_embedding
|
|
|
|
|
self.embedding_dimension = self._get_embedding_dimension()
|
|
|
|
|
|
|
|
|
|
@_async_lru_cache(maxsize=CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
|
|
|
|
|
async def _acache_embedding(text: str) -> List[float]:
|
|
|
|
|
return await self.embedding.aembed_query(text=text)
|
|
|
|
|
|
|
|
|
|
self._aget_embedding = _acache_embedding
|
|
|
|
|
|
|
|
|
|
embedding_dimension: Union[int, Awaitable[int], None] = None
|
|
|
|
|
if setup_mode == CassandraSetupMode.ASYNC:
|
|
|
|
|
embedding_dimension = self._aget_embedding_dimension()
|
|
|
|
|
elif setup_mode == CassandraSetupMode.SYNC:
|
|
|
|
|
embedding_dimension = self._get_embedding_dimension()
|
|
|
|
|
|
|
|
|
|
kwargs = {}
|
|
|
|
|
if setup_mode == CassandraSetupMode.ASYNC:
|
|
|
|
|
kwargs["async_setup"] = True
|
|
|
|
|
|
|
|
|
|
self.table = MetadataVectorCassandraTable(
|
|
|
|
|
session=self.session,
|
|
|
|
|
keyspace=self.keyspace,
|
|
|
|
|
table=self.table_name,
|
|
|
|
|
primary_key_type=["TEXT"],
|
|
|
|
|
vector_dimension=self.embedding_dimension,
|
|
|
|
|
vector_dimension=embedding_dimension,
|
|
|
|
|
ttl_seconds=self.ttl_seconds,
|
|
|
|
|
metadata_indexing=("allow", {"_llm_string_hash"}),
|
|
|
|
|
skip_provisioning=skip_provisioning,
|
|
|
|
|
skip_provisioning=skip_provisioning or setup_mode == CassandraSetupMode.OFF,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _get_embedding_dimension(self) -> int:
|
|
|
|
|
return len(self._get_embedding(text="This is a sample sentence."))
|
|
|
|
|
|
|
|
|
|
async def _aget_embedding_dimension(self) -> int:
|
|
|
|
|
return len(await self._aget_embedding(text="This is a sample sentence."))
|
|
|
|
|
|
|
|
|
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
|
|
|
"""Update cache based on prompt and llm_string."""
|
|
|
|
|
embedding_vector = self._get_embedding(text=prompt)
|
|
|
|
|
llm_string_hash = _hash(llm_string)
|
|
|
|
|
body = _dumps_generations(return_val)
|
|
|
|
@ -1240,7 +1293,7 @@ class CassandraSemanticCache(BaseCache):
|
|
|
|
|
"_llm_string_hash": llm_string_hash,
|
|
|
|
|
}
|
|
|
|
|
row_id = f"{_hash(prompt)}-{llm_string_hash}"
|
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
self.table.put(
|
|
|
|
|
body_blob=body,
|
|
|
|
|
vector=embedding_vector,
|
|
|
|
@ -1248,14 +1301,39 @@ class CassandraSemanticCache(BaseCache):
|
|
|
|
|
metadata=metadata,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def aupdate(
|
|
|
|
|
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
|
|
|
|
|
) -> None:
|
|
|
|
|
embedding_vector = await self._aget_embedding(text=prompt)
|
|
|
|
|
llm_string_hash = _hash(llm_string)
|
|
|
|
|
body = _dumps_generations(return_val)
|
|
|
|
|
metadata = {
|
|
|
|
|
"_prompt": prompt,
|
|
|
|
|
"_llm_string_hash": llm_string_hash,
|
|
|
|
|
}
|
|
|
|
|
row_id = f"{_hash(prompt)}-{llm_string_hash}"
|
|
|
|
|
|
|
|
|
|
await self.table.aput(
|
|
|
|
|
body_blob=body,
|
|
|
|
|
vector=embedding_vector,
|
|
|
|
|
row_id=row_id,
|
|
|
|
|
metadata=metadata,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
|
|
|
"""Look up based on prompt and llm_string."""
|
|
|
|
|
hit_with_id = self.lookup_with_id(prompt, llm_string)
|
|
|
|
|
if hit_with_id is not None:
|
|
|
|
|
return hit_with_id[1]
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
|
|
|
hit_with_id = await self.alookup_with_id(prompt, llm_string)
|
|
|
|
|
if hit_with_id is not None:
|
|
|
|
|
return hit_with_id[1]
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def lookup_with_id(
|
|
|
|
|
self, prompt: str, llm_string: str
|
|
|
|
|
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
|
|
|
@ -1287,6 +1365,37 @@ class CassandraSemanticCache(BaseCache):
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
async def alookup_with_id(
|
|
|
|
|
self, prompt: str, llm_string: str
|
|
|
|
|
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
|
|
|
|
"""
|
|
|
|
|
Look up based on prompt and llm_string.
|
|
|
|
|
If there are hits, return (document_id, cached_entry)
|
|
|
|
|
"""
|
|
|
|
|
prompt_embedding: List[float] = await self._aget_embedding(text=prompt)
|
|
|
|
|
hits = list(
|
|
|
|
|
await self.table.ametric_ann_search(
|
|
|
|
|
vector=prompt_embedding,
|
|
|
|
|
metadata={"_llm_string_hash": _hash(llm_string)},
|
|
|
|
|
n=1,
|
|
|
|
|
metric=self.distance_metric,
|
|
|
|
|
metric_threshold=self.score_threshold,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
if hits:
|
|
|
|
|
hit = hits[0]
|
|
|
|
|
generations = _loads_generations(hit["body_blob"])
|
|
|
|
|
if generations is not None:
|
|
|
|
|
# this protects against malformed cached items:
|
|
|
|
|
return (
|
|
|
|
|
hit["row_id"],
|
|
|
|
|
generations,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def lookup_with_id_through_llm(
|
|
|
|
|
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
|
|
|
|
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
|
|
|
@ -1296,6 +1405,17 @@ class CassandraSemanticCache(BaseCache):
|
|
|
|
|
)[1]
|
|
|
|
|
return self.lookup_with_id(prompt, llm_string=llm_string)
|
|
|
|
|
|
|
|
|
|
async def alookup_with_id_through_llm(
|
|
|
|
|
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
|
|
|
|
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
|
|
|
|
llm_string = (
|
|
|
|
|
await aget_prompts(
|
|
|
|
|
{**llm.dict(), **{"stop": stop}},
|
|
|
|
|
[],
|
|
|
|
|
)
|
|
|
|
|
)[1]
|
|
|
|
|
return await self.alookup_with_id(prompt, llm_string=llm_string)
|
|
|
|
|
|
|
|
|
|
def delete_by_document_id(self, document_id: str) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Given this is a "similarity search" cache, an invalidation pattern
|
|
|
|
@ -1304,10 +1424,22 @@ class CassandraSemanticCache(BaseCache):
|
|
|
|
|
"""
|
|
|
|
|
self.table.delete(row_id=document_id)
|
|
|
|
|
|
|
|
|
|
async def adelete_by_document_id(self, document_id: str) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Given this is a "similarity search" cache, an invalidation pattern
|
|
|
|
|
that makes sense is first a lookup to get an ID, and then deleting
|
|
|
|
|
with that ID. This is for the second step.
|
|
|
|
|
"""
|
|
|
|
|
await self.table.adelete(row_id=document_id)
|
|
|
|
|
|
|
|
|
|
def clear(self, **kwargs: Any) -> None:
|
|
|
|
|
"""Clear the *whole* semantic cache."""
|
|
|
|
|
self.table.clear()
|
|
|
|
|
|
|
|
|
|
async def aclear(self, **kwargs: Any) -> None:
|
|
|
|
|
"""Clear the *whole* semantic cache."""
|
|
|
|
|
await self.table.aclear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FullMd5LLMCache(Base): # type: ignore
|
|
|
|
|
"""SQLite table for full LLM Cache (all generations)."""
|
|
|
|
@ -1412,7 +1544,7 @@ class AstraDBCache(BaseCache):
|
|
|
|
|
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
|
|
|
|
namespace: Optional[str] = None,
|
|
|
|
|
pre_delete_collection: bool = False,
|
|
|
|
|
setup_mode: SetupMode = SetupMode.SYNC,
|
|
|
|
|
setup_mode: AstraSetupMode = AstraSetupMode.SYNC,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Cache that uses Astra DB as a backend.
|
|
|
|
@ -1612,7 +1744,7 @@ class AstraDBSemanticCache(BaseCache):
|
|
|
|
|
astra_db_client: Optional[AstraDB] = None,
|
|
|
|
|
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
|
|
|
|
namespace: Optional[str] = None,
|
|
|
|
|
setup_mode: SetupMode = SetupMode.SYNC,
|
|
|
|
|
setup_mode: AstraSetupMode = AstraSetupMode.SYNC,
|
|
|
|
|
pre_delete_collection: bool = False,
|
|
|
|
|
embedding: Embeddings,
|
|
|
|
|
metric: Optional[str] = None,
|
|
|
|
@ -1675,9 +1807,9 @@ class AstraDBSemanticCache(BaseCache):
|
|
|
|
|
self._aget_embedding = _acache_embedding
|
|
|
|
|
|
|
|
|
|
embedding_dimension: Union[int, Awaitable[int], None] = None
|
|
|
|
|
if setup_mode == SetupMode.ASYNC:
|
|
|
|
|
if setup_mode == AstraSetupMode.ASYNC:
|
|
|
|
|
embedding_dimension = self._aget_embedding_dimension()
|
|
|
|
|
elif setup_mode == SetupMode.SYNC:
|
|
|
|
|
elif setup_mode == AstraSetupMode.SYNC:
|
|
|
|
|
embedding_dimension = self._get_embedding_dimension()
|
|
|
|
|
|
|
|
|
|
self.astra_env = _AstraDBCollectionEnvironment(
|
|
|
|
|