mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
LLM wrapper for Databricks (#5142)
This PR adds LLM wrapper for Databricks. It supports two endpoint types: * serving endpoint * cluster driver proxy app An integration notebook is included to show how it works. Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com> Co-authored-by: Gengliang Wang <gengliang@apache.org> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
1cb6498fdb
commit
aec642febb
523
docs/modules/models/llms/integrations/databricks.ipynb
Normal file
523
docs/modules/models/llms/integrations/databricks.ipynb
Normal file
@ -0,0 +1,523 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "5147e458-3b83-449e-9c2f-e7e1972e43fc",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Databricks\n",
|
||||||
|
"\n",
|
||||||
|
"The [Databricks](https://www.databricks.com/) Lakehouse Platform unifies data, analytics, and AI on one platform.\n",
|
||||||
|
"\n",
|
||||||
|
"This example notebook shows how to wrap Databricks endpoints as LLMs in LangChain.\n",
|
||||||
|
"It supports two endpoint types:\n",
|
||||||
|
"* Serving endpoint, recommended for production and development,\n",
|
||||||
|
"* Cluster driver proxy app, recommended for iteractive development."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "bf07455f-aac9-4873-a8e7-7952af0f8c82",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.llms import Databricks"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "94f6540e-40cd-4d9b-95d3-33d36f061dcc",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Wrapping a serving endpoint\n",
|
||||||
|
"\n",
|
||||||
|
"Prerequisites:\n",
|
||||||
|
"* An LLM was registered and deployed to [a Databricks serving endpoint](https://docs.databricks.com/machine-learning/model-serving/index.html).\n",
|
||||||
|
"* You have [\"Can Query\" permission](https://docs.databricks.com/security/auth-authz/access-control/serving-endpoint-acl.html) to the endpoint.\n",
|
||||||
|
"\n",
|
||||||
|
"The expected MLflow model signature is:\n",
|
||||||
|
" * inputs: `[{\"name\": \"prompt\", \"type\": \"string\"}, {\"name\": \"stop\", \"type\": \"list[string]\"}]`\n",
|
||||||
|
" * outputs: `[{\"type\": \"string\"}]`\n",
|
||||||
|
"\n",
|
||||||
|
"If the model signature is incompatible or you want to insert extra configs, you can set `transform_input_fn` and `transform_output_fn` accordingly."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "7496dc7a-8a1a-4ce6-9648-4f69ed25275b",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'I am happy to hear that you are in good health and as always, you are appreciated.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# If running a Databricks notebook attached to an interactive cluster in \"single user\" \n",
|
||||||
|
"# or \"no isolation shared\" mode, you only need to specify the endpoint name to create \n",
|
||||||
|
"# a `Databricks` instance to query a serving endpoint in the same workspace.\n",
|
||||||
|
"llm = Databricks(endpoint_name=\"dolly\")\n",
|
||||||
|
"\n",
|
||||||
|
"llm(\"How are you?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "0c86d952-4236-4a5e-bdac-cf4e3ccf3a16",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'Good'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 34,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"llm(\"How are you?\", stop=[\".\"])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "5f2507a2-addd-431d-9da5-dc2ae33783f6",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'I am fine. Thank you!'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Otherwise, you can manually specify the Databricks workspace hostname and personal access token \n",
|
||||||
|
"# or set `DATABRICKS_HOST` and `DATABRICKS_API_TOKEN` environment variables, respectively.\n",
|
||||||
|
"# See https://docs.databricks.com/dev-tools/auth.html#databricks-personal-access-tokens\n",
|
||||||
|
"# We strongly recommend not exposing the API token explicitly inside a notebook.\n",
|
||||||
|
"# You can use Databricks secret manager to store your API token securely.\n",
|
||||||
|
"# See https://docs.databricks.com/dev-tools/databricks-utils.html#secrets-utility-dbutilssecrets\n",
|
||||||
|
"\n",
|
||||||
|
"import os\n",
|
||||||
|
"os.environ[\"DATABRICKS_API_TOKEN\"] = dbutils.secrets.get(\"myworkspace\", \"api_token\")\n",
|
||||||
|
"\n",
|
||||||
|
"llm = Databricks(host=\"myworkspace.cloud.databricks.com\", endpoint_name=\"dolly\")\n",
|
||||||
|
"\n",
|
||||||
|
"llm(\"How are you?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "9b54f8ce-ffe5-4c47-a3f0-b4ebde524a6a",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'I am fine.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# If the serving endpoint accepts extra parameters like `temperature`,\n",
|
||||||
|
"# you can set them in `model_kwargs`.\n",
|
||||||
|
"llm = Databricks(endpoint_name=\"dolly\", model_kwargs={\"temperature\": 0.1})\n",
|
||||||
|
"\n",
|
||||||
|
"llm(\"How are you?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "50f172f5-ea1f-4ceb-8cf1-20289848de7b",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'I’m Excellent. You?'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Use `transform_input_fn` and `transform_output_fn` if the serving endpoint\n",
|
||||||
|
"# expects a different input schema and does not return a JSON string,\n",
|
||||||
|
"# respectively, or you want to apply a prompt template on top.\n",
|
||||||
|
"\n",
|
||||||
|
"def transform_input(**request):\n",
|
||||||
|
" full_prompt = f\"\"\"{request[\"prompt\"]}\n",
|
||||||
|
" Be Concise.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" request[\"prompt\"] = full_prompt\n",
|
||||||
|
" return request\n",
|
||||||
|
"\n",
|
||||||
|
"llm = Databricks(endpoint_name=\"dolly\", transform_input_fn=transform_input)\n",
|
||||||
|
"\n",
|
||||||
|
"llm(\"How are you?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "8ea49319-a041-494d-afcd-87bcf00d5efb",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Wrapping a cluster driver proxy app\n",
|
||||||
|
"\n",
|
||||||
|
"Prerequisites:\n",
|
||||||
|
"* An LLM loaded on a Databricks interactive cluster in \"single user\" or \"no isolation shared\" mode.\n",
|
||||||
|
"* A local HTTP server running on the driver node to serve the model at `\"/\"` using HTTP POST with JSON input/output.\n",
|
||||||
|
"* It uses a port number between `[3000, 8000]` and litens to the driver IP address or simply `0.0.0.0` instead of localhost only.\n",
|
||||||
|
"* You have \"Can Attach To\" permission to the cluster.\n",
|
||||||
|
"\n",
|
||||||
|
"The expected server schema (using JSON schema) is:\n",
|
||||||
|
"* inputs:\n",
|
||||||
|
" ```json\n",
|
||||||
|
" {\"type\": \"object\",\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"prompt\": {\"type\": \"string\"},\n",
|
||||||
|
" \"stop\": {\"type\": \"array\", \"items\": {\"type\": \"string\"}}},\n",
|
||||||
|
" \"required\": [\"prompt\"]}\n",
|
||||||
|
" ```\n",
|
||||||
|
"* outputs: `{\"type\": \"string\"}`\n",
|
||||||
|
"\n",
|
||||||
|
"If the server schema is incompatible or you want to insert extra configs, you can use `transform_input_fn` and `transform_output_fn` accordingly.\n",
|
||||||
|
"\n",
|
||||||
|
"The following is a minimal example for running a driver proxy app to serve an LLM:\n",
|
||||||
|
"\n",
|
||||||
|
"```python\n",
|
||||||
|
"from flask import Flask, request, jsonify\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from transformers import pipeline, AutoTokenizer, StoppingCriteria\n",
|
||||||
|
"\n",
|
||||||
|
"model = \"databricks/dolly-v2-3b\"\n",
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained(model, padding_side=\"left\")\n",
|
||||||
|
"dolly = pipeline(model=model, tokenizer=tokenizer, trust_remote_code=True, device_map=\"auto\")\n",
|
||||||
|
"device = dolly.device\n",
|
||||||
|
"\n",
|
||||||
|
"class CheckStop(StoppingCriteria):\n",
|
||||||
|
" def __init__(self, stop=None):\n",
|
||||||
|
" super().__init__()\n",
|
||||||
|
" self.stop = stop or []\n",
|
||||||
|
" self.matched = \"\"\n",
|
||||||
|
" self.stop_ids = [tokenizer.encode(s, return_tensors='pt').to(device) for s in self.stop]\n",
|
||||||
|
" def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):\n",
|
||||||
|
" for i, s in enumerate(self.stop_ids):\n",
|
||||||
|
" if torch.all((s == input_ids[0][-s.shape[1]:])).item():\n",
|
||||||
|
" self.matched = self.stop[i]\n",
|
||||||
|
" return True\n",
|
||||||
|
" return False\n",
|
||||||
|
"\n",
|
||||||
|
"def llm(prompt, stop=None, **kwargs):\n",
|
||||||
|
" check_stop = CheckStop(stop)\n",
|
||||||
|
" result = dolly(prompt, stopping_criteria=[check_stop], **kwargs)\n",
|
||||||
|
" return result[0][\"generated_text\"].rstrip(check_stop.matched)\n",
|
||||||
|
"\n",
|
||||||
|
"app = Flask(\"dolly\")\n",
|
||||||
|
"\n",
|
||||||
|
"@app.route('/', methods=['POST'])\n",
|
||||||
|
"def serve_llm():\n",
|
||||||
|
" resp = llm(**request.json)\n",
|
||||||
|
" return jsonify(resp)\n",
|
||||||
|
"\n",
|
||||||
|
"app.run(host=\"0.0.0.0\", port=\"7777\")\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"Once the server is running, you can create a `Databricks` instance to wrap it as an LLM."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "e3330a01-e738-4170-a176-9954aff56442",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'Hello, thank you for asking. It is wonderful to hear that you are well.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 32,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# If running a Databricks notebook attached to the same cluster that runs the app,\n",
|
||||||
|
"# you only need to specify the driver port to create a `Databricks` instance.\n",
|
||||||
|
"llm = Databricks(cluster_driver_port=\"7777\")\n",
|
||||||
|
"\n",
|
||||||
|
"llm(\"How are you?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "39c121cf-0e44-4e31-91db-37fcac459677",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'I am well. You?'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 40,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Otherwise, you can manually specify the cluster ID to use,\n",
|
||||||
|
"# as well as Databricks workspace hostname and personal access token.\n",
|
||||||
|
"\n",
|
||||||
|
"llm = Databricks(cluster_id=\"0000-000000-xxxxxxxx\", cluster_driver_port=\"7777\")\n",
|
||||||
|
"\n",
|
||||||
|
"llm(\"How are you?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "3d3de599-82fd-45e4-8d8b-bacfc49dc9ce",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'I am very well. It is a pleasure to meet you.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 31,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# If the app accepts extra parameters like `temperature`,\n",
|
||||||
|
"# you can set them in `model_kwargs`.\n",
|
||||||
|
"llm = Databricks(cluster_driver_port=\"7777\", model_kwargs={\"temperature\": 0.1})\n",
|
||||||
|
"\n",
|
||||||
|
"llm(\"How are you?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "853fae8e-8df4-41e6-9d45-7769f883fe80",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'I AM DOING GREAT THANK YOU.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 32,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Use `transform_input_fn` and `transform_output_fn` if the app\n",
|
||||||
|
"# expects a different input schema and does not return a JSON string,\n",
|
||||||
|
"# respectively, or you want to apply a prompt template on top.\n",
|
||||||
|
"\n",
|
||||||
|
"def transform_input(**request):\n",
|
||||||
|
" full_prompt = f\"\"\"{request[\"prompt\"]}\n",
|
||||||
|
" Be Concise.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" request[\"prompt\"] = full_prompt\n",
|
||||||
|
" return request\n",
|
||||||
|
"\n",
|
||||||
|
"def transform_output(response):\n",
|
||||||
|
" return response.upper()\n",
|
||||||
|
"\n",
|
||||||
|
"llm = Databricks(\n",
|
||||||
|
" cluster_driver_port=\"7777\",\n",
|
||||||
|
" transform_input_fn=transform_input,\n",
|
||||||
|
" transform_output_fn=transform_output)\n",
|
||||||
|
"\n",
|
||||||
|
"llm(\"How are you?\")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+notebook": {
|
||||||
|
"dashboards": [],
|
||||||
|
"language": "python",
|
||||||
|
"notebookMetadata": {
|
||||||
|
"pythonIndentUnit": 2
|
||||||
|
},
|
||||||
|
"notebookName": "databricks",
|
||||||
|
"widgets": {}
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "llm",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.10"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
@ -11,6 +11,7 @@ from langchain.llms.beam import Beam
|
|||||||
from langchain.llms.cerebriumai import CerebriumAI
|
from langchain.llms.cerebriumai import CerebriumAI
|
||||||
from langchain.llms.cohere import Cohere
|
from langchain.llms.cohere import Cohere
|
||||||
from langchain.llms.ctransformers import CTransformers
|
from langchain.llms.ctransformers import CTransformers
|
||||||
|
from langchain.llms.databricks import Databricks
|
||||||
from langchain.llms.deepinfra import DeepInfra
|
from langchain.llms.deepinfra import DeepInfra
|
||||||
from langchain.llms.fake import FakeListLLM
|
from langchain.llms.fake import FakeListLLM
|
||||||
from langchain.llms.forefrontai import ForefrontAI
|
from langchain.llms.forefrontai import ForefrontAI
|
||||||
@ -50,6 +51,7 @@ __all__ = [
|
|||||||
"CerebriumAI",
|
"CerebriumAI",
|
||||||
"Cohere",
|
"Cohere",
|
||||||
"CTransformers",
|
"CTransformers",
|
||||||
|
"Databricks",
|
||||||
"DeepInfra",
|
"DeepInfra",
|
||||||
"ForefrontAI",
|
"ForefrontAI",
|
||||||
"GooglePalm",
|
"GooglePalm",
|
||||||
@ -95,6 +97,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
|||||||
"cerebriumai": CerebriumAI,
|
"cerebriumai": CerebriumAI,
|
||||||
"cohere": Cohere,
|
"cohere": Cohere,
|
||||||
"ctransformers": CTransformers,
|
"ctransformers": CTransformers,
|
||||||
|
"databricks": Databricks,
|
||||||
"deepinfra": DeepInfra,
|
"deepinfra": DeepInfra,
|
||||||
"forefrontai": ForefrontAI,
|
"forefrontai": ForefrontAI,
|
||||||
"google_palm": GooglePalm,
|
"google_palm": GooglePalm,
|
||||||
|
323
langchain/llms/databricks.py
Normal file
323
langchain/llms/databricks.py
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator, validator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
|
__all__ = ["Databricks"]
|
||||||
|
|
||||||
|
|
||||||
|
class _DatabricksClientBase(BaseModel, ABC):
|
||||||
|
"""A base JSON API client that talks to Databricks."""
|
||||||
|
|
||||||
|
api_url: str
|
||||||
|
api_token: str
|
||||||
|
|
||||||
|
def post_raw(self, request: Any) -> Any:
|
||||||
|
headers = {"Authorization": f"Bearer {self.api_token}"}
|
||||||
|
response = requests.post(self.api_url, headers=headers, json=request)
|
||||||
|
# TODO: error handling and automatic retries
|
||||||
|
if not response.ok:
|
||||||
|
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def post(self, request: Any) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class _DatabricksServingEndpointClient(_DatabricksClientBase):
|
||||||
|
"""An API client that talks to a Databricks serving endpoint."""
|
||||||
|
|
||||||
|
host: str
|
||||||
|
endpoint_name: str
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
if "api_url" not in values:
|
||||||
|
host = values["host"]
|
||||||
|
endpoint_name = values["endpoint_name"]
|
||||||
|
api_url = f"https://{host}/serving-endpoints/{endpoint_name}/invocations"
|
||||||
|
values["api_url"] = api_url
|
||||||
|
return values
|
||||||
|
|
||||||
|
def post(self, request: Any) -> Any:
|
||||||
|
# See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html
|
||||||
|
wrapped_request = {"dataframe_records": [request]}
|
||||||
|
response = self.post_raw(wrapped_request)["predictions"]
|
||||||
|
# For a single-record query, the result is not a list.
|
||||||
|
if isinstance(response, list):
|
||||||
|
response = response[0]
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
|
||||||
|
"""An API client that talks to a Databricks cluster driver proxy app."""
|
||||||
|
|
||||||
|
host: str
|
||||||
|
cluster_id: str
|
||||||
|
cluster_driver_port: str
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
if "api_url" not in values:
|
||||||
|
host = values["host"]
|
||||||
|
cluster_id = values["cluster_id"]
|
||||||
|
port = values["cluster_driver_port"]
|
||||||
|
api_url = f"https://{host}/driver-proxy-api/o/0/{cluster_id}/{port}"
|
||||||
|
values["api_url"] = api_url
|
||||||
|
return values
|
||||||
|
|
||||||
|
def post(self, request: Any) -> Any:
|
||||||
|
return self.post_raw(request)
|
||||||
|
|
||||||
|
|
||||||
|
def get_repl_context() -> Any:
|
||||||
|
"""Gets the notebook REPL context if running inside a Databricks notebook.
|
||||||
|
Returns None otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from dbruntime.databricks_repl_context import get_context
|
||||||
|
|
||||||
|
return get_context()
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot access dbruntime, not running inside a Databricks notebook."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_host() -> str:
|
||||||
|
"""Gets the default Databricks workspace hostname.
|
||||||
|
Raises an error if the hostname cannot be automatically determined.
|
||||||
|
"""
|
||||||
|
host = os.getenv("DATABRICKS_HOST")
|
||||||
|
if not host:
|
||||||
|
try:
|
||||||
|
host = get_repl_context().browserHostName
|
||||||
|
if not host:
|
||||||
|
raise ValueError("context doesn't contain browserHostName.")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
"host was not set and cannot be automatically inferred. Set "
|
||||||
|
f"environment variable 'DATABRICKS_HOST'. Received error: {e}"
|
||||||
|
)
|
||||||
|
# TODO: support Databricks CLI profile
|
||||||
|
host = host.lstrip("https://").lstrip("http://").rstrip("/")
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_api_token() -> str:
|
||||||
|
"""Gets the default Databricks personal access token.
|
||||||
|
Raises an error if the token cannot be automatically determined.
|
||||||
|
"""
|
||||||
|
if api_token := os.getenv("DATABRICKS_API_TOKEN"):
|
||||||
|
return api_token
|
||||||
|
try:
|
||||||
|
api_token = get_repl_context().apiToken
|
||||||
|
if not api_token:
|
||||||
|
raise ValueError("context doesn't contain apiToken.")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
"api_token was not set and cannot be automatically inferred. Set "
|
||||||
|
f"environment variable 'DATABRICKS_API_TOKEN'. Received error: {e}"
|
||||||
|
)
|
||||||
|
# TODO: support Databricks CLI profile
|
||||||
|
return api_token
|
||||||
|
|
||||||
|
|
||||||
|
class Databricks(LLM):
|
||||||
|
"""LLM wrapper around a Databricks serving endpoint or a cluster driver proxy app.
|
||||||
|
It supports two endpoint types:
|
||||||
|
|
||||||
|
* **Serving endpoint** (recommended for both production and development).
|
||||||
|
We assume that an LLM was registered and deployed to a serving endpoint.
|
||||||
|
To wrap it as an LLM you must have "Can Query" permission to the endpoint.
|
||||||
|
Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and
|
||||||
|
``cluster_driver_port``.
|
||||||
|
The expected model signature is:
|
||||||
|
|
||||||
|
* inputs::
|
||||||
|
|
||||||
|
[{"name": "prompt", "type": "string"},
|
||||||
|
{"name": "stop", "type": "list[string]"}]
|
||||||
|
|
||||||
|
* outputs: ``[{"type": "string"}]``
|
||||||
|
|
||||||
|
* **Cluster driver proxy app** (recommended for interactive development).
|
||||||
|
One can load an LLM on a Databricks interactive cluster and start a local HTTP
|
||||||
|
server on the driver node to serve the model at ``/`` using HTTP POST method
|
||||||
|
with JSON input/output.
|
||||||
|
Please use a port number between ``[3000, 8000]`` and let the server listen to
|
||||||
|
the driver IP address or simply ``0.0.0.0`` instead of localhost only.
|
||||||
|
To wrap it as an LLM you must have "Can Attach To" permission to the cluster.
|
||||||
|
Set ``cluster_id`` and ``cluster_driver_port`` and do not set ``endpoint_name``.
|
||||||
|
The expected server schema (using JSON schema) is:
|
||||||
|
|
||||||
|
* inputs::
|
||||||
|
|
||||||
|
{"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"prompt": {"type": "string"},
|
||||||
|
"stop": {"type": "array", "items": {"type": "string"}}},
|
||||||
|
"required": ["prompt"]}`
|
||||||
|
|
||||||
|
* outputs: ``{"type": "string"}``
|
||||||
|
|
||||||
|
If the endpoint model signature is different or you want to set extra params,
|
||||||
|
you can use `transform_input_fn` and `transform_output_fn` to apply necessary
|
||||||
|
transformations before and after the query.
|
||||||
|
"""
|
||||||
|
|
||||||
|
host: str = Field(default_factory=get_default_host)
|
||||||
|
"""Databricks workspace hostname.
|
||||||
|
If not provided, the default value is determined by
|
||||||
|
|
||||||
|
* the ``DATABRICKS_HOST`` environment variable if present, or
|
||||||
|
* the hostname of the current Databricks workspace if running inside
|
||||||
|
a Databricks notebook attached to an interactive cluster in "single user"
|
||||||
|
or "no isolation shared" mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_token: str = Field(default_factory=get_default_api_token)
|
||||||
|
"""Databricks personal access token.
|
||||||
|
If not provided, the default value is determined by
|
||||||
|
|
||||||
|
* the ``DATABRICKS_API_TOKEN`` environment variable if present, or
|
||||||
|
* an automatically generated temporary token if running inside a Databricks
|
||||||
|
notebook attached to an interactive cluster in "single user" or
|
||||||
|
"no isolation shared" mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
endpoint_name: Optional[str] = None
|
||||||
|
"""Name of the model serving endpont.
|
||||||
|
You must specify the endpoint name to connect to a model serving endpoint.
|
||||||
|
You must not set both ``endpoint_name`` and ``cluster_id``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cluster_id: Optional[str] = None
|
||||||
|
"""ID of the cluster if connecting to a cluster driver proxy app.
|
||||||
|
If neither ``endpoint_name`` nor ``cluster_id`` is not provided and the code runs
|
||||||
|
inside a Databricks notebook attached to an interactive cluster in "single user"
|
||||||
|
or "no isolation shared" mode, the current cluster ID is used as default.
|
||||||
|
You must not set both ``endpoint_name`` and ``cluster_id``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cluster_driver_port: Optional[str] = None
|
||||||
|
"""The port number used by the HTTP server running on the cluster driver node.
|
||||||
|
The server should listen on the driver IP address or simply ``0.0.0.0`` to connect.
|
||||||
|
We recommend the server using a port number between ``[3000, 8000]``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
"""Extra parameters to pass to the endpoint."""
|
||||||
|
|
||||||
|
transform_input_fn: Optional[Callable] = None
|
||||||
|
"""A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible
|
||||||
|
request object that the endpoint accepts.
|
||||||
|
For example, you can apply a prompt template to the input prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
transform_output_fn: Optional[Callable[..., str]] = None
|
||||||
|
"""A function that transforms the output from the endpoint to the generated text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_client: _DatabricksClientBase = PrivateAttr()
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = Extra.forbid
|
||||||
|
underscore_attrs_are_private = True
|
||||||
|
|
||||||
|
@validator("cluster_id", always=True)
|
||||||
|
def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
|
||||||
|
if v and values["endpoint_name"]:
|
||||||
|
raise ValueError("Cannot set both endpoint_name and cluster_id.")
|
||||||
|
elif values["endpoint_name"]:
|
||||||
|
return None
|
||||||
|
elif v:
|
||||||
|
return v
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
if v := get_repl_context().clusterId:
|
||||||
|
return v
|
||||||
|
raise ValueError("Context doesn't contain clusterId.")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Neither endpoint_name nor cluster_id was set. "
|
||||||
|
"And the cluster_id cannot be automatically determined. Received"
|
||||||
|
f" error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@validator("cluster_driver_port", always=True)
|
||||||
|
def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
|
||||||
|
if v and values["endpoint_name"]:
|
||||||
|
raise ValueError("Cannot set both endpoint_name and cluster_driver_port.")
|
||||||
|
elif values["endpoint_name"]:
|
||||||
|
return None
|
||||||
|
elif v is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Must set cluster_driver_port to connect to a cluster driver."
|
||||||
|
)
|
||||||
|
elif int(v) <= 0:
|
||||||
|
raise ValueError(f"Invalid cluster_driver_port: {v}")
|
||||||
|
else:
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("model_kwargs", always=True)
|
||||||
|
def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||||
|
if v:
|
||||||
|
assert "prompt" not in v, "model_kwargs must not contain key 'prompt'"
|
||||||
|
assert "stop" not in v, "model_kwargs must not contain key 'stop'"
|
||||||
|
return v
|
||||||
|
|
||||||
|
def __init__(self, **data: Any):
|
||||||
|
super().__init__(**data)
|
||||||
|
if self.endpoint_name:
|
||||||
|
self._client = _DatabricksServingEndpointClient(
|
||||||
|
host=self.host,
|
||||||
|
api_token=self.api_token,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
|
)
|
||||||
|
elif self.cluster_id and self.cluster_driver_port:
|
||||||
|
self._client = _DatabricksClusterDriverProxyClient(
|
||||||
|
host=self.host,
|
||||||
|
api_token=self.api_token,
|
||||||
|
cluster_id=self.cluster_id,
|
||||||
|
cluster_driver_port=self.cluster_driver_port,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Must specify either endpoint_name or cluster_id/cluster_driver_port."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "databricks"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Queries the LLM endpoint with the given prompt and stop sequence."""
|
||||||
|
|
||||||
|
# TODO: support callbacks
|
||||||
|
|
||||||
|
request = {"prompt": prompt, "stop": stop}
|
||||||
|
if self.model_kwargs:
|
||||||
|
request.update(self.model_kwargs)
|
||||||
|
|
||||||
|
if self.transform_input_fn:
|
||||||
|
request = self.transform_input_fn(**request)
|
||||||
|
|
||||||
|
response = self._client.post(request)
|
||||||
|
|
||||||
|
if self.transform_output_fn:
|
||||||
|
response = self.transform_output_fn(response)
|
||||||
|
|
||||||
|
return response
|
Loading…
Reference in New Issue
Block a user