From aec642febb3daa7dbb6a19996aac2efa92bbf1bd Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 25 May 2023 19:19:37 -0700 Subject: [PATCH] LLM wrapper for Databricks (#5142) This PR adds LLM wrapper for Databricks. It supports two endpoint types: * serving endpoint * cluster driver proxy app An integration notebook is included to show how it works. Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com> Co-authored-by: Gengliang Wang Co-authored-by: Dev 2049 --- .../models/llms/integrations/databricks.ipynb | 523 ++++++++++++++++++ langchain/llms/__init__.py | 3 + langchain/llms/databricks.py | 323 +++++++++++ 3 files changed, 849 insertions(+) create mode 100644 docs/modules/models/llms/integrations/databricks.ipynb create mode 100644 langchain/llms/databricks.py diff --git a/docs/modules/models/llms/integrations/databricks.ipynb b/docs/modules/models/llms/integrations/databricks.ipynb new file mode 100644 index 0000000000..68425cf40d --- /dev/null +++ b/docs/modules/models/llms/integrations/databricks.ipynb @@ -0,0 +1,523 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "5147e458-3b83-449e-9c2f-e7e1972e43fc", + "showTitle": false, + "title": "" + } + }, + "source": [ + "# Databricks\n", + "\n", + "The [Databricks](https://www.databricks.com/) Lakehouse Platform unifies data, analytics, and AI on one platform.\n", + "\n", + "This example notebook shows how to wrap Databricks endpoints as LLMs in LangChain.\n", + "It supports two endpoint types:\n", + "* Serving endpoint, recommended for production and development,\n", + "* Cluster driver proxy app, recommended for iteractive development." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "bf07455f-aac9-4873-a8e7-7952af0f8c82", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "from langchain.llms import Databricks" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "94f6540e-40cd-4d9b-95d3-33d36f061dcc", + "showTitle": false, + "title": "" + } + }, + "source": [ + "## Wrapping a serving endpoint\n", + "\n", + "Prerequisites:\n", + "* An LLM was registered and deployed to [a Databricks serving endpoint](https://docs.databricks.com/machine-learning/model-serving/index.html).\n", + "* You have [\"Can Query\" permission](https://docs.databricks.com/security/auth-authz/access-control/serving-endpoint-acl.html) to the endpoint.\n", + "\n", + "The expected MLflow model signature is:\n", + " * inputs: `[{\"name\": \"prompt\", \"type\": \"string\"}, {\"name\": \"stop\", \"type\": \"list[string]\"}]`\n", + " * outputs: `[{\"type\": \"string\"}]`\n", + "\n", + "If the model signature is incompatible or you want to insert extra configs, you can set `transform_input_fn` and `transform_output_fn` accordingly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "7496dc7a-8a1a-4ce6-9648-4f69ed25275b", + "showTitle": false, + "title": "" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'I am happy to hear that you are in good health and as always, you are appreciated.'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# If running a Databricks notebook attached to an interactive cluster in \"single user\" \n", + "# or \"no isolation shared\" mode, you only need to specify the endpoint name to create \n", + "# a `Databricks` instance to query a serving endpoint in the same workspace.\n", + "llm = Databricks(endpoint_name=\"dolly\")\n", + "\n", + "llm(\"How are you?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "0c86d952-4236-4a5e-bdac-cf4e3ccf3a16", + "showTitle": false, + "title": "" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'Good'" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm(\"How are you?\", stop=[\".\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "5f2507a2-addd-431d-9da5-dc2ae33783f6", + "showTitle": false, + "title": "" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'I am fine. Thank you!'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Otherwise, you can manually specify the Databricks workspace hostname and personal access token \n", + "# or set `DATABRICKS_HOST` and `DATABRICKS_API_TOKEN` environment variables, respectively.\n", + "# See https://docs.databricks.com/dev-tools/auth.html#databricks-personal-access-tokens\n", + "# We strongly recommend not exposing the API token explicitly inside a notebook.\n", + "# You can use Databricks secret manager to store your API token securely.\n", + "# See https://docs.databricks.com/dev-tools/databricks-utils.html#secrets-utility-dbutilssecrets\n", + "\n", + "import os\n", + "os.environ[\"DATABRICKS_API_TOKEN\"] = dbutils.secrets.get(\"myworkspace\", \"api_token\")\n", + "\n", + "llm = Databricks(host=\"myworkspace.cloud.databricks.com\", endpoint_name=\"dolly\")\n", + "\n", + "llm(\"How are you?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "9b54f8ce-ffe5-4c47-a3f0-b4ebde524a6a", + "showTitle": false, + "title": "" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'I am fine.'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# If the serving endpoint accepts extra parameters like `temperature`,\n", + "# you can set them in `model_kwargs`.\n", + "llm = Databricks(endpoint_name=\"dolly\", model_kwargs={\"temperature\": 0.1})\n", + "\n", + "llm(\"How are you?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "50f172f5-ea1f-4ceb-8cf1-20289848de7b", + "showTitle": false, + "title": "" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'I’m Excellent. You?'" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Use `transform_input_fn` and `transform_output_fn` if the serving endpoint\n", + "# expects a different input schema and does not return a JSON string,\n", + "# respectively, or you want to apply a prompt template on top.\n", + "\n", + "def transform_input(**request):\n", + " full_prompt = f\"\"\"{request[\"prompt\"]}\n", + " Be Concise.\n", + " \"\"\"\n", + " request[\"prompt\"] = full_prompt\n", + " return request\n", + "\n", + "llm = Databricks(endpoint_name=\"dolly\", transform_input_fn=transform_input)\n", + "\n", + "llm(\"How are you?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "8ea49319-a041-494d-afcd-87bcf00d5efb", + "showTitle": false, + "title": "" + } + }, + "source": [ + "## Wrapping a cluster driver proxy app\n", + "\n", + "Prerequisites:\n", + "* An LLM loaded on a Databricks interactive cluster in \"single user\" or \"no isolation shared\" mode.\n", + "* A local HTTP server running on the driver node to serve the model at `\"/\"` using HTTP POST with JSON input/output.\n", + "* It uses a port number between `[3000, 8000]` and litens to the driver IP address or simply `0.0.0.0` instead of localhost only.\n", + "* You have \"Can Attach To\" permission to the cluster.\n", + "\n", + "The expected server schema (using JSON schema) is:\n", + "* inputs:\n", + " ```json\n", + " {\"type\": \"object\",\n", + " \"properties\": {\n", + " \"prompt\": {\"type\": \"string\"},\n", + " \"stop\": {\"type\": \"array\", \"items\": {\"type\": \"string\"}}},\n", + " \"required\": [\"prompt\"]}\n", + " ```\n", + "* outputs: `{\"type\": \"string\"}`\n", + "\n", + "If the server schema is incompatible or you want to insert extra configs, you can use `transform_input_fn` and `transform_output_fn` accordingly.\n", + "\n", + "The following is a minimal example for running a driver proxy app to serve an LLM:\n", + "\n", + "```python\n", + "from flask import Flask, request, jsonify\n", + "import torch\n", + "from transformers import pipeline, AutoTokenizer, StoppingCriteria\n", + "\n", + "model = \"databricks/dolly-v2-3b\"\n", + "tokenizer = AutoTokenizer.from_pretrained(model, padding_side=\"left\")\n", + "dolly = pipeline(model=model, tokenizer=tokenizer, trust_remote_code=True, device_map=\"auto\")\n", + "device = dolly.device\n", + "\n", + "class CheckStop(StoppingCriteria):\n", + " def __init__(self, stop=None):\n", + " super().__init__()\n", + " self.stop = stop or []\n", + " self.matched = \"\"\n", + " self.stop_ids = [tokenizer.encode(s, return_tensors='pt').to(device) for s in self.stop]\n", + " def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):\n", + " for i, s in enumerate(self.stop_ids):\n", + " if torch.all((s == input_ids[0][-s.shape[1]:])).item():\n", + " self.matched = self.stop[i]\n", + " return True\n", + " return False\n", + "\n", + "def llm(prompt, stop=None, **kwargs):\n", + " check_stop = CheckStop(stop)\n", + " result = dolly(prompt, stopping_criteria=[check_stop], **kwargs)\n", + " return result[0][\"generated_text\"].rstrip(check_stop.matched)\n", + "\n", + "app = Flask(\"dolly\")\n", + "\n", + "@app.route('/', methods=['POST'])\n", + "def serve_llm():\n", + " resp = llm(**request.json)\n", + " return jsonify(resp)\n", + "\n", + "app.run(host=\"0.0.0.0\", port=\"7777\")\n", + "```\n", + "\n", + "Once the server is running, you can create a `Databricks` instance to wrap it as an LLM." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "e3330a01-e738-4170-a176-9954aff56442", + "showTitle": false, + "title": "" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'Hello, thank you for asking. It is wonderful to hear that you are well.'" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# If running a Databricks notebook attached to the same cluster that runs the app,\n", + "# you only need to specify the driver port to create a `Databricks` instance.\n", + "llm = Databricks(cluster_driver_port=\"7777\")\n", + "\n", + "llm(\"How are you?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "39c121cf-0e44-4e31-91db-37fcac459677", + "showTitle": false, + "title": "" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'I am well. You?'" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Otherwise, you can manually specify the cluster ID to use,\n", + "# as well as Databricks workspace hostname and personal access token.\n", + "\n", + "llm = Databricks(cluster_id=\"0000-000000-xxxxxxxx\", cluster_driver_port=\"7777\")\n", + "\n", + "llm(\"How are you?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "3d3de599-82fd-45e4-8d8b-bacfc49dc9ce", + "showTitle": false, + "title": "" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'I am very well. It is a pleasure to meet you.'" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# If the app accepts extra parameters like `temperature`,\n", + "# you can set them in `model_kwargs`.\n", + "llm = Databricks(cluster_driver_port=\"7777\", model_kwargs={\"temperature\": 0.1})\n", + "\n", + "llm(\"How are you?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "853fae8e-8df4-41e6-9d45-7769f883fe80", + "showTitle": false, + "title": "" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'I AM DOING GREAT THANK YOU.'" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Use `transform_input_fn` and `transform_output_fn` if the app\n", + "# expects a different input schema and does not return a JSON string,\n", + "# respectively, or you want to apply a prompt template on top.\n", + "\n", + "def transform_input(**request):\n", + " full_prompt = f\"\"\"{request[\"prompt\"]}\n", + " Be Concise.\n", + " \"\"\"\n", + " request[\"prompt\"] = full_prompt\n", + " return request\n", + "\n", + "def transform_output(response):\n", + " return response.upper()\n", + "\n", + "llm = Databricks(\n", + " cluster_driver_port=\"7777\",\n", + " transform_input_fn=transform_input,\n", + " transform_output_fn=transform_output)\n", + "\n", + "llm(\"How are you?\")" + ] + } + ], + "metadata": { + "application/vnd.databricks.v1+notebook": { + "dashboards": [], + "language": "python", + "notebookMetadata": { + "pythonIndentUnit": 2 + }, + "notebookName": "databricks", + "widgets": {} + }, + "kernelspec": { + "display_name": "llm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index 29e66ce947..3e62a59d5f 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -11,6 +11,7 @@ from langchain.llms.beam import Beam from langchain.llms.cerebriumai import CerebriumAI from langchain.llms.cohere import Cohere from langchain.llms.ctransformers import CTransformers +from langchain.llms.databricks import Databricks from langchain.llms.deepinfra import DeepInfra from langchain.llms.fake import FakeListLLM from langchain.llms.forefrontai import ForefrontAI @@ -50,6 +51,7 @@ __all__ = [ "CerebriumAI", "Cohere", "CTransformers", + "Databricks", "DeepInfra", "ForefrontAI", "GooglePalm", @@ -95,6 +97,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "cerebriumai": CerebriumAI, "cohere": Cohere, "ctransformers": CTransformers, + "databricks": Databricks, "deepinfra": DeepInfra, "forefrontai": ForefrontAI, "google_palm": GooglePalm, diff --git a/langchain/llms/databricks.py b/langchain/llms/databricks.py new file mode 100644 index 0000000000..d3ba3a1604 --- /dev/null +++ b/langchain/llms/databricks.py @@ -0,0 +1,323 @@ +import os +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional + +import requests +from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator, validator + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM + +__all__ = ["Databricks"] + + +class _DatabricksClientBase(BaseModel, ABC): + """A base JSON API client that talks to Databricks.""" + + api_url: str + api_token: str + + def post_raw(self, request: Any) -> Any: + headers = {"Authorization": f"Bearer {self.api_token}"} + response = requests.post(self.api_url, headers=headers, json=request) + # TODO: error handling and automatic retries + if not response.ok: + raise ValueError(f"HTTP {response.status_code} error: {response.text}") + return response.json() + + @abstractmethod + def post(self, request: Any) -> Any: + ... + + +class _DatabricksServingEndpointClient(_DatabricksClientBase): + """An API client that talks to a Databricks serving endpoint.""" + + host: str + endpoint_name: str + + @root_validator(pre=True) + def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "api_url" not in values: + host = values["host"] + endpoint_name = values["endpoint_name"] + api_url = f"https://{host}/serving-endpoints/{endpoint_name}/invocations" + values["api_url"] = api_url + return values + + def post(self, request: Any) -> Any: + # See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html + wrapped_request = {"dataframe_records": [request]} + response = self.post_raw(wrapped_request)["predictions"] + # For a single-record query, the result is not a list. + if isinstance(response, list): + response = response[0] + return response + + +class _DatabricksClusterDriverProxyClient(_DatabricksClientBase): + """An API client that talks to a Databricks cluster driver proxy app.""" + + host: str + cluster_id: str + cluster_driver_port: str + + @root_validator(pre=True) + def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "api_url" not in values: + host = values["host"] + cluster_id = values["cluster_id"] + port = values["cluster_driver_port"] + api_url = f"https://{host}/driver-proxy-api/o/0/{cluster_id}/{port}" + values["api_url"] = api_url + return values + + def post(self, request: Any) -> Any: + return self.post_raw(request) + + +def get_repl_context() -> Any: + """Gets the notebook REPL context if running inside a Databricks notebook. + Returns None otherwise. + """ + try: + from dbruntime.databricks_repl_context import get_context + + return get_context() + except ImportError: + raise ValueError( + "Cannot access dbruntime, not running inside a Databricks notebook." + ) + + +def get_default_host() -> str: + """Gets the default Databricks workspace hostname. + Raises an error if the hostname cannot be automatically determined. + """ + host = os.getenv("DATABRICKS_HOST") + if not host: + try: + host = get_repl_context().browserHostName + if not host: + raise ValueError("context doesn't contain browserHostName.") + except Exception as e: + raise ValueError( + "host was not set and cannot be automatically inferred. Set " + f"environment variable 'DATABRICKS_HOST'. Received error: {e}" + ) + # TODO: support Databricks CLI profile + host = host.lstrip("https://").lstrip("http://").rstrip("/") + return host + + +def get_default_api_token() -> str: + """Gets the default Databricks personal access token. + Raises an error if the token cannot be automatically determined. + """ + if api_token := os.getenv("DATABRICKS_API_TOKEN"): + return api_token + try: + api_token = get_repl_context().apiToken + if not api_token: + raise ValueError("context doesn't contain apiToken.") + except Exception as e: + raise ValueError( + "api_token was not set and cannot be automatically inferred. Set " + f"environment variable 'DATABRICKS_API_TOKEN'. Received error: {e}" + ) + # TODO: support Databricks CLI profile + return api_token + + +class Databricks(LLM): + """LLM wrapper around a Databricks serving endpoint or a cluster driver proxy app. + It supports two endpoint types: + + * **Serving endpoint** (recommended for both production and development). + We assume that an LLM was registered and deployed to a serving endpoint. + To wrap it as an LLM you must have "Can Query" permission to the endpoint. + Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and + ``cluster_driver_port``. + The expected model signature is: + + * inputs:: + + [{"name": "prompt", "type": "string"}, + {"name": "stop", "type": "list[string]"}] + + * outputs: ``[{"type": "string"}]`` + + * **Cluster driver proxy app** (recommended for interactive development). + One can load an LLM on a Databricks interactive cluster and start a local HTTP + server on the driver node to serve the model at ``/`` using HTTP POST method + with JSON input/output. + Please use a port number between ``[3000, 8000]`` and let the server listen to + the driver IP address or simply ``0.0.0.0`` instead of localhost only. + To wrap it as an LLM you must have "Can Attach To" permission to the cluster. + Set ``cluster_id`` and ``cluster_driver_port`` and do not set ``endpoint_name``. + The expected server schema (using JSON schema) is: + + * inputs:: + + {"type": "object", + "properties": { + "prompt": {"type": "string"}, + "stop": {"type": "array", "items": {"type": "string"}}}, + "required": ["prompt"]}` + + * outputs: ``{"type": "string"}`` + + If the endpoint model signature is different or you want to set extra params, + you can use `transform_input_fn` and `transform_output_fn` to apply necessary + transformations before and after the query. + """ + + host: str = Field(default_factory=get_default_host) + """Databricks workspace hostname. + If not provided, the default value is determined by + + * the ``DATABRICKS_HOST`` environment variable if present, or + * the hostname of the current Databricks workspace if running inside + a Databricks notebook attached to an interactive cluster in "single user" + or "no isolation shared" mode. + """ + + api_token: str = Field(default_factory=get_default_api_token) + """Databricks personal access token. + If not provided, the default value is determined by + + * the ``DATABRICKS_API_TOKEN`` environment variable if present, or + * an automatically generated temporary token if running inside a Databricks + notebook attached to an interactive cluster in "single user" or + "no isolation shared" mode. + """ + + endpoint_name: Optional[str] = None + """Name of the model serving endpont. + You must specify the endpoint name to connect to a model serving endpoint. + You must not set both ``endpoint_name`` and ``cluster_id``. + """ + + cluster_id: Optional[str] = None + """ID of the cluster if connecting to a cluster driver proxy app. + If neither ``endpoint_name`` nor ``cluster_id`` is not provided and the code runs + inside a Databricks notebook attached to an interactive cluster in "single user" + or "no isolation shared" mode, the current cluster ID is used as default. + You must not set both ``endpoint_name`` and ``cluster_id``. + """ + + cluster_driver_port: Optional[str] = None + """The port number used by the HTTP server running on the cluster driver node. + The server should listen on the driver IP address or simply ``0.0.0.0`` to connect. + We recommend the server using a port number between ``[3000, 8000]``. + """ + + model_kwargs: Optional[Dict[str, Any]] = None + """Extra parameters to pass to the endpoint.""" + + transform_input_fn: Optional[Callable] = None + """A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible + request object that the endpoint accepts. + For example, you can apply a prompt template to the input prompt. + """ + + transform_output_fn: Optional[Callable[..., str]] = None + """A function that transforms the output from the endpoint to the generated text. + """ + + _client: _DatabricksClientBase = PrivateAttr() + + class Config: + extra = Extra.forbid + underscore_attrs_are_private = True + + @validator("cluster_id", always=True) + def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]: + if v and values["endpoint_name"]: + raise ValueError("Cannot set both endpoint_name and cluster_id.") + elif values["endpoint_name"]: + return None + elif v: + return v + else: + try: + if v := get_repl_context().clusterId: + return v + raise ValueError("Context doesn't contain clusterId.") + except Exception as e: + raise ValueError( + "Neither endpoint_name nor cluster_id was set. " + "And the cluster_id cannot be automatically determined. Received" + f" error: {e}" + ) + + @validator("cluster_driver_port", always=True) + def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]: + if v and values["endpoint_name"]: + raise ValueError("Cannot set both endpoint_name and cluster_driver_port.") + elif values["endpoint_name"]: + return None + elif v is None: + raise ValueError( + "Must set cluster_driver_port to connect to a cluster driver." + ) + elif int(v) <= 0: + raise ValueError(f"Invalid cluster_driver_port: {v}") + else: + return v + + @validator("model_kwargs", always=True) + def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if v: + assert "prompt" not in v, "model_kwargs must not contain key 'prompt'" + assert "stop" not in v, "model_kwargs must not contain key 'stop'" + return v + + def __init__(self, **data: Any): + super().__init__(**data) + if self.endpoint_name: + self._client = _DatabricksServingEndpointClient( + host=self.host, + api_token=self.api_token, + endpoint_name=self.endpoint_name, + ) + elif self.cluster_id and self.cluster_driver_port: + self._client = _DatabricksClusterDriverProxyClient( + host=self.host, + api_token=self.api_token, + cluster_id=self.cluster_id, + cluster_driver_port=self.cluster_driver_port, + ) + else: + raise ValueError( + "Must specify either endpoint_name or cluster_id/cluster_driver_port." + ) + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "databricks" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: + """Queries the LLM endpoint with the given prompt and stop sequence.""" + + # TODO: support callbacks + + request = {"prompt": prompt, "stop": stop} + if self.model_kwargs: + request.update(self.model_kwargs) + + if self.transform_input_fn: + request = self.transform_input_fn(**request) + + response = self._client.post(request) + + if self.transform_output_fn: + response = self.transform_output_fn(response) + + return response