community[patch]: Fix SQLAlchemyMd5Cache race condition (#16279)

If the SQLAlchemyMd5Cache is shared among multiple processes, it is
possible to encounter a race condition during the cache update.

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Philippe PRADOS 2024-02-14 20:45:28 +01:00 committed by GitHub
parent 70c296ae96
commit d07db457fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -37,13 +37,14 @@ from typing import (
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy import Column, Integer, String, create_engine, delete, select
from sqlalchemy.engine import Row
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session
@ -1308,37 +1309,33 @@ class SQLAlchemyMd5Cache(BaseCache):
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update based on prompt and llm_string."""
self._delete_previous(prompt, llm_string)
prompt_md5 = self.get_md5(prompt)
items = [
self.cache_schema(
id=str(uuid.uuid1()),
prompt=prompt,
prompt_md5=prompt_md5,
llm=llm_string,
response=dumps(gen),
idx=i,
)
for i, gen in enumerate(return_val)
]
with Session(self.engine) as session, session.begin():
self._delete_previous(session, prompt, llm_string)
prompt_md5 = self.get_md5(prompt)
items = [
self.cache_schema(
id=str(uuid.uuid1()),
prompt=prompt,
prompt_md5=prompt_md5,
llm=llm_string,
response=dumps(gen),
idx=i,
)
for i, gen in enumerate(return_val)
]
for item in items:
session.merge(item)
def _delete_previous(self, prompt: str, llm_string: str) -> None:
def _delete_previous(self, session: Session, prompt: str, llm_string: str) -> None:
stmt = (
select(self.cache_schema.response)
delete(self.cache_schema)
.where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore
.where(self.cache_schema.llm == llm_string)
.where(self.cache_schema.prompt == prompt)
.order_by(self.cache_schema.idx)
)
with Session(self.engine) as session, session.begin():
rows = session.execute(stmt).fetchall()
for item in rows:
session.delete(item)
session.execute(stmt)
def _search_rows(self, prompt: str, llm_string: str) -> List[Row]:
def _search_rows(self, prompt: str, llm_string: str) -> Sequence[Row]:
prompt_pd5 = self.get_md5(prompt)
stmt = (
select(self.cache_schema.response)