From 73f7ebd9d11bef702f7743a4acaa25940a9393ca Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 4 Jan 2023 18:38:15 -0800 Subject: [PATCH] Harrison/sqlalchemy cache store (#536) Co-authored-by: Jason Gill --- docs/modules/llms/examples/llm_caching.ipynb | 46 ++++++++++++++++++++ langchain/cache.py | 13 +++--- 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/docs/modules/llms/examples/llm_caching.ipynb b/docs/modules/llms/examples/llm_caching.ipynb index 3538dd5a..4060689b 100644 --- a/docs/modules/llms/examples/llm_caching.ipynb +++ b/docs/modules/llms/examples/llm_caching.ipynb @@ -276,6 +276,52 @@ "# langchain.llm_cache = SQLAlchemyCache(engine)" ] }, + { + "cell_type": "markdown", + "source": [ + "### Custom SQLAlchemy Schemas" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# You can define your own declarative SQLAlchemyCache child class to customize the schema used for caching. For example, to support high-speed fulltext prompt indexing with Postgres, use:\n", + "\n", + "from sqlalchemy import Column, Integer, String, Computed, Index, Sequence\n", + "from sqlalchemy import create_engine\n", + "from sqlalchemy.ext.declarative import declarative_base\n", + "from sqlalchemy_utils import TSVectorType\n", + "from langchain.cache import SQLAlchemyCache\n", + "\n", + "Base = declarative_base()\n", + "\n", + "\n", + "class FulltextLLMCache(Base): # type: ignore\n", + " \"\"\"Postgres table for fulltext-indexed LLM Cache\"\"\"\n", + "\n", + " __tablename__ = \"llm_cache_fulltext\"\n", + " id = Column(Integer, Sequence('cache_id'), primary_key=True)\n", + " prompt = Column(String, nullable=False)\n", + " llm = Column(String, nullable=False)\n", + " idx = Column(Integer)\n", + " response = Column(String)\n", + " prompt_tsv = Column(TSVectorType(), Computed(\"to_tsvector('english', llm || ' ' || prompt)\", persisted=True))\n", + " __table_args__ = (\n", + " Index(\"idx_fulltext_prompt_tsv\", prompt_tsv, postgresql_using=\"gin\"),\n", + " )\n", + "\n", + "engine = create_engine(\"postgresql://postgres:postgres@localhost:5432/postgres\")\n", + "langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "id": "0c69d84d", diff --git a/langchain/cache.py b/langchain/cache.py index 9198b8e2..869e1556 100644 --- a/langchain/cache.py +++ b/langchain/cache.py @@ -56,18 +56,19 @@ class FullLLMCache(Base): # type: ignore class SQLAlchemyCache(BaseCache): """Cache that uses SQAlchemy as a backend.""" - def __init__(self, engine: Engine): + def __init__(self, engine: Engine, cache_schema: Any = FullLLMCache): """Initialize by creating all tables.""" self.engine = engine + self.cache_schema = cache_schema Base.metadata.create_all(self.engine) def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up based on prompt and llm_string.""" stmt = ( - select(FullLLMCache.response) - .where(FullLLMCache.prompt == prompt) - .where(FullLLMCache.llm == llm_string) - .order_by(FullLLMCache.idx) + select(self.cache_schema.response) + .where(self.cache_schema.prompt == prompt) + .where(self.cache_schema.llm == llm_string) + .order_by(self.cache_schema.idx) ) with Session(self.engine) as session: generations = [] @@ -80,7 +81,7 @@ class SQLAlchemyCache(BaseCache): def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """Look up based on prompt and llm_string.""" for i, generation in enumerate(return_val): - item = FullLLMCache( + item = self.cache_schema( prompt=prompt, llm=llm_string, response=generation.text, idx=i ) with Session(self.engine) as session, session.begin():