diff --git a/libs/langchain/langchain/cache.py b/libs/langchain/langchain/cache.py index bbcc87bd39..be96bdf441 100644 --- a/libs/langchain/langchain/cache.py +++ b/libs/langchain/langchain/cache.py @@ -25,6 +25,7 @@ import hashlib import inspect import json import logging +import uuid from datetime import timedelta from functools import lru_cache from typing import ( @@ -40,7 +41,7 @@ from typing import ( cast, ) -from sqlalchemy import Column, Integer, String, create_engine, select +from sqlalchemy import Column, Integer, Row, String, create_engine, select from sqlalchemy.engine.base import Engine from sqlalchemy.orm import Session @@ -49,7 +50,6 @@ try: except ImportError: from sqlalchemy.ext.declarative import declarative_base - from langchain.llms.base import LLM, get_prompts from langchain.load.dump import dumps from langchain.load.load import loads @@ -1066,3 +1066,87 @@ class CassandraSemanticCache(BaseCache): def clear(self, **kwargs: Any) -> None: """Clear the *whole* semantic cache.""" self.table.clear() + + +class FullMd5LLMCache(Base): # type: ignore + """SQLite table for full LLM Cache (all generations).""" + + __tablename__ = "full_md5_llm_cache" + id = Column(String, primary_key=True) + prompt_md5 = Column(String, index=True) + llm = Column(String, index=True) + idx = Column(Integer, index=True) + prompt = Column(String) + response = Column(String) + + +class SQLAlchemyMd5Cache(BaseCache): + """Cache that uses SQAlchemy as a backend.""" + + def __init__( + self, engine: Engine, cache_schema: Type[FullMd5LLMCache] = FullMd5LLMCache + ): + """Initialize by creating all tables.""" + self.engine = engine + self.cache_schema = cache_schema + self.cache_schema.metadata.create_all(self.engine) + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + rows = self._search_rows(prompt, llm_string) + if rows: + return [loads(row[0]) for row in rows] + return None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update based on prompt and llm_string.""" + self._delete_previous(prompt, llm_string) + prompt_md5 = self.get_md5(prompt) + items = [ + self.cache_schema( + id=str(uuid.uuid1()), + prompt=prompt, + prompt_md5=prompt_md5, + llm=llm_string, + response=dumps(gen), + idx=i, + ) + for i, gen in enumerate(return_val) + ] + with Session(self.engine) as session, session.begin(): + for item in items: + session.merge(item) + + def _delete_previous(self, prompt: str, llm_string: str) -> None: + stmt = ( + select(self.cache_schema.response) + .where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore + .where(self.cache_schema.llm == llm_string) + .where(self.cache_schema.prompt == prompt) + .order_by(self.cache_schema.idx) + ) + with Session(self.engine) as session, session.begin(): + rows = session.execute(stmt).fetchall() + for item in rows: + session.delete(item) + + def _search_rows(self, prompt: str, llm_string: str) -> List[Row]: + prompt_pd5 = self.get_md5(prompt) + stmt = ( + select(self.cache_schema.response) + .where(self.cache_schema.prompt_md5 == prompt_pd5) # type: ignore + .where(self.cache_schema.llm == llm_string) + .where(self.cache_schema.prompt == prompt) + .order_by(self.cache_schema.idx) + ) + with Session(self.engine) as session: + return session.execute(stmt).fetchall() + + def clear(self, **kwargs: Any) -> None: + """Clear cache.""" + with Session(self.engine) as session: + session.execute(self.cache_schema.delete()) + + @staticmethod + def get_md5(input_string: str) -> str: + return hashlib.md5(input_string.encode()).hexdigest()