mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
community[patch]: Fix SQLAlchemyMd5Cache race condition (#16279)
If the SQLAlchemyMd5Cache is shared among multiple processes, it is possible to encounter a race condition during the cache update. Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
70c296ae96
commit
d07db457fc
@ -37,13 +37,14 @@ from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from sqlalchemy import Column, Integer, String, create_engine, select
|
||||
from sqlalchemy import Column, Integer, String, create_engine, delete, select
|
||||
from sqlalchemy.engine import Row
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
@ -1308,37 +1309,33 @@ class SQLAlchemyMd5Cache(BaseCache):
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update based on prompt and llm_string."""
|
||||
self._delete_previous(prompt, llm_string)
|
||||
prompt_md5 = self.get_md5(prompt)
|
||||
items = [
|
||||
self.cache_schema(
|
||||
id=str(uuid.uuid1()),
|
||||
prompt=prompt,
|
||||
prompt_md5=prompt_md5,
|
||||
llm=llm_string,
|
||||
response=dumps(gen),
|
||||
idx=i,
|
||||
)
|
||||
for i, gen in enumerate(return_val)
|
||||
]
|
||||
with Session(self.engine) as session, session.begin():
|
||||
self._delete_previous(session, prompt, llm_string)
|
||||
prompt_md5 = self.get_md5(prompt)
|
||||
items = [
|
||||
self.cache_schema(
|
||||
id=str(uuid.uuid1()),
|
||||
prompt=prompt,
|
||||
prompt_md5=prompt_md5,
|
||||
llm=llm_string,
|
||||
response=dumps(gen),
|
||||
idx=i,
|
||||
)
|
||||
for i, gen in enumerate(return_val)
|
||||
]
|
||||
for item in items:
|
||||
session.merge(item)
|
||||
|
||||
def _delete_previous(self, prompt: str, llm_string: str) -> None:
|
||||
def _delete_previous(self, session: Session, prompt: str, llm_string: str) -> None:
|
||||
stmt = (
|
||||
select(self.cache_schema.response)
|
||||
delete(self.cache_schema)
|
||||
.where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore
|
||||
.where(self.cache_schema.llm == llm_string)
|
||||
.where(self.cache_schema.prompt == prompt)
|
||||
.order_by(self.cache_schema.idx)
|
||||
)
|
||||
with Session(self.engine) as session, session.begin():
|
||||
rows = session.execute(stmt).fetchall()
|
||||
for item in rows:
|
||||
session.delete(item)
|
||||
session.execute(stmt)
|
||||
|
||||
def _search_rows(self, prompt: str, llm_string: str) -> List[Row]:
|
||||
def _search_rows(self, prompt: str, llm_string: str) -> Sequence[Row]:
|
||||
prompt_pd5 = self.get_md5(prompt)
|
||||
stmt = (
|
||||
select(self.cache_schema.response)
|
||||
|
Loading…
Reference in New Issue
Block a user