diff --git a/libs/community/langchain_community/llms/databricks.py b/libs/community/langchain_community/llms/databricks.py index 3b0bca7edb..715530f7a7 100644 --- a/libs/community/langchain_community/llms/databricks.py +++ b/libs/community/langchain_community/llms/databricks.py @@ -250,7 +250,6 @@ def _pickle_fn_to_hex_string(fn: Callable) -> str: class Databricks(LLM): - """Databricks serving endpoint or a cluster driver proxy app for LLM. It supports two endpoint types: @@ -374,6 +373,15 @@ class Databricks(LLM): If not provided, the task is automatically inferred from the endpoint. """ + allow_dangerous_deserialization: bool = False + """Whether to allow dangerous deserialization of the data which + involves loading data using pickle. + + If the data has been modified by a malicious actor, it can deliver a + malicious payload that results in execution of arbitrary code on the target + machine. + """ + _client: _DatabricksClientBase = PrivateAttr() class Config: @@ -435,6 +443,16 @@ class Databricks(LLM): return v def __init__(self, **data: Any): + if not data.get("allow_dangerous_deserialization"): + raise ValueError( + "This code relies on the pickle module. " + "You will need to set allow_dangerous_deserialization=True " + "if you want to opt-in to allow deserialization of data using pickle." + "Data can be compromised by a malicious actor if " + "not handled properly to include " + "a malicious payload that when deserialized with " + "pickle can execute arbitrary code on your machine." + ) if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]): data["transform_input_fn"] = _load_pickled_fn_from_hex_string( data["transform_input_fn"] diff --git a/libs/community/langchain_community/llms/self_hosted.py b/libs/community/langchain_community/llms/self_hosted.py index 043ffca136..e94f2d4054 100644 --- a/libs/community/langchain_community/llms/self_hosted.py +++ b/libs/community/langchain_community/llms/self_hosted.py @@ -137,6 +137,11 @@ class SelfHostedPipeline(LLM): model_reqs: List[str] = ["./", "torch"] """Requirements to install on hardware to inference the model.""" + allow_dangerous_deserialization: bool = False + """Allow deserialization using pickle which can be dangerous if + loading compromised data. + """ + class Config: """Configuration for this pydantic object.""" @@ -149,6 +154,16 @@ class SelfHostedPipeline(LLM): and run on the server, i.e. in a module and not a REPL or closure. Then, initialize the remote inference function. """ + if not kwargs.get("allow_dangerous_deserialization"): + raise ValueError( + "SelfHostedPipeline relies on the pickle module. " + "You will need to set allow_dangerous_deserialization=True " + "if you want to opt-in to allow deserialization of data using pickle." + "Data can be compromised by a malicious actor if " + "not handled properly to include " + "a malicious payload that when deserialized with " + "pickle can execute arbitrary code. " + ) super().__init__(**kwargs) try: import runhouse as rh diff --git a/libs/community/langchain_community/vectorstores/annoy.py b/libs/community/langchain_community/vectorstores/annoy.py index b797fcf9bd..d35def219e 100644 --- a/libs/community/langchain_community/vectorstores/annoy.py +++ b/libs/community/langchain_community/vectorstores/annoy.py @@ -429,6 +429,8 @@ class Annoy(VectorStore): cls, folder_path: str, embeddings: Embeddings, + *, + allow_dangerous_deserialization: bool = False, ) -> Annoy: """Load Annoy index, docstore, and index_to_docstore_id to disk. @@ -436,7 +438,25 @@ class Annoy(VectorStore): folder_path: folder path to load index, docstore, and index_to_docstore_id from. embeddings: Embeddings to use when generating queries. + allow_dangerous_deserialization: whether to allow deserialization + of the data which involves loading a pickle file. + Pickle files can be modified by malicious actors to deliver a + malicious payload that results in execution of + arbitrary code on your machine. """ + if not allow_dangerous_deserialization: + raise ValueError( + "The de-serialization relies loading a pickle file. " + "Pickle files can be modified to deliver a malicious payload that " + "results in execution of arbitrary code on your machine." + "You will need to set `allow_dangerous_deserialization` to `True` to " + "enable deserialization. If you do this, make sure that you " + "trust the source of the data. For example, if you are loading a " + "file that you created, and no that no one else has modified the file, " + "then this is safe to do. Do not set this to `True` if you are loading " + "a file from an untrusted source (e.g., some random site on the " + "internet.)." + ) path = Path(folder_path) # load index separately since it is not picklable annoy = dependable_annoy_import() diff --git a/libs/community/langchain_community/vectorstores/faiss.py b/libs/community/langchain_community/vectorstores/faiss.py index 4a7d65f912..32286bb178 100644 --- a/libs/community/langchain_community/vectorstores/faiss.py +++ b/libs/community/langchain_community/vectorstores/faiss.py @@ -1093,6 +1093,8 @@ class FAISS(VectorStore): folder_path: str, embeddings: Embeddings, index_name: str = "index", + *, + allow_dangerous_deserialization: bool = False, **kwargs: Any, ) -> FAISS: """Load FAISS index, docstore, and index_to_docstore_id from disk. @@ -1102,8 +1104,26 @@ class FAISS(VectorStore): and index_to_docstore_id from. embeddings: Embeddings to use when generating queries index_name: for saving with a specific index file name + allow_dangerous_deserialization: whether to allow deserialization + of the data which involves loading a pickle file. + Pickle files can be modified by malicious actors to deliver a + malicious payload that results in execution of + arbitrary code on your machine. asynchronous: whether to use async version or not """ + if not allow_dangerous_deserialization: + raise ValueError( + "The de-serialization relies loading a pickle file. " + "Pickle files can be modified to deliver a malicious payload that " + "results in execution of arbitrary code on your machine." + "You will need to set `allow_dangerous_deserialization` to `True` to " + "enable deserialization. If you do this, make sure that you " + "trust the source of the data. For example, if you are loading a " + "file that you created, and no that no one else has modified the file, " + "then this is safe to do. Do not set this to `True` if you are loading " + "a file from an untrusted source (e.g., some random site on the " + "internet.)." + ) path = Path(folder_path) # load index separately since it is not picklable faiss = dependable_faiss_import() diff --git a/libs/community/langchain_community/vectorstores/scann.py b/libs/community/langchain_community/vectorstores/scann.py index 11be67dcaa..67fc46096e 100644 --- a/libs/community/langchain_community/vectorstores/scann.py +++ b/libs/community/langchain_community/vectorstores/scann.py @@ -460,6 +460,8 @@ class ScaNN(VectorStore): folder_path: str, embedding: Embeddings, index_name: str = "index", + *, + allow_dangerous_deserialization: bool = False, **kwargs: Any, ) -> ScaNN: """Load ScaNN index, docstore, and index_to_docstore_id from disk. @@ -469,7 +471,25 @@ class ScaNN(VectorStore): and index_to_docstore_id from. embeddings: Embeddings to use when generating queries index_name: for saving with a specific index file name + allow_dangerous_deserialization: whether to allow deserialization + of the data which involves loading a pickle file. + Pickle files can be modified by malicious actors to deliver a + malicious payload that results in execution of + arbitrary code on your machine. """ + if not allow_dangerous_deserialization: + raise ValueError( + "The de-serialization relies loading a pickle file. " + "Pickle files can be modified to deliver a malicious payload that " + "results in execution of arbitrary code on your machine." + "You will need to set `allow_dangerous_deserialization` to `True` to " + "enable deserialization. If you do this, make sure that you " + "trust the source of the data. For example, if you are loading a " + "file that you created, and no that no one else has modified the file, " + "then this is safe to do. Do not set this to `True` if you are loading " + "a file from an untrusted source (e.g., some random site on the " + "internet.)." + ) path = Path(folder_path) scann_path = path / "{index_name}.scann".format(index_name=index_name) scann_path.mkdir(exist_ok=True, parents=True) diff --git a/libs/community/langchain_community/vectorstores/tiledb.py b/libs/community/langchain_community/vectorstores/tiledb.py index 48d79c5a5e..3b4a5cb5b3 100644 --- a/libs/community/langchain_community/vectorstores/tiledb.py +++ b/libs/community/langchain_community/vectorstores/tiledb.py @@ -87,9 +87,28 @@ class TileDB(VectorStore): docs_array_uri: str = "", config: Optional[Mapping[str, Any]] = None, timestamp: Any = None, + allow_dangerous_deserialization: bool = False, **kwargs: Any, ): - """Initialize with necessary components.""" + """Initialize with necessary components. + + Args: + allow_dangerous_deserialization: whether to allow deserialization + of the data which involves loading data using pickle. + data can be modified by malicious actors to deliver a + malicious payload that results in execution of + arbitrary code on your machine. + """ + if not allow_dangerous_deserialization: + raise ValueError( + "TileDB relies on pickle for serialization and deserialization. " + "This can be dangerous if the data is intercepted and/or modified " + "by malicious actors prior to being de-serialized. " + "If you are sure that the data is safe from modification, you can " + " set allow_dangerous_deserialization=True to proceed. " + "Loading of compromised data using pickle can result in execution of " + "arbitrary code on your machine." + ) self.embedding = embedding self.embedding_function = embedding.embed_query self.index_uri = index_uri diff --git a/libs/community/tests/integration_tests/vectorstores/test_annoy.py b/libs/community/tests/integration_tests/vectorstores/test_annoy.py index 600f716667..1950dcd277 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_annoy.py +++ b/libs/community/tests/integration_tests/vectorstores/test_annoy.py @@ -116,7 +116,9 @@ def test_annoy_local_save_load() -> None: temp_dir = tempfile.TemporaryDirectory() docsearch.save_local(temp_dir.name) - loaded_docsearch = Annoy.load_local(temp_dir.name, FakeEmbeddings()) + loaded_docsearch = Annoy.load_local( + temp_dir.name, FakeEmbeddings(), allow_dangerous_deserialization=True + ) assert docsearch.index_to_docstore_id == loaded_docsearch.index_to_docstore_id assert docsearch.docstore.__dict__ == loaded_docsearch.docstore.__dict__ diff --git a/libs/community/tests/integration_tests/vectorstores/test_scann.py b/libs/community/tests/integration_tests/vectorstores/test_scann.py index f945b003a6..b060df5adf 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_scann.py +++ b/libs/community/tests/integration_tests/vectorstores/test_scann.py @@ -252,7 +252,9 @@ def test_scann_local_save_load() -> None: temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S") with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder: docsearch.save_local(temp_folder) - new_docsearch = ScaNN.load_local(temp_folder, FakeEmbeddings()) + new_docsearch = ScaNN.load_local( + temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True + ) assert new_docsearch.index is not None diff --git a/libs/community/tests/unit_tests/llms/test_databricks.py b/libs/community/tests/unit_tests/llms/test_databricks.py index 6b1a3d9881..fe93101d68 100644 --- a/libs/community/tests/unit_tests/llms/test_databricks.py +++ b/libs/community/tests/unit_tests/llms/test_databricks.py @@ -44,8 +44,9 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None: monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token") llm = Databricks( - endpoint_name="databricks-mixtral-8x7b-instruct", + endpoint_name="some_end_point_name", # Value should not matter for this test transform_input_fn=transform_input, + allow_dangerous_deserialization=True, ) params = llm._default_params pickled_string = cloudpickle.dumps(transform_input).hex() diff --git a/libs/community/tests/unit_tests/vectorstores/test_faiss.py b/libs/community/tests/unit_tests/vectorstores/test_faiss.py index cedecc8ada..e6a32176c1 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_faiss.py +++ b/libs/community/tests/unit_tests/vectorstores/test_faiss.py @@ -608,7 +608,9 @@ def test_faiss_local_save_load() -> None: temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S") with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder: docsearch.save_local(temp_folder) - new_docsearch = FAISS.load_local(temp_folder, FakeEmbeddings()) + new_docsearch = FAISS.load_local( + temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True + ) assert new_docsearch.index is not None @@ -620,7 +622,9 @@ async def test_faiss_async_local_save_load() -> None: temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S") with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder: docsearch.save_local(temp_folder) - new_docsearch = FAISS.load_local(temp_folder, FakeEmbeddings()) + new_docsearch = FAISS.load_local( + temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True + ) assert new_docsearch.index is not None