- **Description:** Databricks SerDe uses cloudpickle instead of pickle
when serializing a user-defined function transform_input_fn since pickle
does not support functions defined in `__main__`, and cloudpickle
supports this.
- **Dependencies:** cloudpickle>=2.0.0
Added a unit test.
**Description:** Databricks LLM does not support SerDe the
transform_input_fn and transform_output_fn. After saving and loading,
the LLM will be broken. This PR serialize these functions into a hex
string using pickle, and saving the hex string in the yaml file. Using
pickle to serialize a function can be flaky, but this is a simple
workaround that unblocks many use cases. If more sophisticated SerDe is
needed, we can improve it later.
Test:
Added a simple unit test.
I did manual test on Databricks and it works well.
The saved yaml looks like:
```
llm:
_type: databricks
cluster_driver_port: null
cluster_id: null
databricks_uri: databricks
endpoint_name: databricks-mixtral-8x7b-instruct
extra_params: {}
host: e2-dogfood.staging.cloud.databricks.com
max_tokens: null
model_kwargs: null
n: 1
stop: null
task: null
temperature: 0.0
transform_input_fn: 80049520000000000000008c085f5f6d61696e5f5f948c0f7472616e73666f726d5f696e7075749493942e
transform_output_fn: null
```
@baskaryan
```python
from langchain_community.embeddings import DatabricksEmbeddings
from langchain_community.llms import Databricks
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
import mlflow
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
def transform_input(**request):
request["messages"] = [
{
"role": "user",
"content": request["prompt"]
}
]
del request["prompt"]
return request
llm = Databricks(endpoint_name="databricks-mixtral-8x7b-instruct", transform_input_fn=transform_input)
persist_dir = "faiss_databricks_embedding"
# Create the vector db, persist the db to a local fs folder
loader = TextLoader("state_of_the_union.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
db = FAISS.from_documents(docs, embeddings)
db.save_local(persist_dir)
def load_retriever(persist_directory):
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
vectorstore = FAISS.load_local(persist_directory, embeddings)
return vectorstore.as_retriever()
retriever = load_retriever(persist_dir)
retrievalQA = RetrievalQA.from_llm(llm=llm, retriever=retriever)
with mlflow.start_run() as run:
logged_model = mlflow.langchain.log_model(
retrievalQA,
artifact_path="retrieval_qa",
loader_fn=load_retriever,
persist_dir=persist_dir,
)
# Load the retrievalQA chain
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}]))
```