|
|
|
@ -25,6 +25,7 @@ import hashlib
|
|
|
|
|
import inspect
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
import uuid
|
|
|
|
|
from datetime import timedelta
|
|
|
|
|
from functools import lru_cache
|
|
|
|
|
from typing import (
|
|
|
|
@ -40,7 +41,7 @@ from typing import (
|
|
|
|
|
cast,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import Column, Integer, String, create_engine, select
|
|
|
|
|
from sqlalchemy import Column, Integer, Row, String, create_engine, select
|
|
|
|
|
from sqlalchemy.engine.base import Engine
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
@ -49,7 +50,6 @@ try:
|
|
|
|
|
except ImportError:
|
|
|
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from langchain.llms.base import LLM, get_prompts
|
|
|
|
|
from langchain.load.dump import dumps
|
|
|
|
|
from langchain.load.load import loads
|
|
|
|
@ -1066,3 +1066,87 @@ class CassandraSemanticCache(BaseCache):
|
|
|
|
|
def clear(self, **kwargs: Any) -> None:
|
|
|
|
|
"""Clear the *whole* semantic cache."""
|
|
|
|
|
self.table.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FullMd5LLMCache(Base): # type: ignore
|
|
|
|
|
"""SQLite table for full LLM Cache (all generations)."""
|
|
|
|
|
|
|
|
|
|
__tablename__ = "full_md5_llm_cache"
|
|
|
|
|
id = Column(String, primary_key=True)
|
|
|
|
|
prompt_md5 = Column(String, index=True)
|
|
|
|
|
llm = Column(String, index=True)
|
|
|
|
|
idx = Column(Integer, index=True)
|
|
|
|
|
prompt = Column(String)
|
|
|
|
|
response = Column(String)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SQLAlchemyMd5Cache(BaseCache):
|
|
|
|
|
"""Cache that uses SQAlchemy as a backend."""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self, engine: Engine, cache_schema: Type[FullMd5LLMCache] = FullMd5LLMCache
|
|
|
|
|
):
|
|
|
|
|
"""Initialize by creating all tables."""
|
|
|
|
|
self.engine = engine
|
|
|
|
|
self.cache_schema = cache_schema
|
|
|
|
|
self.cache_schema.metadata.create_all(self.engine)
|
|
|
|
|
|
|
|
|
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
|
|
|
"""Look up based on prompt and llm_string."""
|
|
|
|
|
rows = self._search_rows(prompt, llm_string)
|
|
|
|
|
if rows:
|
|
|
|
|
return [loads(row[0]) for row in rows]
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
|
|
|
"""Update based on prompt and llm_string."""
|
|
|
|
|
self._delete_previous(prompt, llm_string)
|
|
|
|
|
prompt_md5 = self.get_md5(prompt)
|
|
|
|
|
items = [
|
|
|
|
|
self.cache_schema(
|
|
|
|
|
id=str(uuid.uuid1()),
|
|
|
|
|
prompt=prompt,
|
|
|
|
|
prompt_md5=prompt_md5,
|
|
|
|
|
llm=llm_string,
|
|
|
|
|
response=dumps(gen),
|
|
|
|
|
idx=i,
|
|
|
|
|
)
|
|
|
|
|
for i, gen in enumerate(return_val)
|
|
|
|
|
]
|
|
|
|
|
with Session(self.engine) as session, session.begin():
|
|
|
|
|
for item in items:
|
|
|
|
|
session.merge(item)
|
|
|
|
|
|
|
|
|
|
def _delete_previous(self, prompt: str, llm_string: str) -> None:
|
|
|
|
|
stmt = (
|
|
|
|
|
select(self.cache_schema.response)
|
|
|
|
|
.where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore
|
|
|
|
|
.where(self.cache_schema.llm == llm_string)
|
|
|
|
|
.where(self.cache_schema.prompt == prompt)
|
|
|
|
|
.order_by(self.cache_schema.idx)
|
|
|
|
|
)
|
|
|
|
|
with Session(self.engine) as session, session.begin():
|
|
|
|
|
rows = session.execute(stmt).fetchall()
|
|
|
|
|
for item in rows:
|
|
|
|
|
session.delete(item)
|
|
|
|
|
|
|
|
|
|
def _search_rows(self, prompt: str, llm_string: str) -> List[Row]:
|
|
|
|
|
prompt_pd5 = self.get_md5(prompt)
|
|
|
|
|
stmt = (
|
|
|
|
|
select(self.cache_schema.response)
|
|
|
|
|
.where(self.cache_schema.prompt_md5 == prompt_pd5) # type: ignore
|
|
|
|
|
.where(self.cache_schema.llm == llm_string)
|
|
|
|
|
.where(self.cache_schema.prompt == prompt)
|
|
|
|
|
.order_by(self.cache_schema.idx)
|
|
|
|
|
)
|
|
|
|
|
with Session(self.engine) as session:
|
|
|
|
|
return session.execute(stmt).fetchall()
|
|
|
|
|
|
|
|
|
|
def clear(self, **kwargs: Any) -> None:
|
|
|
|
|
"""Clear cache."""
|
|
|
|
|
with Session(self.engine) as session:
|
|
|
|
|
session.execute(self.cache_schema.delete())
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_md5(input_string: str) -> str:
|
|
|
|
|
return hashlib.md5(input_string.encode()).hexdigest()
|
|
|
|
|