forked from Archives/langchain
Harrison/sqlalchemy cache store (#536)
Co-authored-by: Jason Gill <jasongill@gmail.com>
This commit is contained in:
parent
870cccb877
commit
73f7ebd9d1
@ -276,6 +276,52 @@
|
|||||||
"# langchain.llm_cache = SQLAlchemyCache(engine)"
|
"# langchain.llm_cache = SQLAlchemyCache(engine)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Custom SQLAlchemy Schemas"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# You can define your own declarative SQLAlchemyCache child class to customize the schema used for caching. For example, to support high-speed fulltext prompt indexing with Postgres, use:\n",
|
||||||
|
"\n",
|
||||||
|
"from sqlalchemy import Column, Integer, String, Computed, Index, Sequence\n",
|
||||||
|
"from sqlalchemy import create_engine\n",
|
||||||
|
"from sqlalchemy.ext.declarative import declarative_base\n",
|
||||||
|
"from sqlalchemy_utils import TSVectorType\n",
|
||||||
|
"from langchain.cache import SQLAlchemyCache\n",
|
||||||
|
"\n",
|
||||||
|
"Base = declarative_base()\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class FulltextLLMCache(Base): # type: ignore\n",
|
||||||
|
" \"\"\"Postgres table for fulltext-indexed LLM Cache\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
" __tablename__ = \"llm_cache_fulltext\"\n",
|
||||||
|
" id = Column(Integer, Sequence('cache_id'), primary_key=True)\n",
|
||||||
|
" prompt = Column(String, nullable=False)\n",
|
||||||
|
" llm = Column(String, nullable=False)\n",
|
||||||
|
" idx = Column(Integer)\n",
|
||||||
|
" response = Column(String)\n",
|
||||||
|
" prompt_tsv = Column(TSVectorType(), Computed(\"to_tsvector('english', llm || ' ' || prompt)\", persisted=True))\n",
|
||||||
|
" __table_args__ = (\n",
|
||||||
|
" Index(\"idx_fulltext_prompt_tsv\", prompt_tsv, postgresql_using=\"gin\"),\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"engine = create_engine(\"postgresql://postgres:postgres@localhost:5432/postgres\")\n",
|
||||||
|
"langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "0c69d84d",
|
"id": "0c69d84d",
|
||||||
|
@ -56,18 +56,19 @@ class FullLLMCache(Base): # type: ignore
|
|||||||
class SQLAlchemyCache(BaseCache):
|
class SQLAlchemyCache(BaseCache):
|
||||||
"""Cache that uses SQAlchemy as a backend."""
|
"""Cache that uses SQAlchemy as a backend."""
|
||||||
|
|
||||||
def __init__(self, engine: Engine):
|
def __init__(self, engine: Engine, cache_schema: Any = FullLLMCache):
|
||||||
"""Initialize by creating all tables."""
|
"""Initialize by creating all tables."""
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
|
self.cache_schema = cache_schema
|
||||||
Base.metadata.create_all(self.engine)
|
Base.metadata.create_all(self.engine)
|
||||||
|
|
||||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
"""Look up based on prompt and llm_string."""
|
"""Look up based on prompt and llm_string."""
|
||||||
stmt = (
|
stmt = (
|
||||||
select(FullLLMCache.response)
|
select(self.cache_schema.response)
|
||||||
.where(FullLLMCache.prompt == prompt)
|
.where(self.cache_schema.prompt == prompt)
|
||||||
.where(FullLLMCache.llm == llm_string)
|
.where(self.cache_schema.llm == llm_string)
|
||||||
.order_by(FullLLMCache.idx)
|
.order_by(self.cache_schema.idx)
|
||||||
)
|
)
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
generations = []
|
generations = []
|
||||||
@ -80,7 +81,7 @@ class SQLAlchemyCache(BaseCache):
|
|||||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
"""Look up based on prompt and llm_string."""
|
"""Look up based on prompt and llm_string."""
|
||||||
for i, generation in enumerate(return_val):
|
for i, generation in enumerate(return_val):
|
||||||
item = FullLLMCache(
|
item = self.cache_schema(
|
||||||
prompt=prompt, llm=llm_string, response=generation.text, idx=i
|
prompt=prompt, llm=llm_string, response=generation.text, idx=i
|
||||||
)
|
)
|
||||||
with Session(self.engine) as session, session.begin():
|
with Session(self.engine) as session, session.begin():
|
||||||
|
Loading…
Reference in New Issue
Block a user