community[major]: breaking change in some APIs to force users to opt-in for pickling (#18696)

This is a PR that adds a dangerous load parameter to force users to opt in to use pickle.

This is a PR that's meant to raise user awareness that the pickling module is involved.
pull/18700/head
Eugene Yurtsev 3 months ago committed by GitHub
parent 0e52961562
commit 4c25b49229
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -250,7 +250,6 @@ def _pickle_fn_to_hex_string(fn: Callable) -> str:
class Databricks(LLM):
"""Databricks serving endpoint or a cluster driver proxy app for LLM.
It supports two endpoint types:
@ -374,6 +373,15 @@ class Databricks(LLM):
If not provided, the task is automatically inferred from the endpoint.
"""
allow_dangerous_deserialization: bool = False
"""Whether to allow dangerous deserialization of the data which
involves loading data using pickle.
If the data has been modified by a malicious actor, it can deliver a
malicious payload that results in execution of arbitrary code on the target
machine.
"""
_client: _DatabricksClientBase = PrivateAttr()
class Config:
@ -435,6 +443,16 @@ class Databricks(LLM):
return v
def __init__(self, **data: Any):
if not data.get("allow_dangerous_deserialization"):
raise ValueError(
"This code relies on the pickle module. "
"You will need to set allow_dangerous_deserialization=True "
"if you want to opt-in to allow deserialization of data using pickle."
"Data can be compromised by a malicious actor if "
"not handled properly to include "
"a malicious payload that when deserialized with "
"pickle can execute arbitrary code on your machine."
)
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
data["transform_input_fn"] = _load_pickled_fn_from_hex_string(
data["transform_input_fn"]

@ -137,6 +137,11 @@ class SelfHostedPipeline(LLM):
model_reqs: List[str] = ["./", "torch"]
"""Requirements to install on hardware to inference the model."""
allow_dangerous_deserialization: bool = False
"""Allow deserialization using pickle which can be dangerous if
loading compromised data.
"""
class Config:
"""Configuration for this pydantic object."""
@ -149,6 +154,16 @@ class SelfHostedPipeline(LLM):
and run on the server, i.e. in a module and not a REPL or closure.
Then, initialize the remote inference function.
"""
if not kwargs.get("allow_dangerous_deserialization"):
raise ValueError(
"SelfHostedPipeline relies on the pickle module. "
"You will need to set allow_dangerous_deserialization=True "
"if you want to opt-in to allow deserialization of data using pickle."
"Data can be compromised by a malicious actor if "
"not handled properly to include "
"a malicious payload that when deserialized with "
"pickle can execute arbitrary code. "
)
super().__init__(**kwargs)
try:
import runhouse as rh

@ -429,6 +429,8 @@ class Annoy(VectorStore):
cls,
folder_path: str,
embeddings: Embeddings,
*,
allow_dangerous_deserialization: bool = False,
) -> Annoy:
"""Load Annoy index, docstore, and index_to_docstore_id to disk.
@ -436,7 +438,25 @@ class Annoy(VectorStore):
folder_path: folder path to load index, docstore,
and index_to_docstore_id from.
embeddings: Embeddings to use when generating queries.
allow_dangerous_deserialization: whether to allow deserialization
of the data which involves loading a pickle file.
Pickle files can be modified by malicious actors to deliver a
malicious payload that results in execution of
arbitrary code on your machine.
"""
if not allow_dangerous_deserialization:
raise ValueError(
"The de-serialization relies loading a pickle file. "
"Pickle files can be modified to deliver a malicious payload that "
"results in execution of arbitrary code on your machine."
"You will need to set `allow_dangerous_deserialization` to `True` to "
"enable deserialization. If you do this, make sure that you "
"trust the source of the data. For example, if you are loading a "
"file that you created, and no that no one else has modified the file, "
"then this is safe to do. Do not set this to `True` if you are loading "
"a file from an untrusted source (e.g., some random site on the "
"internet.)."
)
path = Path(folder_path)
# load index separately since it is not picklable
annoy = dependable_annoy_import()

@ -1093,6 +1093,8 @@ class FAISS(VectorStore):
folder_path: str,
embeddings: Embeddings,
index_name: str = "index",
*,
allow_dangerous_deserialization: bool = False,
**kwargs: Any,
) -> FAISS:
"""Load FAISS index, docstore, and index_to_docstore_id from disk.
@ -1102,8 +1104,26 @@ class FAISS(VectorStore):
and index_to_docstore_id from.
embeddings: Embeddings to use when generating queries
index_name: for saving with a specific index file name
allow_dangerous_deserialization: whether to allow deserialization
of the data which involves loading a pickle file.
Pickle files can be modified by malicious actors to deliver a
malicious payload that results in execution of
arbitrary code on your machine.
asynchronous: whether to use async version or not
"""
if not allow_dangerous_deserialization:
raise ValueError(
"The de-serialization relies loading a pickle file. "
"Pickle files can be modified to deliver a malicious payload that "
"results in execution of arbitrary code on your machine."
"You will need to set `allow_dangerous_deserialization` to `True` to "
"enable deserialization. If you do this, make sure that you "
"trust the source of the data. For example, if you are loading a "
"file that you created, and no that no one else has modified the file, "
"then this is safe to do. Do not set this to `True` if you are loading "
"a file from an untrusted source (e.g., some random site on the "
"internet.)."
)
path = Path(folder_path)
# load index separately since it is not picklable
faiss = dependable_faiss_import()

@ -460,6 +460,8 @@ class ScaNN(VectorStore):
folder_path: str,
embedding: Embeddings,
index_name: str = "index",
*,
allow_dangerous_deserialization: bool = False,
**kwargs: Any,
) -> ScaNN:
"""Load ScaNN index, docstore, and index_to_docstore_id from disk.
@ -469,7 +471,25 @@ class ScaNN(VectorStore):
and index_to_docstore_id from.
embeddings: Embeddings to use when generating queries
index_name: for saving with a specific index file name
allow_dangerous_deserialization: whether to allow deserialization
of the data which involves loading a pickle file.
Pickle files can be modified by malicious actors to deliver a
malicious payload that results in execution of
arbitrary code on your machine.
"""
if not allow_dangerous_deserialization:
raise ValueError(
"The de-serialization relies loading a pickle file. "
"Pickle files can be modified to deliver a malicious payload that "
"results in execution of arbitrary code on your machine."
"You will need to set `allow_dangerous_deserialization` to `True` to "
"enable deserialization. If you do this, make sure that you "
"trust the source of the data. For example, if you are loading a "
"file that you created, and no that no one else has modified the file, "
"then this is safe to do. Do not set this to `True` if you are loading "
"a file from an untrusted source (e.g., some random site on the "
"internet.)."
)
path = Path(folder_path)
scann_path = path / "{index_name}.scann".format(index_name=index_name)
scann_path.mkdir(exist_ok=True, parents=True)

@ -87,9 +87,28 @@ class TileDB(VectorStore):
docs_array_uri: str = "",
config: Optional[Mapping[str, Any]] = None,
timestamp: Any = None,
allow_dangerous_deserialization: bool = False,
**kwargs: Any,
):
"""Initialize with necessary components."""
"""Initialize with necessary components.
Args:
allow_dangerous_deserialization: whether to allow deserialization
of the data which involves loading data using pickle.
data can be modified by malicious actors to deliver a
malicious payload that results in execution of
arbitrary code on your machine.
"""
if not allow_dangerous_deserialization:
raise ValueError(
"TileDB relies on pickle for serialization and deserialization. "
"This can be dangerous if the data is intercepted and/or modified "
"by malicious actors prior to being de-serialized. "
"If you are sure that the data is safe from modification, you can "
" set allow_dangerous_deserialization=True to proceed. "
"Loading of compromised data using pickle can result in execution of "
"arbitrary code on your machine."
)
self.embedding = embedding
self.embedding_function = embedding.embed_query
self.index_uri = index_uri

@ -116,7 +116,9 @@ def test_annoy_local_save_load() -> None:
temp_dir = tempfile.TemporaryDirectory()
docsearch.save_local(temp_dir.name)
loaded_docsearch = Annoy.load_local(temp_dir.name, FakeEmbeddings())
loaded_docsearch = Annoy.load_local(
temp_dir.name, FakeEmbeddings(), allow_dangerous_deserialization=True
)
assert docsearch.index_to_docstore_id == loaded_docsearch.index_to_docstore_id
assert docsearch.docstore.__dict__ == loaded_docsearch.docstore.__dict__

@ -252,7 +252,9 @@ def test_scann_local_save_load() -> None:
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
docsearch.save_local(temp_folder)
new_docsearch = ScaNN.load_local(temp_folder, FakeEmbeddings())
new_docsearch = ScaNN.load_local(
temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True
)
assert new_docsearch.index is not None

@ -44,8 +44,9 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
llm = Databricks(
endpoint_name="databricks-mixtral-8x7b-instruct",
endpoint_name="some_end_point_name", # Value should not matter for this test
transform_input_fn=transform_input,
allow_dangerous_deserialization=True,
)
params = llm._default_params
pickled_string = cloudpickle.dumps(transform_input).hex()

@ -608,7 +608,9 @@ def test_faiss_local_save_load() -> None:
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
docsearch.save_local(temp_folder)
new_docsearch = FAISS.load_local(temp_folder, FakeEmbeddings())
new_docsearch = FAISS.load_local(
temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True
)
assert new_docsearch.index is not None
@ -620,7 +622,9 @@ async def test_faiss_async_local_save_load() -> None:
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
docsearch.save_local(temp_folder)
new_docsearch = FAISS.load_local(temp_folder, FakeEmbeddings())
new_docsearch = FAISS.load_local(
temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True
)
assert new_docsearch.index is not None

Loading…
Cancel
Save