community[patch]: Support SerDe transform functions in Databricks LLM (#16752)

**Description:** Databricks LLM does not support SerDe the
transform_input_fn and transform_output_fn. After saving and loading,
the LLM will be broken. This PR serialize these functions into a hex
string using pickle, and saving the hex string in the yaml file. Using
pickle to serialize a function can be flaky, but this is a simple
workaround that unblocks many use cases. If more sophisticated SerDe is
needed, we can improve it later.

Test:
Added a simple unit test.
I did manual test on Databricks and it works well.
The saved yaml looks like:
```
llm:
      _type: databricks
      cluster_driver_port: null
      cluster_id: null
      databricks_uri: databricks
      endpoint_name: databricks-mixtral-8x7b-instruct
      extra_params: {}
      host: e2-dogfood.staging.cloud.databricks.com
      max_tokens: null
      model_kwargs: null
      n: 1
      stop: null
      task: null
      temperature: 0.0
      transform_input_fn: 80049520000000000000008c085f5f6d61696e5f5f948c0f7472616e73666f726d5f696e7075749493942e
      transform_output_fn: null
```

@baskaryan

```python
from langchain_community.embeddings import DatabricksEmbeddings
from langchain_community.llms import Databricks
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
import mlflow

embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")

def transform_input(**request):
  request["messages"] = [
    {
      "role": "user",
      "content": request["prompt"]
    }
  ]
  del request["prompt"]
  return request

llm = Databricks(endpoint_name="databricks-mixtral-8x7b-instruct", transform_input_fn=transform_input)

persist_dir = "faiss_databricks_embedding"

# Create the vector db, persist the db to a local fs folder
loader = TextLoader("state_of_the_union.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
db = FAISS.from_documents(docs, embeddings)
db.save_local(persist_dir)

def load_retriever(persist_directory):
    embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
    vectorstore = FAISS.load_local(persist_directory, embeddings)
    return vectorstore.as_retriever()

retriever = load_retriever(persist_dir)
retrievalQA = RetrievalQA.from_llm(llm=llm, retriever=retriever)
with mlflow.start_run() as run:
    logged_model = mlflow.langchain.log_model(
        retrievalQA,
        artifact_path="retrieval_qa",
        loader_fn=load_retriever,
        persist_dir=persist_dir,
    )

# Load the retrievalQA chain
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}]))

```
pull/17263/head
Liang Zhang 5 months ago committed by GitHub
parent ce22e10c4b
commit 7306600e2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,4 +1,6 @@
import os
import pickle
import re
import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Optional
@ -212,6 +214,32 @@ def get_default_api_token() -> str:
return api_token
def _is_hex_string(data: str) -> bool:
"""Checks if a data is a valid hexadecimal string using a regular expression."""
if not isinstance(data, str):
return False
pattern = r"^[0-9a-fA-F]+$"
return bool(re.match(pattern, data))
def _load_pickled_fn_from_hex_string(data: str) -> Callable:
"""Loads a pickled function from a hexadecimal string."""
try:
return pickle.loads(bytes.fromhex(data))
except Exception as e:
raise ValueError(
f"Failed to load the pickled function from a hexadecimal string. Error: {e}"
)
def _pickle_fn_to_hex_string(fn: Callable) -> str:
"""Pickles a function and returns the hexadecimal string."""
try:
return pickle.dumps(fn).hex()
except Exception as e:
raise ValueError(f"Failed to pickle the function: {e}")
class Databricks(LLM):
"""Databricks serving endpoint or a cluster driver proxy app for LLM.
@ -398,6 +426,17 @@ class Databricks(LLM):
return v
def __init__(self, **data: Any):
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
data["transform_input_fn"] = _load_pickled_fn_from_hex_string(
data["transform_input_fn"]
)
if "transform_output_fn" in data and _is_hex_string(
data["transform_output_fn"]
):
data["transform_output_fn"] = _load_pickled_fn_from_hex_string(
data["transform_output_fn"]
)
super().__init__(**data)
if self.model_kwargs is not None and self.extra_params is not None:
raise ValueError("Cannot set both extra_params and extra_params.")
@ -443,9 +482,12 @@ class Databricks(LLM):
"max_tokens": self.max_tokens,
"extra_params": self.extra_params,
"task": self.task,
# TODO: Support saving transform_input_fn and transform_output_fn
# "transform_input_fn": self.transform_input_fn,
# "transform_output_fn": self.transform_output_fn,
"transform_input_fn": None
if self.transform_input_fn is None
else _pickle_fn_to_hex_string(self.transform_input_fn),
"transform_output_fn": None
if self.transform_output_fn is None
else _pickle_fn_to_hex_string(self.transform_output_fn),
}
@property

@ -0,0 +1,46 @@
"""test Databricks LLM"""
import pickle
from typing import Any, Dict
from pytest import MonkeyPatch
from langchain_community.llms.databricks import Databricks
class MockDatabricksServingEndpointClient:
def __init__(
self,
host: str,
api_token: str,
endpoint_name: str,
databricks_uri: str,
task: str,
):
self.host = host
self.api_token = api_token
self.endpoint_name = endpoint_name
self.databricks_uri = databricks_uri
self.task = task
def transform_input(**request: Any) -> Dict[str, Any]:
request["messages"] = [{"role": "user", "content": request["prompt"]}]
del request["prompt"]
return request
def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(
"langchain_community.llms.databricks._DatabricksServingEndpointClient",
MockDatabricksServingEndpointClient,
)
monkeypatch.setenv("DATABRICKS_HOST", "my-default-host")
monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
llm = Databricks(
endpoint_name="databricks-mixtral-8x7b-instruct",
transform_input_fn=transform_input,
)
params = llm._default_params
pickled_string = pickle.dumps(transform_input).hex()
assert params["transform_input_fn"] == pickled_string
Loading…
Cancel
Save