|
|
|
@ -56,7 +56,10 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
|
|
|
|
|
assert params["transform_input_fn"] == pickled_string
|
|
|
|
|
|
|
|
|
|
request = {"prompt": "What is the meaning of life?"}
|
|
|
|
|
fn = _load_pickled_fn_from_hex_string(params["transform_input_fn"])
|
|
|
|
|
fn = _load_pickled_fn_from_hex_string(
|
|
|
|
|
data=params["transform_input_fn"],
|
|
|
|
|
allow_dangerous_deserialization=True,
|
|
|
|
|
)
|
|
|
|
|
assert fn(**request) == transform_input(**request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -69,15 +72,44 @@ def test_saving_loading_llm(monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
|
|
|
|
|
monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
|
|
|
|
|
|
|
|
|
|
llm = Databricks(
|
|
|
|
|
endpoint_name="chat", temperature=0.1, allow_dangerous_deserialization=True
|
|
|
|
|
endpoint_name="chat",
|
|
|
|
|
temperature=0.1,
|
|
|
|
|
)
|
|
|
|
|
llm.save(file_path=tmp_path / "databricks.yaml")
|
|
|
|
|
|
|
|
|
|
# Loading without allowing_dangerous_deserialization=True should raise an error.
|
|
|
|
|
loaded_llm = load_llm(tmp_path / "databricks.yaml")
|
|
|
|
|
assert_llm_equality(llm, loaded_llm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("cloudpickle")
|
|
|
|
|
def test_saving_loading_llm_dangerous_serde_check(
|
|
|
|
|
monkeypatch: MonkeyPatch, tmp_path: Path
|
|
|
|
|
) -> None:
|
|
|
|
|
monkeypatch.setattr(
|
|
|
|
|
"langchain_community.llms.databricks._DatabricksServingEndpointClient",
|
|
|
|
|
MockDatabricksServingEndpointClient,
|
|
|
|
|
)
|
|
|
|
|
monkeypatch.setenv("DATABRICKS_HOST", "my-default-host")
|
|
|
|
|
monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
|
|
|
|
|
|
|
|
|
|
llm1 = Databricks(
|
|
|
|
|
endpoint_name="chat",
|
|
|
|
|
temperature=0.1,
|
|
|
|
|
transform_input_fn=lambda x, y, **kwargs: {},
|
|
|
|
|
)
|
|
|
|
|
llm1.save(file_path=tmp_path / "databricks1.yaml")
|
|
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match="This code relies on the pickle module."):
|
|
|
|
|
load_llm(tmp_path / "databricks.yaml")
|
|
|
|
|
load_llm(tmp_path / "databricks1.yaml")
|
|
|
|
|
|
|
|
|
|
loaded_llm = load_llm(
|
|
|
|
|
tmp_path / "databricks.yaml", allow_dangerous_deserialization=True
|
|
|
|
|
load_llm(tmp_path / "databricks1.yaml", allow_dangerous_deserialization=True)
|
|
|
|
|
|
|
|
|
|
llm2 = Databricks(
|
|
|
|
|
endpoint_name="chat", temperature=0.1, transform_output_fn=lambda x: "test"
|
|
|
|
|
)
|
|
|
|
|
assert_llm_equality(llm, loaded_llm)
|
|
|
|
|
llm2.save(file_path=tmp_path / "databricks2.yaml")
|
|
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match="This code relies on the pickle module."):
|
|
|
|
|
load_llm(tmp_path / "databricks2.yaml")
|
|
|
|
|
|
|
|
|
|
load_llm(tmp_path / "databricks2.yaml", allow_dangerous_deserialization=True)
|
|
|
|
|