2023-12-11 21:53:30 +00:00
|
|
|
import os
|
community[patch]: Support SerDe transform functions in Databricks LLM (#16752)
**Description:** Databricks LLM does not support SerDe the
transform_input_fn and transform_output_fn. After saving and loading,
the LLM will be broken. This PR serialize these functions into a hex
string using pickle, and saving the hex string in the yaml file. Using
pickle to serialize a function can be flaky, but this is a simple
workaround that unblocks many use cases. If more sophisticated SerDe is
needed, we can improve it later.
Test:
Added a simple unit test.
I did manual test on Databricks and it works well.
The saved yaml looks like:
```
llm:
_type: databricks
cluster_driver_port: null
cluster_id: null
databricks_uri: databricks
endpoint_name: databricks-mixtral-8x7b-instruct
extra_params: {}
host: e2-dogfood.staging.cloud.databricks.com
max_tokens: null
model_kwargs: null
n: 1
stop: null
task: null
temperature: 0.0
transform_input_fn: 80049520000000000000008c085f5f6d61696e5f5f948c0f7472616e73666f726d5f696e7075749493942e
transform_output_fn: null
```
@baskaryan
```python
from langchain_community.embeddings import DatabricksEmbeddings
from langchain_community.llms import Databricks
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
import mlflow
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
def transform_input(**request):
request["messages"] = [
{
"role": "user",
"content": request["prompt"]
}
]
del request["prompt"]
return request
llm = Databricks(endpoint_name="databricks-mixtral-8x7b-instruct", transform_input_fn=transform_input)
persist_dir = "faiss_databricks_embedding"
# Create the vector db, persist the db to a local fs folder
loader = TextLoader("state_of_the_union.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
db = FAISS.from_documents(docs, embeddings)
db.save_local(persist_dir)
def load_retriever(persist_directory):
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
vectorstore = FAISS.load_local(persist_directory, embeddings)
return vectorstore.as_retriever()
retriever = load_retriever(persist_dir)
retrievalQA = RetrievalQA.from_llm(llm=llm, retriever=retriever)
with mlflow.start_run() as run:
logged_model = mlflow.langchain.log_model(
retrievalQA,
artifact_path="retrieval_qa",
loader_fn=load_retriever,
persist_dir=persist_dir,
)
# Load the retrievalQA chain
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}]))
```
2024-02-08 21:09:50 +00:00
|
|
|
import pickle
|
|
|
|
import re
|
2023-12-11 21:53:30 +00:00
|
|
|
import warnings
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional
|
|
|
|
|
|
|
|
import requests
|
|
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
|
|
from langchain_core.language_models import LLM
|
|
|
|
from langchain_core.pydantic_v1 import (
|
|
|
|
BaseModel,
|
|
|
|
Extra,
|
|
|
|
Field,
|
|
|
|
PrivateAttr,
|
|
|
|
root_validator,
|
|
|
|
validator,
|
|
|
|
)
|
|
|
|
|
|
|
|
__all__ = ["Databricks"]
|
|
|
|
|
|
|
|
|
|
|
|
class _DatabricksClientBase(BaseModel, ABC):
|
|
|
|
"""A base JSON API client that talks to Databricks."""
|
|
|
|
|
|
|
|
api_url: str
|
|
|
|
api_token: str
|
|
|
|
|
|
|
|
def request(self, method: str, url: str, request: Any) -> Any:
|
|
|
|
headers = {"Authorization": f"Bearer {self.api_token}"}
|
|
|
|
response = requests.request(
|
|
|
|
method=method, url=url, headers=headers, json=request
|
|
|
|
)
|
|
|
|
# TODO: error handling and automatic retries
|
|
|
|
if not response.ok:
|
|
|
|
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
|
|
|
|
return response.json()
|
|
|
|
|
|
|
|
def _get(self, url: str) -> Any:
|
|
|
|
return self.request("GET", url, None)
|
|
|
|
|
|
|
|
def _post(self, url: str, request: Any) -> Any:
|
|
|
|
return self.request("POST", url, request)
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def post(
|
|
|
|
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
|
|
|
|
) -> Any:
|
|
|
|
...
|
|
|
|
|
|
|
|
@property
|
|
|
|
def llm(self) -> bool:
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def _transform_completions(response: Dict[str, Any]) -> str:
|
|
|
|
return response["choices"][0]["text"]
|
|
|
|
|
|
|
|
|
community[patch]: Add param "task" to Databricks LLM to work around serialization of transform_output_fn (#14933)
**What is the reproduce code?**
```python
from langchain.chains import LLMChain, load_chain
from langchain.llms import Databricks
from langchain.prompts import PromptTemplate
def transform_output(response):
# Extract the answer from the responses.
return str(response["candidates"][0]["text"])
def transform_input(**request):
full_prompt = f"""{request["prompt"]}
Be Concise.
"""
request["prompt"] = full_prompt
return request
chat_model = Databricks(
endpoint_name="llama2-13B-chat-Brambles",
transform_input_fn=transform_input,
transform_output_fn=transform_output,
verbose=True,
)
print(f"Test chat model: {chat_model('What is Apache Spark')}") # This works
llm_chain = LLMChain(llm=chat_model, prompt=PromptTemplate.from_template("{chat_input}"))
llm_chain("colorful socks") # this works
llm_chain.save("databricks_llm_chain.yaml") # transform_input_fn and transform_output_fn are not serialized into the model yaml file
loaded_chain = load_chain("databricks_llm_chain.yaml") # The Databricks LLM is recreated with transform_input_fn=None, transform_output_fn=None.
loaded_chain("colorful socks") # Thus this errors. The transform_output_fn is needed to produce the correct output
```
Error:
```
File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-6c34afab-3473-421d-877f-1ef18930ef4d/lib/python3.10/site-packages/pydantic/v1/main.py", line 341, in __init__
raise validation_error
pydantic.v1.error_wrappers.ValidationError: 1 validation error for Generation
text
str type expected (type=type_error.str)
request payload: {'query': 'What is a databricks notebook?'}'}
```
**What does the error mean?**
When the LLM generates an answer, represented by a Generation data
object. The Generation data object takes a str field called text, e.g.
Generation(text=”blah”). However, the Databricks LLM tried to put a
non-str to text, e.g. Generation(text={“candidates”:[{“text”: “blah”}]})
Thus, pydantic errors.
**Why the output format becomes incorrect after saving and loading the
Databricks LLM?**
Databrick LLM does not support serializing transform_input_fn and
transform_output_fn, so they are not serialized into the model yaml
file. When the Databricks LLM is loaded, it is recreated with
transform_input_fn=None, transform_output_fn=None. Without
transform_output_fn, the output text is not unwrapped, thus errors.
Missing transform_output_fn causes this error.
Missing transform_input_fn causes the additional prompt “Be Concise.” to
be lost after saving and loading.
<!-- Thank you for contributing to LangChain!
Replace this entire comment with:
- **Description:** a description of the change,
- **Issue:** the issue # it fixes (if applicable),
- **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!
Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.
See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://python.langchain.com/docs/contributing/
If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.
If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
-->
---------
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-12-20 17:50:23 +00:00
|
|
|
def _transform_llama2_chat(response: Dict[str, Any]) -> str:
|
|
|
|
return response["candidates"][0]["text"]
|
|
|
|
|
|
|
|
|
2023-12-11 21:53:30 +00:00
|
|
|
def _transform_chat(response: Dict[str, Any]) -> str:
|
|
|
|
return response["choices"][0]["message"]["content"]
|
|
|
|
|
|
|
|
|
|
|
|
class _DatabricksServingEndpointClient(_DatabricksClientBase):
|
|
|
|
"""An API client that talks to a Databricks serving endpoint."""
|
|
|
|
|
|
|
|
host: str
|
|
|
|
endpoint_name: str
|
|
|
|
databricks_uri: str
|
|
|
|
client: Any = None
|
|
|
|
external_or_foundation: bool = False
|
|
|
|
task: Optional[str] = None
|
|
|
|
|
|
|
|
def __init__(self, **data: Any):
|
|
|
|
super().__init__(**data)
|
|
|
|
|
|
|
|
try:
|
|
|
|
from mlflow.deployments import get_deploy_client
|
|
|
|
|
|
|
|
self.client = get_deploy_client(self.databricks_uri)
|
|
|
|
except ImportError as e:
|
|
|
|
raise ImportError(
|
|
|
|
"Failed to create the client. "
|
|
|
|
"Please install mlflow with `pip install mlflow`."
|
|
|
|
) from e
|
|
|
|
|
|
|
|
endpoint = self.client.get_endpoint(self.endpoint_name)
|
|
|
|
self.external_or_foundation = endpoint.get("endpoint_type", "").lower() in (
|
|
|
|
"external_model",
|
|
|
|
"foundation_model_api",
|
|
|
|
)
|
community[patch]: Add param "task" to Databricks LLM to work around serialization of transform_output_fn (#14933)
**What is the reproduce code?**
```python
from langchain.chains import LLMChain, load_chain
from langchain.llms import Databricks
from langchain.prompts import PromptTemplate
def transform_output(response):
# Extract the answer from the responses.
return str(response["candidates"][0]["text"])
def transform_input(**request):
full_prompt = f"""{request["prompt"]}
Be Concise.
"""
request["prompt"] = full_prompt
return request
chat_model = Databricks(
endpoint_name="llama2-13B-chat-Brambles",
transform_input_fn=transform_input,
transform_output_fn=transform_output,
verbose=True,
)
print(f"Test chat model: {chat_model('What is Apache Spark')}") # This works
llm_chain = LLMChain(llm=chat_model, prompt=PromptTemplate.from_template("{chat_input}"))
llm_chain("colorful socks") # this works
llm_chain.save("databricks_llm_chain.yaml") # transform_input_fn and transform_output_fn are not serialized into the model yaml file
loaded_chain = load_chain("databricks_llm_chain.yaml") # The Databricks LLM is recreated with transform_input_fn=None, transform_output_fn=None.
loaded_chain("colorful socks") # Thus this errors. The transform_output_fn is needed to produce the correct output
```
Error:
```
File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-6c34afab-3473-421d-877f-1ef18930ef4d/lib/python3.10/site-packages/pydantic/v1/main.py", line 341, in __init__
raise validation_error
pydantic.v1.error_wrappers.ValidationError: 1 validation error for Generation
text
str type expected (type=type_error.str)
request payload: {'query': 'What is a databricks notebook?'}'}
```
**What does the error mean?**
When the LLM generates an answer, represented by a Generation data
object. The Generation data object takes a str field called text, e.g.
Generation(text=”blah”). However, the Databricks LLM tried to put a
non-str to text, e.g. Generation(text={“candidates”:[{“text”: “blah”}]})
Thus, pydantic errors.
**Why the output format becomes incorrect after saving and loading the
Databricks LLM?**
Databrick LLM does not support serializing transform_input_fn and
transform_output_fn, so they are not serialized into the model yaml
file. When the Databricks LLM is loaded, it is recreated with
transform_input_fn=None, transform_output_fn=None. Without
transform_output_fn, the output text is not unwrapped, thus errors.
Missing transform_output_fn causes this error.
Missing transform_input_fn causes the additional prompt “Be Concise.” to
be lost after saving and loading.
<!-- Thank you for contributing to LangChain!
Replace this entire comment with:
- **Description:** a description of the change,
- **Issue:** the issue # it fixes (if applicable),
- **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!
Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.
See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://python.langchain.com/docs/contributing/
If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.
If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
-->
---------
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-12-20 17:50:23 +00:00
|
|
|
if self.task is None:
|
|
|
|
self.task = endpoint.get("task")
|
2023-12-11 21:53:30 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def llm(self) -> bool:
|
community[patch]: Add param "task" to Databricks LLM to work around serialization of transform_output_fn (#14933)
**What is the reproduce code?**
```python
from langchain.chains import LLMChain, load_chain
from langchain.llms import Databricks
from langchain.prompts import PromptTemplate
def transform_output(response):
# Extract the answer from the responses.
return str(response["candidates"][0]["text"])
def transform_input(**request):
full_prompt = f"""{request["prompt"]}
Be Concise.
"""
request["prompt"] = full_prompt
return request
chat_model = Databricks(
endpoint_name="llama2-13B-chat-Brambles",
transform_input_fn=transform_input,
transform_output_fn=transform_output,
verbose=True,
)
print(f"Test chat model: {chat_model('What is Apache Spark')}") # This works
llm_chain = LLMChain(llm=chat_model, prompt=PromptTemplate.from_template("{chat_input}"))
llm_chain("colorful socks") # this works
llm_chain.save("databricks_llm_chain.yaml") # transform_input_fn and transform_output_fn are not serialized into the model yaml file
loaded_chain = load_chain("databricks_llm_chain.yaml") # The Databricks LLM is recreated with transform_input_fn=None, transform_output_fn=None.
loaded_chain("colorful socks") # Thus this errors. The transform_output_fn is needed to produce the correct output
```
Error:
```
File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-6c34afab-3473-421d-877f-1ef18930ef4d/lib/python3.10/site-packages/pydantic/v1/main.py", line 341, in __init__
raise validation_error
pydantic.v1.error_wrappers.ValidationError: 1 validation error for Generation
text
str type expected (type=type_error.str)
request payload: {'query': 'What is a databricks notebook?'}'}
```
**What does the error mean?**
When the LLM generates an answer, represented by a Generation data
object. The Generation data object takes a str field called text, e.g.
Generation(text=”blah”). However, the Databricks LLM tried to put a
non-str to text, e.g. Generation(text={“candidates”:[{“text”: “blah”}]})
Thus, pydantic errors.
**Why the output format becomes incorrect after saving and loading the
Databricks LLM?**
Databrick LLM does not support serializing transform_input_fn and
transform_output_fn, so they are not serialized into the model yaml
file. When the Databricks LLM is loaded, it is recreated with
transform_input_fn=None, transform_output_fn=None. Without
transform_output_fn, the output text is not unwrapped, thus errors.
Missing transform_output_fn causes this error.
Missing transform_input_fn causes the additional prompt “Be Concise.” to
be lost after saving and loading.
<!-- Thank you for contributing to LangChain!
Replace this entire comment with:
- **Description:** a description of the change,
- **Issue:** the issue # it fixes (if applicable),
- **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!
Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.
See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://python.langchain.com/docs/contributing/
If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.
If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
-->
---------
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-12-20 17:50:23 +00:00
|
|
|
return self.task in ("llm/v1/chat", "llm/v1/completions", "llama2/chat")
|
2023-12-11 21:53:30 +00:00
|
|
|
|
|
|
|
@root_validator(pre=True)
|
|
|
|
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
if "api_url" not in values:
|
|
|
|
host = values["host"]
|
|
|
|
endpoint_name = values["endpoint_name"]
|
|
|
|
api_url = f"https://{host}/serving-endpoints/{endpoint_name}/invocations"
|
|
|
|
values["api_url"] = api_url
|
|
|
|
return values
|
|
|
|
|
|
|
|
def post(
|
|
|
|
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
|
|
|
|
) -> Any:
|
|
|
|
if self.external_or_foundation:
|
|
|
|
resp = self.client.predict(endpoint=self.endpoint_name, inputs=request)
|
|
|
|
if transform_output_fn:
|
|
|
|
return transform_output_fn(resp)
|
|
|
|
|
|
|
|
if self.task == "llm/v1/chat":
|
|
|
|
return _transform_chat(resp)
|
|
|
|
elif self.task == "llm/v1/completions":
|
|
|
|
return _transform_completions(resp)
|
|
|
|
|
|
|
|
return resp
|
|
|
|
else:
|
|
|
|
# See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html
|
|
|
|
wrapped_request = {"dataframe_records": [request]}
|
|
|
|
response = self.client.predict(
|
|
|
|
endpoint=self.endpoint_name, inputs=wrapped_request
|
|
|
|
)
|
|
|
|
preds = response["predictions"]
|
|
|
|
# For a single-record query, the result is not a list.
|
|
|
|
pred = preds[0] if isinstance(preds, list) else preds
|
community[patch]: Add param "task" to Databricks LLM to work around serialization of transform_output_fn (#14933)
**What is the reproduce code?**
```python
from langchain.chains import LLMChain, load_chain
from langchain.llms import Databricks
from langchain.prompts import PromptTemplate
def transform_output(response):
# Extract the answer from the responses.
return str(response["candidates"][0]["text"])
def transform_input(**request):
full_prompt = f"""{request["prompt"]}
Be Concise.
"""
request["prompt"] = full_prompt
return request
chat_model = Databricks(
endpoint_name="llama2-13B-chat-Brambles",
transform_input_fn=transform_input,
transform_output_fn=transform_output,
verbose=True,
)
print(f"Test chat model: {chat_model('What is Apache Spark')}") # This works
llm_chain = LLMChain(llm=chat_model, prompt=PromptTemplate.from_template("{chat_input}"))
llm_chain("colorful socks") # this works
llm_chain.save("databricks_llm_chain.yaml") # transform_input_fn and transform_output_fn are not serialized into the model yaml file
loaded_chain = load_chain("databricks_llm_chain.yaml") # The Databricks LLM is recreated with transform_input_fn=None, transform_output_fn=None.
loaded_chain("colorful socks") # Thus this errors. The transform_output_fn is needed to produce the correct output
```
Error:
```
File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-6c34afab-3473-421d-877f-1ef18930ef4d/lib/python3.10/site-packages/pydantic/v1/main.py", line 341, in __init__
raise validation_error
pydantic.v1.error_wrappers.ValidationError: 1 validation error for Generation
text
str type expected (type=type_error.str)
request payload: {'query': 'What is a databricks notebook?'}'}
```
**What does the error mean?**
When the LLM generates an answer, represented by a Generation data
object. The Generation data object takes a str field called text, e.g.
Generation(text=”blah”). However, the Databricks LLM tried to put a
non-str to text, e.g. Generation(text={“candidates”:[{“text”: “blah”}]})
Thus, pydantic errors.
**Why the output format becomes incorrect after saving and loading the
Databricks LLM?**
Databrick LLM does not support serializing transform_input_fn and
transform_output_fn, so they are not serialized into the model yaml
file. When the Databricks LLM is loaded, it is recreated with
transform_input_fn=None, transform_output_fn=None. Without
transform_output_fn, the output text is not unwrapped, thus errors.
Missing transform_output_fn causes this error.
Missing transform_input_fn causes the additional prompt “Be Concise.” to
be lost after saving and loading.
<!-- Thank you for contributing to LangChain!
Replace this entire comment with:
- **Description:** a description of the change,
- **Issue:** the issue # it fixes (if applicable),
- **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!
Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.
See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://python.langchain.com/docs/contributing/
If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.
If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
-->
---------
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-12-20 17:50:23 +00:00
|
|
|
if self.task == "llama2/chat":
|
|
|
|
return _transform_llama2_chat(pred)
|
2023-12-11 21:53:30 +00:00
|
|
|
return transform_output_fn(pred) if transform_output_fn else pred
|
|
|
|
|
|
|
|
|
|
|
|
class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
|
|
|
|
"""An API client that talks to a Databricks cluster driver proxy app."""
|
|
|
|
|
|
|
|
host: str
|
|
|
|
cluster_id: str
|
|
|
|
cluster_driver_port: str
|
|
|
|
|
|
|
|
@root_validator(pre=True)
|
|
|
|
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
if "api_url" not in values:
|
|
|
|
host = values["host"]
|
|
|
|
cluster_id = values["cluster_id"]
|
|
|
|
port = values["cluster_driver_port"]
|
|
|
|
api_url = f"https://{host}/driver-proxy-api/o/0/{cluster_id}/{port}"
|
|
|
|
values["api_url"] = api_url
|
|
|
|
return values
|
|
|
|
|
|
|
|
def post(
|
|
|
|
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
|
|
|
|
) -> Any:
|
|
|
|
resp = self._post(self.api_url, request)
|
|
|
|
return transform_output_fn(resp) if transform_output_fn else resp
|
|
|
|
|
|
|
|
|
|
|
|
def get_repl_context() -> Any:
|
|
|
|
"""Gets the notebook REPL context if running inside a Databricks notebook.
|
|
|
|
Returns None otherwise.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
from dbruntime.databricks_repl_context import get_context
|
|
|
|
|
|
|
|
return get_context()
|
|
|
|
except ImportError:
|
|
|
|
raise ImportError(
|
|
|
|
"Cannot access dbruntime, not running inside a Databricks notebook."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def get_default_host() -> str:
|
|
|
|
"""Gets the default Databricks workspace hostname.
|
|
|
|
Raises an error if the hostname cannot be automatically determined.
|
|
|
|
"""
|
|
|
|
host = os.getenv("DATABRICKS_HOST")
|
|
|
|
if not host:
|
|
|
|
try:
|
|
|
|
host = get_repl_context().browserHostName
|
|
|
|
if not host:
|
|
|
|
raise ValueError("context doesn't contain browserHostName.")
|
|
|
|
except Exception as e:
|
|
|
|
raise ValueError(
|
|
|
|
"host was not set and cannot be automatically inferred. Set "
|
|
|
|
f"environment variable 'DATABRICKS_HOST'. Received error: {e}"
|
|
|
|
)
|
|
|
|
# TODO: support Databricks CLI profile
|
|
|
|
host = host.lstrip("https://").lstrip("http://").rstrip("/")
|
|
|
|
return host
|
|
|
|
|
|
|
|
|
|
|
|
def get_default_api_token() -> str:
|
|
|
|
"""Gets the default Databricks personal access token.
|
|
|
|
Raises an error if the token cannot be automatically determined.
|
|
|
|
"""
|
|
|
|
if api_token := os.getenv("DATABRICKS_TOKEN"):
|
|
|
|
return api_token
|
|
|
|
try:
|
|
|
|
api_token = get_repl_context().apiToken
|
|
|
|
if not api_token:
|
|
|
|
raise ValueError("context doesn't contain apiToken.")
|
|
|
|
except Exception as e:
|
|
|
|
raise ValueError(
|
|
|
|
"api_token was not set and cannot be automatically inferred. Set "
|
|
|
|
f"environment variable 'DATABRICKS_TOKEN'. Received error: {e}"
|
|
|
|
)
|
|
|
|
# TODO: support Databricks CLI profile
|
|
|
|
return api_token
|
|
|
|
|
|
|
|
|
community[patch]: Support SerDe transform functions in Databricks LLM (#16752)
**Description:** Databricks LLM does not support SerDe the
transform_input_fn and transform_output_fn. After saving and loading,
the LLM will be broken. This PR serialize these functions into a hex
string using pickle, and saving the hex string in the yaml file. Using
pickle to serialize a function can be flaky, but this is a simple
workaround that unblocks many use cases. If more sophisticated SerDe is
needed, we can improve it later.
Test:
Added a simple unit test.
I did manual test on Databricks and it works well.
The saved yaml looks like:
```
llm:
_type: databricks
cluster_driver_port: null
cluster_id: null
databricks_uri: databricks
endpoint_name: databricks-mixtral-8x7b-instruct
extra_params: {}
host: e2-dogfood.staging.cloud.databricks.com
max_tokens: null
model_kwargs: null
n: 1
stop: null
task: null
temperature: 0.0
transform_input_fn: 80049520000000000000008c085f5f6d61696e5f5f948c0f7472616e73666f726d5f696e7075749493942e
transform_output_fn: null
```
@baskaryan
```python
from langchain_community.embeddings import DatabricksEmbeddings
from langchain_community.llms import Databricks
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
import mlflow
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
def transform_input(**request):
request["messages"] = [
{
"role": "user",
"content": request["prompt"]
}
]
del request["prompt"]
return request
llm = Databricks(endpoint_name="databricks-mixtral-8x7b-instruct", transform_input_fn=transform_input)
persist_dir = "faiss_databricks_embedding"
# Create the vector db, persist the db to a local fs folder
loader = TextLoader("state_of_the_union.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
db = FAISS.from_documents(docs, embeddings)
db.save_local(persist_dir)
def load_retriever(persist_directory):
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
vectorstore = FAISS.load_local(persist_directory, embeddings)
return vectorstore.as_retriever()
retriever = load_retriever(persist_dir)
retrievalQA = RetrievalQA.from_llm(llm=llm, retriever=retriever)
with mlflow.start_run() as run:
logged_model = mlflow.langchain.log_model(
retrievalQA,
artifact_path="retrieval_qa",
loader_fn=load_retriever,
persist_dir=persist_dir,
)
# Load the retrievalQA chain
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}]))
```
2024-02-08 21:09:50 +00:00
|
|
|
def _is_hex_string(data: str) -> bool:
|
|
|
|
"""Checks if a data is a valid hexadecimal string using a regular expression."""
|
|
|
|
if not isinstance(data, str):
|
|
|
|
return False
|
|
|
|
pattern = r"^[0-9a-fA-F]+$"
|
|
|
|
return bool(re.match(pattern, data))
|
|
|
|
|
|
|
|
|
|
|
|
def _load_pickled_fn_from_hex_string(data: str) -> Callable:
|
|
|
|
"""Loads a pickled function from a hexadecimal string."""
|
|
|
|
try:
|
|
|
|
return pickle.loads(bytes.fromhex(data))
|
|
|
|
except Exception as e:
|
|
|
|
raise ValueError(
|
|
|
|
f"Failed to load the pickled function from a hexadecimal string. Error: {e}"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _pickle_fn_to_hex_string(fn: Callable) -> str:
|
|
|
|
"""Pickles a function and returns the hexadecimal string."""
|
|
|
|
try:
|
|
|
|
return pickle.dumps(fn).hex()
|
|
|
|
except Exception as e:
|
|
|
|
raise ValueError(f"Failed to pickle the function: {e}")
|
|
|
|
|
|
|
|
|
2023-12-11 21:53:30 +00:00
|
|
|
class Databricks(LLM):
|
|
|
|
|
|
|
|
"""Databricks serving endpoint or a cluster driver proxy app for LLM.
|
|
|
|
|
|
|
|
It supports two endpoint types:
|
|
|
|
|
|
|
|
* **Serving endpoint** (recommended for both production and development).
|
|
|
|
We assume that an LLM was deployed to a serving endpoint.
|
|
|
|
To wrap it as an LLM you must have "Can Query" permission to the endpoint.
|
|
|
|
Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and
|
|
|
|
``cluster_driver_port``.
|
|
|
|
|
|
|
|
If the underlying model is a model registered by MLflow, the expected model
|
|
|
|
signature is:
|
|
|
|
|
|
|
|
* inputs::
|
|
|
|
|
|
|
|
[{"name": "prompt", "type": "string"},
|
|
|
|
{"name": "stop", "type": "list[string]"}]
|
|
|
|
|
|
|
|
* outputs: ``[{"type": "string"}]``
|
|
|
|
|
|
|
|
If the underlying model is an external or foundation model, the response from the
|
|
|
|
endpoint is automatically transformed to the expected format unless
|
|
|
|
``transform_output_fn`` is provided.
|
|
|
|
|
|
|
|
* **Cluster driver proxy app** (recommended for interactive development).
|
|
|
|
One can load an LLM on a Databricks interactive cluster and start a local HTTP
|
|
|
|
server on the driver node to serve the model at ``/`` using HTTP POST method
|
|
|
|
with JSON input/output.
|
|
|
|
Please use a port number between ``[3000, 8000]`` and let the server listen to
|
|
|
|
the driver IP address or simply ``0.0.0.0`` instead of localhost only.
|
|
|
|
To wrap it as an LLM you must have "Can Attach To" permission to the cluster.
|
|
|
|
Set ``cluster_id`` and ``cluster_driver_port`` and do not set ``endpoint_name``.
|
|
|
|
The expected server schema (using JSON schema) is:
|
|
|
|
|
|
|
|
* inputs::
|
|
|
|
|
|
|
|
{"type": "object",
|
|
|
|
"properties": {
|
|
|
|
"prompt": {"type": "string"},
|
|
|
|
"stop": {"type": "array", "items": {"type": "string"}}},
|
|
|
|
"required": ["prompt"]}`
|
|
|
|
|
|
|
|
* outputs: ``{"type": "string"}``
|
|
|
|
|
|
|
|
If the endpoint model signature is different or you want to set extra params,
|
|
|
|
you can use `transform_input_fn` and `transform_output_fn` to apply necessary
|
|
|
|
transformations before and after the query.
|
|
|
|
"""
|
|
|
|
|
|
|
|
host: str = Field(default_factory=get_default_host)
|
|
|
|
"""Databricks workspace hostname.
|
|
|
|
If not provided, the default value is determined by
|
|
|
|
|
|
|
|
* the ``DATABRICKS_HOST`` environment variable if present, or
|
|
|
|
* the hostname of the current Databricks workspace if running inside
|
|
|
|
a Databricks notebook attached to an interactive cluster in "single user"
|
|
|
|
or "no isolation shared" mode.
|
|
|
|
"""
|
|
|
|
|
|
|
|
api_token: str = Field(default_factory=get_default_api_token)
|
|
|
|
"""Databricks personal access token.
|
|
|
|
If not provided, the default value is determined by
|
|
|
|
|
|
|
|
* the ``DATABRICKS_TOKEN`` environment variable if present, or
|
|
|
|
* an automatically generated temporary token if running inside a Databricks
|
|
|
|
notebook attached to an interactive cluster in "single user" or
|
|
|
|
"no isolation shared" mode.
|
|
|
|
"""
|
|
|
|
|
|
|
|
endpoint_name: Optional[str] = None
|
|
|
|
"""Name of the model serving endpoint.
|
|
|
|
You must specify the endpoint name to connect to a model serving endpoint.
|
|
|
|
You must not set both ``endpoint_name`` and ``cluster_id``.
|
|
|
|
"""
|
|
|
|
|
|
|
|
cluster_id: Optional[str] = None
|
|
|
|
"""ID of the cluster if connecting to a cluster driver proxy app.
|
|
|
|
If neither ``endpoint_name`` nor ``cluster_id`` is not provided and the code runs
|
|
|
|
inside a Databricks notebook attached to an interactive cluster in "single user"
|
|
|
|
or "no isolation shared" mode, the current cluster ID is used as default.
|
|
|
|
You must not set both ``endpoint_name`` and ``cluster_id``.
|
|
|
|
"""
|
|
|
|
|
|
|
|
cluster_driver_port: Optional[str] = None
|
|
|
|
"""The port number used by the HTTP server running on the cluster driver node.
|
|
|
|
The server should listen on the driver IP address or simply ``0.0.0.0`` to connect.
|
|
|
|
We recommend the server using a port number between ``[3000, 8000]``.
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs: Optional[Dict[str, Any]] = None
|
|
|
|
"""
|
|
|
|
Deprecated. Please use ``extra_params`` instead. Extra parameters to pass to
|
|
|
|
the endpoint.
|
|
|
|
"""
|
|
|
|
|
|
|
|
transform_input_fn: Optional[Callable] = None
|
|
|
|
"""A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible
|
|
|
|
request object that the endpoint accepts.
|
|
|
|
For example, you can apply a prompt template to the input prompt.
|
|
|
|
"""
|
|
|
|
|
|
|
|
transform_output_fn: Optional[Callable[..., str]] = None
|
|
|
|
"""A function that transforms the output from the endpoint to the generated text.
|
|
|
|
"""
|
|
|
|
|
|
|
|
databricks_uri: str = "databricks"
|
|
|
|
"""The databricks URI. Only used when using a serving endpoint."""
|
|
|
|
|
|
|
|
temperature: float = 0.0
|
|
|
|
"""The sampling temperature."""
|
|
|
|
n: int = 1
|
|
|
|
"""The number of completion choices to generate."""
|
|
|
|
stop: Optional[List[str]] = None
|
|
|
|
"""The stop sequence."""
|
|
|
|
max_tokens: Optional[int] = None
|
|
|
|
"""The maximum number of tokens to generate."""
|
|
|
|
extra_params: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
"""Any extra parameters to pass to the endpoint."""
|
community[patch]: Add param "task" to Databricks LLM to work around serialization of transform_output_fn (#14933)
**What is the reproduce code?**
```python
from langchain.chains import LLMChain, load_chain
from langchain.llms import Databricks
from langchain.prompts import PromptTemplate
def transform_output(response):
# Extract the answer from the responses.
return str(response["candidates"][0]["text"])
def transform_input(**request):
full_prompt = f"""{request["prompt"]}
Be Concise.
"""
request["prompt"] = full_prompt
return request
chat_model = Databricks(
endpoint_name="llama2-13B-chat-Brambles",
transform_input_fn=transform_input,
transform_output_fn=transform_output,
verbose=True,
)
print(f"Test chat model: {chat_model('What is Apache Spark')}") # This works
llm_chain = LLMChain(llm=chat_model, prompt=PromptTemplate.from_template("{chat_input}"))
llm_chain("colorful socks") # this works
llm_chain.save("databricks_llm_chain.yaml") # transform_input_fn and transform_output_fn are not serialized into the model yaml file
loaded_chain = load_chain("databricks_llm_chain.yaml") # The Databricks LLM is recreated with transform_input_fn=None, transform_output_fn=None.
loaded_chain("colorful socks") # Thus this errors. The transform_output_fn is needed to produce the correct output
```
Error:
```
File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-6c34afab-3473-421d-877f-1ef18930ef4d/lib/python3.10/site-packages/pydantic/v1/main.py", line 341, in __init__
raise validation_error
pydantic.v1.error_wrappers.ValidationError: 1 validation error for Generation
text
str type expected (type=type_error.str)
request payload: {'query': 'What is a databricks notebook?'}'}
```
**What does the error mean?**
When the LLM generates an answer, represented by a Generation data
object. The Generation data object takes a str field called text, e.g.
Generation(text=”blah”). However, the Databricks LLM tried to put a
non-str to text, e.g. Generation(text={“candidates”:[{“text”: “blah”}]})
Thus, pydantic errors.
**Why the output format becomes incorrect after saving and loading the
Databricks LLM?**
Databrick LLM does not support serializing transform_input_fn and
transform_output_fn, so they are not serialized into the model yaml
file. When the Databricks LLM is loaded, it is recreated with
transform_input_fn=None, transform_output_fn=None. Without
transform_output_fn, the output text is not unwrapped, thus errors.
Missing transform_output_fn causes this error.
Missing transform_input_fn causes the additional prompt “Be Concise.” to
be lost after saving and loading.
<!-- Thank you for contributing to LangChain!
Replace this entire comment with:
- **Description:** a description of the change,
- **Issue:** the issue # it fixes (if applicable),
- **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!
Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.
See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://python.langchain.com/docs/contributing/
If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.
If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
-->
---------
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-12-20 17:50:23 +00:00
|
|
|
task: Optional[str] = None
|
|
|
|
"""The task of the endpoint. Only used when using a serving endpoint.
|
|
|
|
If not provided, the task is automatically inferred from the endpoint.
|
|
|
|
"""
|
2023-12-11 21:53:30 +00:00
|
|
|
|
|
|
|
_client: _DatabricksClientBase = PrivateAttr()
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
extra = Extra.forbid
|
|
|
|
underscore_attrs_are_private = True
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _llm_params(self) -> Dict[str, Any]:
|
|
|
|
params: Dict[str, Any] = {
|
|
|
|
"temperature": self.temperature,
|
|
|
|
"n": self.n,
|
|
|
|
}
|
|
|
|
if self.stop:
|
|
|
|
params["stop"] = self.stop
|
|
|
|
if self.max_tokens is not None:
|
|
|
|
params["max_tokens"] = self.max_tokens
|
|
|
|
return params
|
|
|
|
|
|
|
|
@validator("cluster_id", always=True)
|
|
|
|
def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
|
|
|
|
if v and values["endpoint_name"]:
|
|
|
|
raise ValueError("Cannot set both endpoint_name and cluster_id.")
|
|
|
|
elif values["endpoint_name"]:
|
|
|
|
return None
|
|
|
|
elif v:
|
|
|
|
return v
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
if v := get_repl_context().clusterId:
|
|
|
|
return v
|
|
|
|
raise ValueError("Context doesn't contain clusterId.")
|
|
|
|
except Exception as e:
|
|
|
|
raise ValueError(
|
|
|
|
"Neither endpoint_name nor cluster_id was set. "
|
|
|
|
"And the cluster_id cannot be automatically determined. Received"
|
|
|
|
f" error: {e}"
|
|
|
|
)
|
|
|
|
|
|
|
|
@validator("cluster_driver_port", always=True)
|
|
|
|
def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
|
|
|
|
if v and values["endpoint_name"]:
|
|
|
|
raise ValueError("Cannot set both endpoint_name and cluster_driver_port.")
|
|
|
|
elif values["endpoint_name"]:
|
|
|
|
return None
|
|
|
|
elif v is None:
|
|
|
|
raise ValueError(
|
|
|
|
"Must set cluster_driver_port to connect to a cluster driver."
|
|
|
|
)
|
|
|
|
elif int(v) <= 0:
|
|
|
|
raise ValueError(f"Invalid cluster_driver_port: {v}")
|
|
|
|
else:
|
|
|
|
return v
|
|
|
|
|
|
|
|
@validator("model_kwargs", always=True)
|
|
|
|
def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
|
|
|
if v:
|
|
|
|
assert "prompt" not in v, "model_kwargs must not contain key 'prompt'"
|
|
|
|
assert "stop" not in v, "model_kwargs must not contain key 'stop'"
|
|
|
|
return v
|
|
|
|
|
|
|
|
def __init__(self, **data: Any):
|
community[patch]: Support SerDe transform functions in Databricks LLM (#16752)
**Description:** Databricks LLM does not support SerDe the
transform_input_fn and transform_output_fn. After saving and loading,
the LLM will be broken. This PR serialize these functions into a hex
string using pickle, and saving the hex string in the yaml file. Using
pickle to serialize a function can be flaky, but this is a simple
workaround that unblocks many use cases. If more sophisticated SerDe is
needed, we can improve it later.
Test:
Added a simple unit test.
I did manual test on Databricks and it works well.
The saved yaml looks like:
```
llm:
_type: databricks
cluster_driver_port: null
cluster_id: null
databricks_uri: databricks
endpoint_name: databricks-mixtral-8x7b-instruct
extra_params: {}
host: e2-dogfood.staging.cloud.databricks.com
max_tokens: null
model_kwargs: null
n: 1
stop: null
task: null
temperature: 0.0
transform_input_fn: 80049520000000000000008c085f5f6d61696e5f5f948c0f7472616e73666f726d5f696e7075749493942e
transform_output_fn: null
```
@baskaryan
```python
from langchain_community.embeddings import DatabricksEmbeddings
from langchain_community.llms import Databricks
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
import mlflow
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
def transform_input(**request):
request["messages"] = [
{
"role": "user",
"content": request["prompt"]
}
]
del request["prompt"]
return request
llm = Databricks(endpoint_name="databricks-mixtral-8x7b-instruct", transform_input_fn=transform_input)
persist_dir = "faiss_databricks_embedding"
# Create the vector db, persist the db to a local fs folder
loader = TextLoader("state_of_the_union.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
db = FAISS.from_documents(docs, embeddings)
db.save_local(persist_dir)
def load_retriever(persist_directory):
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
vectorstore = FAISS.load_local(persist_directory, embeddings)
return vectorstore.as_retriever()
retriever = load_retriever(persist_dir)
retrievalQA = RetrievalQA.from_llm(llm=llm, retriever=retriever)
with mlflow.start_run() as run:
logged_model = mlflow.langchain.log_model(
retrievalQA,
artifact_path="retrieval_qa",
loader_fn=load_retriever,
persist_dir=persist_dir,
)
# Load the retrievalQA chain
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}]))
```
2024-02-08 21:09:50 +00:00
|
|
|
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
|
|
|
|
data["transform_input_fn"] = _load_pickled_fn_from_hex_string(
|
|
|
|
data["transform_input_fn"]
|
|
|
|
)
|
|
|
|
if "transform_output_fn" in data and _is_hex_string(
|
|
|
|
data["transform_output_fn"]
|
|
|
|
):
|
|
|
|
data["transform_output_fn"] = _load_pickled_fn_from_hex_string(
|
|
|
|
data["transform_output_fn"]
|
|
|
|
)
|
|
|
|
|
2023-12-11 21:53:30 +00:00
|
|
|
super().__init__(**data)
|
|
|
|
if self.model_kwargs is not None and self.extra_params is not None:
|
|
|
|
raise ValueError("Cannot set both extra_params and extra_params.")
|
|
|
|
elif self.model_kwargs is not None:
|
|
|
|
warnings.warn(
|
|
|
|
"model_kwargs is deprecated. Please use extra_params instead.",
|
|
|
|
DeprecationWarning,
|
|
|
|
)
|
|
|
|
if self.endpoint_name:
|
|
|
|
self._client = _DatabricksServingEndpointClient(
|
|
|
|
host=self.host,
|
|
|
|
api_token=self.api_token,
|
|
|
|
endpoint_name=self.endpoint_name,
|
|
|
|
databricks_uri=self.databricks_uri,
|
community[patch]: Add param "task" to Databricks LLM to work around serialization of transform_output_fn (#14933)
**What is the reproduce code?**
```python
from langchain.chains import LLMChain, load_chain
from langchain.llms import Databricks
from langchain.prompts import PromptTemplate
def transform_output(response):
# Extract the answer from the responses.
return str(response["candidates"][0]["text"])
def transform_input(**request):
full_prompt = f"""{request["prompt"]}
Be Concise.
"""
request["prompt"] = full_prompt
return request
chat_model = Databricks(
endpoint_name="llama2-13B-chat-Brambles",
transform_input_fn=transform_input,
transform_output_fn=transform_output,
verbose=True,
)
print(f"Test chat model: {chat_model('What is Apache Spark')}") # This works
llm_chain = LLMChain(llm=chat_model, prompt=PromptTemplate.from_template("{chat_input}"))
llm_chain("colorful socks") # this works
llm_chain.save("databricks_llm_chain.yaml") # transform_input_fn and transform_output_fn are not serialized into the model yaml file
loaded_chain = load_chain("databricks_llm_chain.yaml") # The Databricks LLM is recreated with transform_input_fn=None, transform_output_fn=None.
loaded_chain("colorful socks") # Thus this errors. The transform_output_fn is needed to produce the correct output
```
Error:
```
File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-6c34afab-3473-421d-877f-1ef18930ef4d/lib/python3.10/site-packages/pydantic/v1/main.py", line 341, in __init__
raise validation_error
pydantic.v1.error_wrappers.ValidationError: 1 validation error for Generation
text
str type expected (type=type_error.str)
request payload: {'query': 'What is a databricks notebook?'}'}
```
**What does the error mean?**
When the LLM generates an answer, represented by a Generation data
object. The Generation data object takes a str field called text, e.g.
Generation(text=”blah”). However, the Databricks LLM tried to put a
non-str to text, e.g. Generation(text={“candidates”:[{“text”: “blah”}]})
Thus, pydantic errors.
**Why the output format becomes incorrect after saving and loading the
Databricks LLM?**
Databrick LLM does not support serializing transform_input_fn and
transform_output_fn, so they are not serialized into the model yaml
file. When the Databricks LLM is loaded, it is recreated with
transform_input_fn=None, transform_output_fn=None. Without
transform_output_fn, the output text is not unwrapped, thus errors.
Missing transform_output_fn causes this error.
Missing transform_input_fn causes the additional prompt “Be Concise.” to
be lost after saving and loading.
<!-- Thank you for contributing to LangChain!
Replace this entire comment with:
- **Description:** a description of the change,
- **Issue:** the issue # it fixes (if applicable),
- **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!
Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.
See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://python.langchain.com/docs/contributing/
If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.
If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
-->
---------
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-12-20 17:50:23 +00:00
|
|
|
task=self.task,
|
2023-12-11 21:53:30 +00:00
|
|
|
)
|
|
|
|
elif self.cluster_id and self.cluster_driver_port:
|
|
|
|
self._client = _DatabricksClusterDriverProxyClient(
|
|
|
|
host=self.host,
|
|
|
|
api_token=self.api_token,
|
|
|
|
cluster_id=self.cluster_id,
|
|
|
|
cluster_driver_port=self.cluster_driver_port,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
"Must specify either endpoint_name or cluster_id/cluster_driver_port."
|
|
|
|
)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _default_params(self) -> Dict[str, Any]:
|
|
|
|
"""Return default params."""
|
|
|
|
return {
|
|
|
|
"host": self.host,
|
|
|
|
# "api_token": self.api_token, # Never save the token
|
|
|
|
"endpoint_name": self.endpoint_name,
|
|
|
|
"cluster_id": self.cluster_id,
|
|
|
|
"cluster_driver_port": self.cluster_driver_port,
|
|
|
|
"databricks_uri": self.databricks_uri,
|
|
|
|
"model_kwargs": self.model_kwargs,
|
|
|
|
"temperature": self.temperature,
|
|
|
|
"n": self.n,
|
|
|
|
"stop": self.stop,
|
|
|
|
"max_tokens": self.max_tokens,
|
|
|
|
"extra_params": self.extra_params,
|
community[patch]: Add param "task" to Databricks LLM to work around serialization of transform_output_fn (#14933)
**What is the reproduce code?**
```python
from langchain.chains import LLMChain, load_chain
from langchain.llms import Databricks
from langchain.prompts import PromptTemplate
def transform_output(response):
# Extract the answer from the responses.
return str(response["candidates"][0]["text"])
def transform_input(**request):
full_prompt = f"""{request["prompt"]}
Be Concise.
"""
request["prompt"] = full_prompt
return request
chat_model = Databricks(
endpoint_name="llama2-13B-chat-Brambles",
transform_input_fn=transform_input,
transform_output_fn=transform_output,
verbose=True,
)
print(f"Test chat model: {chat_model('What is Apache Spark')}") # This works
llm_chain = LLMChain(llm=chat_model, prompt=PromptTemplate.from_template("{chat_input}"))
llm_chain("colorful socks") # this works
llm_chain.save("databricks_llm_chain.yaml") # transform_input_fn and transform_output_fn are not serialized into the model yaml file
loaded_chain = load_chain("databricks_llm_chain.yaml") # The Databricks LLM is recreated with transform_input_fn=None, transform_output_fn=None.
loaded_chain("colorful socks") # Thus this errors. The transform_output_fn is needed to produce the correct output
```
Error:
```
File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-6c34afab-3473-421d-877f-1ef18930ef4d/lib/python3.10/site-packages/pydantic/v1/main.py", line 341, in __init__
raise validation_error
pydantic.v1.error_wrappers.ValidationError: 1 validation error for Generation
text
str type expected (type=type_error.str)
request payload: {'query': 'What is a databricks notebook?'}'}
```
**What does the error mean?**
When the LLM generates an answer, represented by a Generation data
object. The Generation data object takes a str field called text, e.g.
Generation(text=”blah”). However, the Databricks LLM tried to put a
non-str to text, e.g. Generation(text={“candidates”:[{“text”: “blah”}]})
Thus, pydantic errors.
**Why the output format becomes incorrect after saving and loading the
Databricks LLM?**
Databrick LLM does not support serializing transform_input_fn and
transform_output_fn, so they are not serialized into the model yaml
file. When the Databricks LLM is loaded, it is recreated with
transform_input_fn=None, transform_output_fn=None. Without
transform_output_fn, the output text is not unwrapped, thus errors.
Missing transform_output_fn causes this error.
Missing transform_input_fn causes the additional prompt “Be Concise.” to
be lost after saving and loading.
<!-- Thank you for contributing to LangChain!
Replace this entire comment with:
- **Description:** a description of the change,
- **Issue:** the issue # it fixes (if applicable),
- **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!
Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.
See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://python.langchain.com/docs/contributing/
If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.
If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
-->
---------
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-12-20 17:50:23 +00:00
|
|
|
"task": self.task,
|
community[patch]: Support SerDe transform functions in Databricks LLM (#16752)
**Description:** Databricks LLM does not support SerDe the
transform_input_fn and transform_output_fn. After saving and loading,
the LLM will be broken. This PR serialize these functions into a hex
string using pickle, and saving the hex string in the yaml file. Using
pickle to serialize a function can be flaky, but this is a simple
workaround that unblocks many use cases. If more sophisticated SerDe is
needed, we can improve it later.
Test:
Added a simple unit test.
I did manual test on Databricks and it works well.
The saved yaml looks like:
```
llm:
_type: databricks
cluster_driver_port: null
cluster_id: null
databricks_uri: databricks
endpoint_name: databricks-mixtral-8x7b-instruct
extra_params: {}
host: e2-dogfood.staging.cloud.databricks.com
max_tokens: null
model_kwargs: null
n: 1
stop: null
task: null
temperature: 0.0
transform_input_fn: 80049520000000000000008c085f5f6d61696e5f5f948c0f7472616e73666f726d5f696e7075749493942e
transform_output_fn: null
```
@baskaryan
```python
from langchain_community.embeddings import DatabricksEmbeddings
from langchain_community.llms import Databricks
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
import mlflow
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
def transform_input(**request):
request["messages"] = [
{
"role": "user",
"content": request["prompt"]
}
]
del request["prompt"]
return request
llm = Databricks(endpoint_name="databricks-mixtral-8x7b-instruct", transform_input_fn=transform_input)
persist_dir = "faiss_databricks_embedding"
# Create the vector db, persist the db to a local fs folder
loader = TextLoader("state_of_the_union.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
db = FAISS.from_documents(docs, embeddings)
db.save_local(persist_dir)
def load_retriever(persist_directory):
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
vectorstore = FAISS.load_local(persist_directory, embeddings)
return vectorstore.as_retriever()
retriever = load_retriever(persist_dir)
retrievalQA = RetrievalQA.from_llm(llm=llm, retriever=retriever)
with mlflow.start_run() as run:
logged_model = mlflow.langchain.log_model(
retrievalQA,
artifact_path="retrieval_qa",
loader_fn=load_retriever,
persist_dir=persist_dir,
)
# Load the retrievalQA chain
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}]))
```
2024-02-08 21:09:50 +00:00
|
|
|
"transform_input_fn": None
|
|
|
|
if self.transform_input_fn is None
|
|
|
|
else _pickle_fn_to_hex_string(self.transform_input_fn),
|
|
|
|
"transform_output_fn": None
|
|
|
|
if self.transform_output_fn is None
|
|
|
|
else _pickle_fn_to_hex_string(self.transform_output_fn),
|
2023-12-11 21:53:30 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
|
|
return self._default_params
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
"""Return type of llm."""
|
|
|
|
return "databricks"
|
|
|
|
|
|
|
|
def _call(
|
|
|
|
self,
|
|
|
|
prompt: str,
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> str:
|
|
|
|
"""Queries the LLM endpoint with the given prompt and stop sequence."""
|
|
|
|
|
|
|
|
# TODO: support callbacks
|
|
|
|
|
|
|
|
request: Dict[str, Any] = {"prompt": prompt}
|
|
|
|
if self._client.llm:
|
|
|
|
request.update(self._llm_params)
|
|
|
|
request.update(self.model_kwargs or self.extra_params)
|
|
|
|
request.update(kwargs)
|
|
|
|
if stop:
|
|
|
|
request["stop"] = stop
|
|
|
|
|
|
|
|
if self.transform_input_fn:
|
|
|
|
request = self.transform_input_fn(**request)
|
|
|
|
|
|
|
|
return self._client.post(request, transform_output_fn=self.transform_output_fn)
|