Harrison/sqlalchemy cache store (#536)

Co-authored-by: Jason Gill <jasongill@gmail.com>
This commit is contained in:
Harrison Chase 2023-01-04 18:38:15 -08:00 committed by GitHub
parent 870cccb877
commit 73f7ebd9d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 6 deletions

View File

@ -276,6 +276,52 @@
"# langchain.llm_cache = SQLAlchemyCache(engine)" "# langchain.llm_cache = SQLAlchemyCache(engine)"
] ]
}, },
{
"cell_type": "markdown",
"source": [
"### Custom SQLAlchemy Schemas"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# You can define your own declarative SQLAlchemyCache child class to customize the schema used for caching. For example, to support high-speed fulltext prompt indexing with Postgres, use:\n",
"\n",
"from sqlalchemy import Column, Integer, String, Computed, Index, Sequence\n",
"from sqlalchemy import create_engine\n",
"from sqlalchemy.ext.declarative import declarative_base\n",
"from sqlalchemy_utils import TSVectorType\n",
"from langchain.cache import SQLAlchemyCache\n",
"\n",
"Base = declarative_base()\n",
"\n",
"\n",
"class FulltextLLMCache(Base): # type: ignore\n",
" \"\"\"Postgres table for fulltext-indexed LLM Cache\"\"\"\n",
"\n",
" __tablename__ = \"llm_cache_fulltext\"\n",
" id = Column(Integer, Sequence('cache_id'), primary_key=True)\n",
" prompt = Column(String, nullable=False)\n",
" llm = Column(String, nullable=False)\n",
" idx = Column(Integer)\n",
" response = Column(String)\n",
" prompt_tsv = Column(TSVectorType(), Computed(\"to_tsvector('english', llm || ' ' || prompt)\", persisted=True))\n",
" __table_args__ = (\n",
" Index(\"idx_fulltext_prompt_tsv\", prompt_tsv, postgresql_using=\"gin\"),\n",
" )\n",
"\n",
"engine = create_engine(\"postgresql://postgres:postgres@localhost:5432/postgres\")\n",
"langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)"
],
"metadata": {
"collapsed": false
}
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "0c69d84d", "id": "0c69d84d",

View File

@ -56,18 +56,19 @@ class FullLLMCache(Base): # type: ignore
class SQLAlchemyCache(BaseCache): class SQLAlchemyCache(BaseCache):
"""Cache that uses SQAlchemy as a backend.""" """Cache that uses SQAlchemy as a backend."""
def __init__(self, engine: Engine): def __init__(self, engine: Engine, cache_schema: Any = FullLLMCache):
"""Initialize by creating all tables.""" """Initialize by creating all tables."""
self.engine = engine self.engine = engine
self.cache_schema = cache_schema
Base.metadata.create_all(self.engine) Base.metadata.create_all(self.engine)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string.""" """Look up based on prompt and llm_string."""
stmt = ( stmt = (
select(FullLLMCache.response) select(self.cache_schema.response)
.where(FullLLMCache.prompt == prompt) .where(self.cache_schema.prompt == prompt)
.where(FullLLMCache.llm == llm_string) .where(self.cache_schema.llm == llm_string)
.order_by(FullLLMCache.idx) .order_by(self.cache_schema.idx)
) )
with Session(self.engine) as session: with Session(self.engine) as session:
generations = [] generations = []
@ -80,7 +81,7 @@ class SQLAlchemyCache(BaseCache):
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Look up based on prompt and llm_string.""" """Look up based on prompt and llm_string."""
for i, generation in enumerate(return_val): for i, generation in enumerate(return_val):
item = FullLLMCache( item = self.cache_schema(
prompt=prompt, llm=llm_string, response=generation.text, idx=i prompt=prompt, llm=llm_string, response=generation.text, idx=i
) )
with Session(self.engine) as session, session.begin(): with Session(self.engine) as session, session.begin():