#11655 Add SQLAlchemyMd5Cache implementation (#11660)

- **Description:** Add SQLAlchemyMd5Cache implementation, 
  - **Issue:** the issue # #11655,
  - **Dependencies:** no deps,
  - **Tag maintainer:** @markowanga

---------

Co-authored-by: Marcin Wątroba <marcin.watroba@pwr.edu.pl>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/11569/head^2
Marcin Wątroba 1 year ago committed by GitHub
parent 70f7558db2
commit 51a3a86022
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,6 +25,7 @@ import hashlib
import inspect
import json
import logging
import uuid
from datetime import timedelta
from functools import lru_cache
from typing import (
@ -40,7 +41,7 @@ from typing import (
cast,
)
from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy import Column, Integer, Row, String, create_engine, select
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session
@ -49,7 +50,6 @@ try:
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain.llms.base import LLM, get_prompts
from langchain.load.dump import dumps
from langchain.load.load import loads
@ -1066,3 +1066,87 @@ class CassandraSemanticCache(BaseCache):
def clear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache."""
self.table.clear()
class FullMd5LLMCache(Base): # type: ignore
"""SQLite table for full LLM Cache (all generations)."""
__tablename__ = "full_md5_llm_cache"
id = Column(String, primary_key=True)
prompt_md5 = Column(String, index=True)
llm = Column(String, index=True)
idx = Column(Integer, index=True)
prompt = Column(String)
response = Column(String)
class SQLAlchemyMd5Cache(BaseCache):
"""Cache that uses SQAlchemy as a backend."""
def __init__(
self, engine: Engine, cache_schema: Type[FullMd5LLMCache] = FullMd5LLMCache
):
"""Initialize by creating all tables."""
self.engine = engine
self.cache_schema = cache_schema
self.cache_schema.metadata.create_all(self.engine)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
rows = self._search_rows(prompt, llm_string)
if rows:
return [loads(row[0]) for row in rows]
return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update based on prompt and llm_string."""
self._delete_previous(prompt, llm_string)
prompt_md5 = self.get_md5(prompt)
items = [
self.cache_schema(
id=str(uuid.uuid1()),
prompt=prompt,
prompt_md5=prompt_md5,
llm=llm_string,
response=dumps(gen),
idx=i,
)
for i, gen in enumerate(return_val)
]
with Session(self.engine) as session, session.begin():
for item in items:
session.merge(item)
def _delete_previous(self, prompt: str, llm_string: str) -> None:
stmt = (
select(self.cache_schema.response)
.where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore
.where(self.cache_schema.llm == llm_string)
.where(self.cache_schema.prompt == prompt)
.order_by(self.cache_schema.idx)
)
with Session(self.engine) as session, session.begin():
rows = session.execute(stmt).fetchall()
for item in rows:
session.delete(item)
def _search_rows(self, prompt: str, llm_string: str) -> List[Row]:
prompt_pd5 = self.get_md5(prompt)
stmt = (
select(self.cache_schema.response)
.where(self.cache_schema.prompt_md5 == prompt_pd5) # type: ignore
.where(self.cache_schema.llm == llm_string)
.where(self.cache_schema.prompt == prompt)
.order_by(self.cache_schema.idx)
)
with Session(self.engine) as session:
return session.execute(stmt).fetchall()
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
with Session(self.engine) as session:
session.execute(self.cache_schema.delete())
@staticmethod
def get_md5(input_string: str) -> str:
return hashlib.md5(input_string.encode()).hexdigest()

Loading…
Cancel
Save