mirror of https://github.com/hwchase17/langchain
LLM wrapper for Databricks (#5142)
This PR adds LLM wrapper for Databricks. It supports two endpoint types: * serving endpoint * cluster driver proxy app An integration notebook is included to show how it works. Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com> Co-authored-by: Gengliang Wang <gengliang@apache.org> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>pull/4480/head^2
parent
1cb6498fdb
commit
aec642febb
@ -0,0 +1,323 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator, validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
__all__ = ["Databricks"]
|
||||
|
||||
|
||||
class _DatabricksClientBase(BaseModel, ABC):
|
||||
"""A base JSON API client that talks to Databricks."""
|
||||
|
||||
api_url: str
|
||||
api_token: str
|
||||
|
||||
def post_raw(self, request: Any) -> Any:
|
||||
headers = {"Authorization": f"Bearer {self.api_token}"}
|
||||
response = requests.post(self.api_url, headers=headers, json=request)
|
||||
# TODO: error handling and automatic retries
|
||||
if not response.ok:
|
||||
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
|
||||
return response.json()
|
||||
|
||||
@abstractmethod
|
||||
def post(self, request: Any) -> Any:
|
||||
...
|
||||
|
||||
|
||||
class _DatabricksServingEndpointClient(_DatabricksClientBase):
|
||||
"""An API client that talks to a Databricks serving endpoint."""
|
||||
|
||||
host: str
|
||||
endpoint_name: str
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "api_url" not in values:
|
||||
host = values["host"]
|
||||
endpoint_name = values["endpoint_name"]
|
||||
api_url = f"https://{host}/serving-endpoints/{endpoint_name}/invocations"
|
||||
values["api_url"] = api_url
|
||||
return values
|
||||
|
||||
def post(self, request: Any) -> Any:
|
||||
# See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html
|
||||
wrapped_request = {"dataframe_records": [request]}
|
||||
response = self.post_raw(wrapped_request)["predictions"]
|
||||
# For a single-record query, the result is not a list.
|
||||
if isinstance(response, list):
|
||||
response = response[0]
|
||||
return response
|
||||
|
||||
|
||||
class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
|
||||
"""An API client that talks to a Databricks cluster driver proxy app."""
|
||||
|
||||
host: str
|
||||
cluster_id: str
|
||||
cluster_driver_port: str
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "api_url" not in values:
|
||||
host = values["host"]
|
||||
cluster_id = values["cluster_id"]
|
||||
port = values["cluster_driver_port"]
|
||||
api_url = f"https://{host}/driver-proxy-api/o/0/{cluster_id}/{port}"
|
||||
values["api_url"] = api_url
|
||||
return values
|
||||
|
||||
def post(self, request: Any) -> Any:
|
||||
return self.post_raw(request)
|
||||
|
||||
|
||||
def get_repl_context() -> Any:
|
||||
"""Gets the notebook REPL context if running inside a Databricks notebook.
|
||||
Returns None otherwise.
|
||||
"""
|
||||
try:
|
||||
from dbruntime.databricks_repl_context import get_context
|
||||
|
||||
return get_context()
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Cannot access dbruntime, not running inside a Databricks notebook."
|
||||
)
|
||||
|
||||
|
||||
def get_default_host() -> str:
|
||||
"""Gets the default Databricks workspace hostname.
|
||||
Raises an error if the hostname cannot be automatically determined.
|
||||
"""
|
||||
host = os.getenv("DATABRICKS_HOST")
|
||||
if not host:
|
||||
try:
|
||||
host = get_repl_context().browserHostName
|
||||
if not host:
|
||||
raise ValueError("context doesn't contain browserHostName.")
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"host was not set and cannot be automatically inferred. Set "
|
||||
f"environment variable 'DATABRICKS_HOST'. Received error: {e}"
|
||||
)
|
||||
# TODO: support Databricks CLI profile
|
||||
host = host.lstrip("https://").lstrip("http://").rstrip("/")
|
||||
return host
|
||||
|
||||
|
||||
def get_default_api_token() -> str:
|
||||
"""Gets the default Databricks personal access token.
|
||||
Raises an error if the token cannot be automatically determined.
|
||||
"""
|
||||
if api_token := os.getenv("DATABRICKS_API_TOKEN"):
|
||||
return api_token
|
||||
try:
|
||||
api_token = get_repl_context().apiToken
|
||||
if not api_token:
|
||||
raise ValueError("context doesn't contain apiToken.")
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"api_token was not set and cannot be automatically inferred. Set "
|
||||
f"environment variable 'DATABRICKS_API_TOKEN'. Received error: {e}"
|
||||
)
|
||||
# TODO: support Databricks CLI profile
|
||||
return api_token
|
||||
|
||||
|
||||
class Databricks(LLM):
|
||||
"""LLM wrapper around a Databricks serving endpoint or a cluster driver proxy app.
|
||||
It supports two endpoint types:
|
||||
|
||||
* **Serving endpoint** (recommended for both production and development).
|
||||
We assume that an LLM was registered and deployed to a serving endpoint.
|
||||
To wrap it as an LLM you must have "Can Query" permission to the endpoint.
|
||||
Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and
|
||||
``cluster_driver_port``.
|
||||
The expected model signature is:
|
||||
|
||||
* inputs::
|
||||
|
||||
[{"name": "prompt", "type": "string"},
|
||||
{"name": "stop", "type": "list[string]"}]
|
||||
|
||||
* outputs: ``[{"type": "string"}]``
|
||||
|
||||
* **Cluster driver proxy app** (recommended for interactive development).
|
||||
One can load an LLM on a Databricks interactive cluster and start a local HTTP
|
||||
server on the driver node to serve the model at ``/`` using HTTP POST method
|
||||
with JSON input/output.
|
||||
Please use a port number between ``[3000, 8000]`` and let the server listen to
|
||||
the driver IP address or simply ``0.0.0.0`` instead of localhost only.
|
||||
To wrap it as an LLM you must have "Can Attach To" permission to the cluster.
|
||||
Set ``cluster_id`` and ``cluster_driver_port`` and do not set ``endpoint_name``.
|
||||
The expected server schema (using JSON schema) is:
|
||||
|
||||
* inputs::
|
||||
|
||||
{"type": "object",
|
||||
"properties": {
|
||||
"prompt": {"type": "string"},
|
||||
"stop": {"type": "array", "items": {"type": "string"}}},
|
||||
"required": ["prompt"]}`
|
||||
|
||||
* outputs: ``{"type": "string"}``
|
||||
|
||||
If the endpoint model signature is different or you want to set extra params,
|
||||
you can use `transform_input_fn` and `transform_output_fn` to apply necessary
|
||||
transformations before and after the query.
|
||||
"""
|
||||
|
||||
host: str = Field(default_factory=get_default_host)
|
||||
"""Databricks workspace hostname.
|
||||
If not provided, the default value is determined by
|
||||
|
||||
* the ``DATABRICKS_HOST`` environment variable if present, or
|
||||
* the hostname of the current Databricks workspace if running inside
|
||||
a Databricks notebook attached to an interactive cluster in "single user"
|
||||
or "no isolation shared" mode.
|
||||
"""
|
||||
|
||||
api_token: str = Field(default_factory=get_default_api_token)
|
||||
"""Databricks personal access token.
|
||||
If not provided, the default value is determined by
|
||||
|
||||
* the ``DATABRICKS_API_TOKEN`` environment variable if present, or
|
||||
* an automatically generated temporary token if running inside a Databricks
|
||||
notebook attached to an interactive cluster in "single user" or
|
||||
"no isolation shared" mode.
|
||||
"""
|
||||
|
||||
endpoint_name: Optional[str] = None
|
||||
"""Name of the model serving endpont.
|
||||
You must specify the endpoint name to connect to a model serving endpoint.
|
||||
You must not set both ``endpoint_name`` and ``cluster_id``.
|
||||
"""
|
||||
|
||||
cluster_id: Optional[str] = None
|
||||
"""ID of the cluster if connecting to a cluster driver proxy app.
|
||||
If neither ``endpoint_name`` nor ``cluster_id`` is not provided and the code runs
|
||||
inside a Databricks notebook attached to an interactive cluster in "single user"
|
||||
or "no isolation shared" mode, the current cluster ID is used as default.
|
||||
You must not set both ``endpoint_name`` and ``cluster_id``.
|
||||
"""
|
||||
|
||||
cluster_driver_port: Optional[str] = None
|
||||
"""The port number used by the HTTP server running on the cluster driver node.
|
||||
The server should listen on the driver IP address or simply ``0.0.0.0`` to connect.
|
||||
We recommend the server using a port number between ``[3000, 8000]``.
|
||||
"""
|
||||
|
||||
model_kwargs: Optional[Dict[str, Any]] = None
|
||||
"""Extra parameters to pass to the endpoint."""
|
||||
|
||||
transform_input_fn: Optional[Callable] = None
|
||||
"""A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible
|
||||
request object that the endpoint accepts.
|
||||
For example, you can apply a prompt template to the input prompt.
|
||||
"""
|
||||
|
||||
transform_output_fn: Optional[Callable[..., str]] = None
|
||||
"""A function that transforms the output from the endpoint to the generated text.
|
||||
"""
|
||||
|
||||
_client: _DatabricksClientBase = PrivateAttr()
|
||||
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
underscore_attrs_are_private = True
|
||||
|
||||
@validator("cluster_id", always=True)
|
||||
def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
|
||||
if v and values["endpoint_name"]:
|
||||
raise ValueError("Cannot set both endpoint_name and cluster_id.")
|
||||
elif values["endpoint_name"]:
|
||||
return None
|
||||
elif v:
|
||||
return v
|
||||
else:
|
||||
try:
|
||||
if v := get_repl_context().clusterId:
|
||||
return v
|
||||
raise ValueError("Context doesn't contain clusterId.")
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Neither endpoint_name nor cluster_id was set. "
|
||||
"And the cluster_id cannot be automatically determined. Received"
|
||||
f" error: {e}"
|
||||
)
|
||||
|
||||
@validator("cluster_driver_port", always=True)
|
||||
def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
|
||||
if v and values["endpoint_name"]:
|
||||
raise ValueError("Cannot set both endpoint_name and cluster_driver_port.")
|
||||
elif values["endpoint_name"]:
|
||||
return None
|
||||
elif v is None:
|
||||
raise ValueError(
|
||||
"Must set cluster_driver_port to connect to a cluster driver."
|
||||
)
|
||||
elif int(v) <= 0:
|
||||
raise ValueError(f"Invalid cluster_driver_port: {v}")
|
||||
else:
|
||||
return v
|
||||
|
||||
@validator("model_kwargs", always=True)
|
||||
def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
if v:
|
||||
assert "prompt" not in v, "model_kwargs must not contain key 'prompt'"
|
||||
assert "stop" not in v, "model_kwargs must not contain key 'stop'"
|
||||
return v
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
if self.endpoint_name:
|
||||
self._client = _DatabricksServingEndpointClient(
|
||||
host=self.host,
|
||||
api_token=self.api_token,
|
||||
endpoint_name=self.endpoint_name,
|
||||
)
|
||||
elif self.cluster_id and self.cluster_driver_port:
|
||||
self._client = _DatabricksClusterDriverProxyClient(
|
||||
host=self.host,
|
||||
api_token=self.api_token,
|
||||
cluster_id=self.cluster_id,
|
||||
cluster_driver_port=self.cluster_driver_port,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must specify either endpoint_name or cluster_id/cluster_driver_port."
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "databricks"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Queries the LLM endpoint with the given prompt and stop sequence."""
|
||||
|
||||
# TODO: support callbacks
|
||||
|
||||
request = {"prompt": prompt, "stop": stop}
|
||||
if self.model_kwargs:
|
||||
request.update(self.model_kwargs)
|
||||
|
||||
if self.transform_input_fn:
|
||||
request = self.transform_input_fn(**request)
|
||||
|
||||
response = self._client.post(request)
|
||||
|
||||
if self.transform_output_fn:
|
||||
response = self.transform_output_fn(response)
|
||||
|
||||
return response
|
Loading…
Reference in New Issue