From 7306600e2f1b8a3cb31b37693f74ebc0d03efd8a Mon Sep 17 00:00:00 2001 From: Liang Zhang Date: Thu, 8 Feb 2024 13:09:50 -0800 Subject: [PATCH] community[patch]: Support SerDe transform functions in Databricks LLM (#16752) **Description:** Databricks LLM does not support SerDe the transform_input_fn and transform_output_fn. After saving and loading, the LLM will be broken. This PR serialize these functions into a hex string using pickle, and saving the hex string in the yaml file. Using pickle to serialize a function can be flaky, but this is a simple workaround that unblocks many use cases. If more sophisticated SerDe is needed, we can improve it later. Test: Added a simple unit test. I did manual test on Databricks and it works well. The saved yaml looks like: ``` llm: _type: databricks cluster_driver_port: null cluster_id: null databricks_uri: databricks endpoint_name: databricks-mixtral-8x7b-instruct extra_params: {} host: e2-dogfood.staging.cloud.databricks.com max_tokens: null model_kwargs: null n: 1 stop: null task: null temperature: 0.0 transform_input_fn: 80049520000000000000008c085f5f6d61696e5f5f948c0f7472616e73666f726d5f696e7075749493942e transform_output_fn: null ``` @baskaryan ```python from langchain_community.embeddings import DatabricksEmbeddings from langchain_community.llms import Databricks from langchain.chains import RetrievalQA from langchain.document_loaders import TextLoader from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import FAISS import mlflow embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en") def transform_input(**request): request["messages"] = [ { "role": "user", "content": request["prompt"] } ] del request["prompt"] return request llm = Databricks(endpoint_name="databricks-mixtral-8x7b-instruct", transform_input_fn=transform_input) persist_dir = "faiss_databricks_embedding" # Create the vector db, persist the db to a local fs folder loader = TextLoader("state_of_the_union.txt") documents = loader.load() text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) docs = text_splitter.split_documents(documents) db = FAISS.from_documents(docs, embeddings) db.save_local(persist_dir) def load_retriever(persist_directory): embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en") vectorstore = FAISS.load_local(persist_directory, embeddings) return vectorstore.as_retriever() retriever = load_retriever(persist_dir) retrievalQA = RetrievalQA.from_llm(llm=llm, retriever=retriever) with mlflow.start_run() as run: logged_model = mlflow.langchain.log_model( retrievalQA, artifact_path="retrieval_qa", loader_fn=load_retriever, persist_dir=persist_dir, ) # Load the retrievalQA chain loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri) print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}])) ``` --- .../langchain_community/llms/databricks.py | 48 +++++++++++++++++-- .../tests/unit_tests/llms/test_databricks.py | 46 ++++++++++++++++++ 2 files changed, 91 insertions(+), 3 deletions(-) create mode 100644 libs/community/tests/unit_tests/llms/test_databricks.py diff --git a/libs/community/langchain_community/llms/databricks.py b/libs/community/langchain_community/llms/databricks.py index 9333629f95..c92f8ba221 100644 --- a/libs/community/langchain_community/llms/databricks.py +++ b/libs/community/langchain_community/llms/databricks.py @@ -1,4 +1,6 @@ import os +import pickle +import re import warnings from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Mapping, Optional @@ -212,6 +214,32 @@ def get_default_api_token() -> str: return api_token +def _is_hex_string(data: str) -> bool: + """Checks if a data is a valid hexadecimal string using a regular expression.""" + if not isinstance(data, str): + return False + pattern = r"^[0-9a-fA-F]+$" + return bool(re.match(pattern, data)) + + +def _load_pickled_fn_from_hex_string(data: str) -> Callable: + """Loads a pickled function from a hexadecimal string.""" + try: + return pickle.loads(bytes.fromhex(data)) + except Exception as e: + raise ValueError( + f"Failed to load the pickled function from a hexadecimal string. Error: {e}" + ) + + +def _pickle_fn_to_hex_string(fn: Callable) -> str: + """Pickles a function and returns the hexadecimal string.""" + try: + return pickle.dumps(fn).hex() + except Exception as e: + raise ValueError(f"Failed to pickle the function: {e}") + + class Databricks(LLM): """Databricks serving endpoint or a cluster driver proxy app for LLM. @@ -398,6 +426,17 @@ class Databricks(LLM): return v def __init__(self, **data: Any): + if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]): + data["transform_input_fn"] = _load_pickled_fn_from_hex_string( + data["transform_input_fn"] + ) + if "transform_output_fn" in data and _is_hex_string( + data["transform_output_fn"] + ): + data["transform_output_fn"] = _load_pickled_fn_from_hex_string( + data["transform_output_fn"] + ) + super().__init__(**data) if self.model_kwargs is not None and self.extra_params is not None: raise ValueError("Cannot set both extra_params and extra_params.") @@ -443,9 +482,12 @@ class Databricks(LLM): "max_tokens": self.max_tokens, "extra_params": self.extra_params, "task": self.task, - # TODO: Support saving transform_input_fn and transform_output_fn - # "transform_input_fn": self.transform_input_fn, - # "transform_output_fn": self.transform_output_fn, + "transform_input_fn": None + if self.transform_input_fn is None + else _pickle_fn_to_hex_string(self.transform_input_fn), + "transform_output_fn": None + if self.transform_output_fn is None + else _pickle_fn_to_hex_string(self.transform_output_fn), } @property diff --git a/libs/community/tests/unit_tests/llms/test_databricks.py b/libs/community/tests/unit_tests/llms/test_databricks.py new file mode 100644 index 0000000000..7d3809e270 --- /dev/null +++ b/libs/community/tests/unit_tests/llms/test_databricks.py @@ -0,0 +1,46 @@ +"""test Databricks LLM""" +import pickle +from typing import Any, Dict + +from pytest import MonkeyPatch + +from langchain_community.llms.databricks import Databricks + + +class MockDatabricksServingEndpointClient: + def __init__( + self, + host: str, + api_token: str, + endpoint_name: str, + databricks_uri: str, + task: str, + ): + self.host = host + self.api_token = api_token + self.endpoint_name = endpoint_name + self.databricks_uri = databricks_uri + self.task = task + + +def transform_input(**request: Any) -> Dict[str, Any]: + request["messages"] = [{"role": "user", "content": request["prompt"]}] + del request["prompt"] + return request + + +def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setattr( + "langchain_community.llms.databricks._DatabricksServingEndpointClient", + MockDatabricksServingEndpointClient, + ) + monkeypatch.setenv("DATABRICKS_HOST", "my-default-host") + monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token") + + llm = Databricks( + endpoint_name="databricks-mixtral-8x7b-instruct", + transform_input_fn=transform_input, + ) + params = llm._default_params + pickled_string = pickle.dumps(transform_input).hex() + assert params["transform_input_fn"] == pickled_string