community[patch]: Support SerDe transform functions in Databricks LLM (#16752)

**Description:** Databricks LLM does not support SerDe the
transform_input_fn and transform_output_fn. After saving and loading,
the LLM will be broken. This PR serialize these functions into a hex
string using pickle, and saving the hex string in the yaml file. Using
pickle to serialize a function can be flaky, but this is a simple
workaround that unblocks many use cases. If more sophisticated SerDe is
needed, we can improve it later.

Test:
Added a simple unit test.
I did manual test on Databricks and it works well.
The saved yaml looks like:
```
llm:
      _type: databricks
      cluster_driver_port: null
      cluster_id: null
      databricks_uri: databricks
      endpoint_name: databricks-mixtral-8x7b-instruct
      extra_params: {}
      host: e2-dogfood.staging.cloud.databricks.com
      max_tokens: null
      model_kwargs: null
      n: 1
      stop: null
      task: null
      temperature: 0.0
      transform_input_fn: 80049520000000000000008c085f5f6d61696e5f5f948c0f7472616e73666f726d5f696e7075749493942e
      transform_output_fn: null
```

@baskaryan

```python
from langchain_community.embeddings import DatabricksEmbeddings
from langchain_community.llms import Databricks
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
import mlflow

embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")

def transform_input(**request):
  request["messages"] = [
    {
      "role": "user",
      "content": request["prompt"]
    }
  ]
  del request["prompt"]
  return request

llm = Databricks(endpoint_name="databricks-mixtral-8x7b-instruct", transform_input_fn=transform_input)

persist_dir = "faiss_databricks_embedding"

# Create the vector db, persist the db to a local fs folder
loader = TextLoader("state_of_the_union.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
db = FAISS.from_documents(docs, embeddings)
db.save_local(persist_dir)

def load_retriever(persist_directory):
    embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
    vectorstore = FAISS.load_local(persist_directory, embeddings)
    return vectorstore.as_retriever()

retriever = load_retriever(persist_dir)
retrievalQA = RetrievalQA.from_llm(llm=llm, retriever=retriever)
with mlflow.start_run() as run:
    logged_model = mlflow.langchain.log_model(
        retrievalQA,
        artifact_path="retrieval_qa",
        loader_fn=load_retriever,
        persist_dir=persist_dir,
    )

# Load the retrievalQA chain
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}]))

```
pull/17263/head
Liang Zhang 5 months ago committed by GitHub
parent ce22e10c4b
commit 7306600e2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,4 +1,6 @@
import os import os
import pickle
import re
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Optional from typing import Any, Callable, Dict, List, Mapping, Optional
@ -212,6 +214,32 @@ def get_default_api_token() -> str:
return api_token return api_token
def _is_hex_string(data: str) -> bool:
"""Checks if a data is a valid hexadecimal string using a regular expression."""
if not isinstance(data, str):
return False
pattern = r"^[0-9a-fA-F]+$"
return bool(re.match(pattern, data))
def _load_pickled_fn_from_hex_string(data: str) -> Callable:
"""Loads a pickled function from a hexadecimal string."""
try:
return pickle.loads(bytes.fromhex(data))
except Exception as e:
raise ValueError(
f"Failed to load the pickled function from a hexadecimal string. Error: {e}"
)
def _pickle_fn_to_hex_string(fn: Callable) -> str:
"""Pickles a function and returns the hexadecimal string."""
try:
return pickle.dumps(fn).hex()
except Exception as e:
raise ValueError(f"Failed to pickle the function: {e}")
class Databricks(LLM): class Databricks(LLM):
"""Databricks serving endpoint or a cluster driver proxy app for LLM. """Databricks serving endpoint or a cluster driver proxy app for LLM.
@ -398,6 +426,17 @@ class Databricks(LLM):
return v return v
def __init__(self, **data: Any): def __init__(self, **data: Any):
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
data["transform_input_fn"] = _load_pickled_fn_from_hex_string(
data["transform_input_fn"]
)
if "transform_output_fn" in data and _is_hex_string(
data["transform_output_fn"]
):
data["transform_output_fn"] = _load_pickled_fn_from_hex_string(
data["transform_output_fn"]
)
super().__init__(**data) super().__init__(**data)
if self.model_kwargs is not None and self.extra_params is not None: if self.model_kwargs is not None and self.extra_params is not None:
raise ValueError("Cannot set both extra_params and extra_params.") raise ValueError("Cannot set both extra_params and extra_params.")
@ -443,9 +482,12 @@ class Databricks(LLM):
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
"extra_params": self.extra_params, "extra_params": self.extra_params,
"task": self.task, "task": self.task,
# TODO: Support saving transform_input_fn and transform_output_fn "transform_input_fn": None
# "transform_input_fn": self.transform_input_fn, if self.transform_input_fn is None
# "transform_output_fn": self.transform_output_fn, else _pickle_fn_to_hex_string(self.transform_input_fn),
"transform_output_fn": None
if self.transform_output_fn is None
else _pickle_fn_to_hex_string(self.transform_output_fn),
} }
@property @property

@ -0,0 +1,46 @@
"""test Databricks LLM"""
import pickle
from typing import Any, Dict
from pytest import MonkeyPatch
from langchain_community.llms.databricks import Databricks
class MockDatabricksServingEndpointClient:
def __init__(
self,
host: str,
api_token: str,
endpoint_name: str,
databricks_uri: str,
task: str,
):
self.host = host
self.api_token = api_token
self.endpoint_name = endpoint_name
self.databricks_uri = databricks_uri
self.task = task
def transform_input(**request: Any) -> Dict[str, Any]:
request["messages"] = [{"role": "user", "content": request["prompt"]}]
del request["prompt"]
return request
def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(
"langchain_community.llms.databricks._DatabricksServingEndpointClient",
MockDatabricksServingEndpointClient,
)
monkeypatch.setenv("DATABRICKS_HOST", "my-default-host")
monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
llm = Databricks(
endpoint_name="databricks-mixtral-8x7b-instruct",
transform_input_fn=transform_input,
)
params = llm._default_params
pickled_string = pickle.dumps(transform_input).hex()
assert params["transform_input_fn"] == pickled_string
Loading…
Cancel
Save