diff --git a/libs/community/langchain_community/llms/databricks.py b/libs/community/langchain_community/llms/databricks.py index 9333629f95..c92f8ba221 100644 --- a/libs/community/langchain_community/llms/databricks.py +++ b/libs/community/langchain_community/llms/databricks.py @@ -1,4 +1,6 @@ import os +import pickle +import re import warnings from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Mapping, Optional @@ -212,6 +214,32 @@ def get_default_api_token() -> str: return api_token +def _is_hex_string(data: str) -> bool: + """Checks if a data is a valid hexadecimal string using a regular expression.""" + if not isinstance(data, str): + return False + pattern = r"^[0-9a-fA-F]+$" + return bool(re.match(pattern, data)) + + +def _load_pickled_fn_from_hex_string(data: str) -> Callable: + """Loads a pickled function from a hexadecimal string.""" + try: + return pickle.loads(bytes.fromhex(data)) + except Exception as e: + raise ValueError( + f"Failed to load the pickled function from a hexadecimal string. Error: {e}" + ) + + +def _pickle_fn_to_hex_string(fn: Callable) -> str: + """Pickles a function and returns the hexadecimal string.""" + try: + return pickle.dumps(fn).hex() + except Exception as e: + raise ValueError(f"Failed to pickle the function: {e}") + + class Databricks(LLM): """Databricks serving endpoint or a cluster driver proxy app for LLM. @@ -398,6 +426,17 @@ class Databricks(LLM): return v def __init__(self, **data: Any): + if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]): + data["transform_input_fn"] = _load_pickled_fn_from_hex_string( + data["transform_input_fn"] + ) + if "transform_output_fn" in data and _is_hex_string( + data["transform_output_fn"] + ): + data["transform_output_fn"] = _load_pickled_fn_from_hex_string( + data["transform_output_fn"] + ) + super().__init__(**data) if self.model_kwargs is not None and self.extra_params is not None: raise ValueError("Cannot set both extra_params and extra_params.") @@ -443,9 +482,12 @@ class Databricks(LLM): "max_tokens": self.max_tokens, "extra_params": self.extra_params, "task": self.task, - # TODO: Support saving transform_input_fn and transform_output_fn - # "transform_input_fn": self.transform_input_fn, - # "transform_output_fn": self.transform_output_fn, + "transform_input_fn": None + if self.transform_input_fn is None + else _pickle_fn_to_hex_string(self.transform_input_fn), + "transform_output_fn": None + if self.transform_output_fn is None + else _pickle_fn_to_hex_string(self.transform_output_fn), } @property diff --git a/libs/community/tests/unit_tests/llms/test_databricks.py b/libs/community/tests/unit_tests/llms/test_databricks.py new file mode 100644 index 0000000000..7d3809e270 --- /dev/null +++ b/libs/community/tests/unit_tests/llms/test_databricks.py @@ -0,0 +1,46 @@ +"""test Databricks LLM""" +import pickle +from typing import Any, Dict + +from pytest import MonkeyPatch + +from langchain_community.llms.databricks import Databricks + + +class MockDatabricksServingEndpointClient: + def __init__( + self, + host: str, + api_token: str, + endpoint_name: str, + databricks_uri: str, + task: str, + ): + self.host = host + self.api_token = api_token + self.endpoint_name = endpoint_name + self.databricks_uri = databricks_uri + self.task = task + + +def transform_input(**request: Any) -> Dict[str, Any]: + request["messages"] = [{"role": "user", "content": request["prompt"]}] + del request["prompt"] + return request + + +def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setattr( + "langchain_community.llms.databricks._DatabricksServingEndpointClient", + MockDatabricksServingEndpointClient, + ) + monkeypatch.setenv("DATABRICKS_HOST", "my-default-host") + monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token") + + llm = Databricks( + endpoint_name="databricks-mixtral-8x7b-instruct", + transform_input_fn=transform_input, + ) + params = llm._default_params + pickled_string = pickle.dumps(transform_input).hex() + assert params["transform_input_fn"] == pickled_string