mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
534 lines
15 KiB
Plaintext
534 lines
15 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"application/vnd.databricks.v1+cell": {
|
||
"cellMetadata": {},
|
||
"inputWidgets": {},
|
||
"nuid": "5147e458-3b83-449e-9c2f-e7e1972e43fc",
|
||
"showTitle": false,
|
||
"title": ""
|
||
}
|
||
},
|
||
"source": [
|
||
"# Databricks\n",
|
||
"\n",
|
||
"The [Databricks](https://www.databricks.com/) Lakehouse Platform unifies data, analytics, and AI on one platform.\n",
|
||
"\n",
|
||
"This example notebook shows how to wrap Databricks endpoints as LLMs in LangChain.\n",
|
||
"It supports two endpoint types:\n",
|
||
"* Serving endpoint, recommended for production and development,\n",
|
||
"* Cluster driver proxy app, recommended for iteractive development."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"application/vnd.databricks.v1+cell": {
|
||
"cellMetadata": {
|
||
"byteLimit": 2048000,
|
||
"rowLimit": 10000
|
||
},
|
||
"inputWidgets": {},
|
||
"nuid": "bf07455f-aac9-4873-a8e7-7952af0f8c82",
|
||
"showTitle": false,
|
||
"title": ""
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from langchain.llms import Databricks"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"application/vnd.databricks.v1+cell": {
|
||
"cellMetadata": {},
|
||
"inputWidgets": {},
|
||
"nuid": "94f6540e-40cd-4d9b-95d3-33d36f061dcc",
|
||
"showTitle": false,
|
||
"title": ""
|
||
}
|
||
},
|
||
"source": [
|
||
"## Wrapping a serving endpoint\n",
|
||
"\n",
|
||
"Prerequisites:\n",
|
||
"* An LLM was registered and deployed to [a Databricks serving endpoint](https://docs.databricks.com/machine-learning/model-serving/index.html).\n",
|
||
"* You have [\"Can Query\" permission](https://docs.databricks.com/security/auth-authz/access-control/serving-endpoint-acl.html) to the endpoint.\n",
|
||
"\n",
|
||
"The expected MLflow model signature is:\n",
|
||
" * inputs: `[{\"name\": \"prompt\", \"type\": \"string\"}, {\"name\": \"stop\", \"type\": \"list[string]\"}]`\n",
|
||
" * outputs: `[{\"type\": \"string\"}]`\n",
|
||
"\n",
|
||
"If the model signature is incompatible or you want to insert extra configs, you can set `transform_input_fn` and `transform_output_fn` accordingly."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"application/vnd.databricks.v1+cell": {
|
||
"cellMetadata": {
|
||
"byteLimit": 2048000,
|
||
"rowLimit": 10000
|
||
},
|
||
"inputWidgets": {},
|
||
"nuid": "7496dc7a-8a1a-4ce6-9648-4f69ed25275b",
|
||
"showTitle": false,
|
||
"title": ""
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'I am happy to hear that you are in good health and as always, you are appreciated.'"
|
||
]
|
||
},
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# If running a Databricks notebook attached to an interactive cluster in \"single user\"\n",
|
||
"# or \"no isolation shared\" mode, you only need to specify the endpoint name to create\n",
|
||
"# a `Databricks` instance to query a serving endpoint in the same workspace.\n",
|
||
"llm = Databricks(endpoint_name=\"dolly\")\n",
|
||
"\n",
|
||
"llm(\"How are you?\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"application/vnd.databricks.v1+cell": {
|
||
"cellMetadata": {
|
||
"byteLimit": 2048000,
|
||
"rowLimit": 10000
|
||
},
|
||
"inputWidgets": {},
|
||
"nuid": "0c86d952-4236-4a5e-bdac-cf4e3ccf3a16",
|
||
"showTitle": false,
|
||
"title": ""
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'Good'"
|
||
]
|
||
},
|
||
"execution_count": 34,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"llm(\"How are you?\", stop=[\".\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"application/vnd.databricks.v1+cell": {
|
||
"cellMetadata": {
|
||
"byteLimit": 2048000,
|
||
"rowLimit": 10000
|
||
},
|
||
"inputWidgets": {},
|
||
"nuid": "5f2507a2-addd-431d-9da5-dc2ae33783f6",
|
||
"showTitle": false,
|
||
"title": ""
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'I am fine. Thank you!'"
|
||
]
|
||
},
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Otherwise, you can manually specify the Databricks workspace hostname and personal access token\n",
|
||
"# or set `DATABRICKS_HOST` and `DATABRICKS_TOKEN` environment variables, respectively.\n",
|
||
"# See https://docs.databricks.com/dev-tools/auth.html#databricks-personal-access-tokens\n",
|
||
"# We strongly recommend not exposing the API token explicitly inside a notebook.\n",
|
||
"# You can use Databricks secret manager to store your API token securely.\n",
|
||
"# See https://docs.databricks.com/dev-tools/databricks-utils.html#secrets-utility-dbutilssecrets\n",
|
||
"\n",
|
||
"import os\n",
|
||
"\n",
|
||
"os.environ[\"DATABRICKS_TOKEN\"] = dbutils.secrets.get(\"myworkspace\", \"api_token\")\n",
|
||
"\n",
|
||
"llm = Databricks(host=\"myworkspace.cloud.databricks.com\", endpoint_name=\"dolly\")\n",
|
||
"\n",
|
||
"llm(\"How are you?\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"application/vnd.databricks.v1+cell": {
|
||
"cellMetadata": {
|
||
"byteLimit": 2048000,
|
||
"rowLimit": 10000
|
||
},
|
||
"inputWidgets": {},
|
||
"nuid": "9b54f8ce-ffe5-4c47-a3f0-b4ebde524a6a",
|
||
"showTitle": false,
|
||
"title": ""
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'I am fine.'"
|
||
]
|
||
},
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# If the serving endpoint accepts extra parameters like `temperature`,\n",
|
||
"# you can set them in `model_kwargs`.\n",
|
||
"llm = Databricks(endpoint_name=\"dolly\", model_kwargs={\"temperature\": 0.1})\n",
|
||
"\n",
|
||
"llm(\"How are you?\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"application/vnd.databricks.v1+cell": {
|
||
"cellMetadata": {
|
||
"byteLimit": 2048000,
|
||
"rowLimit": 10000
|
||
},
|
||
"inputWidgets": {},
|
||
"nuid": "50f172f5-ea1f-4ceb-8cf1-20289848de7b",
|
||
"showTitle": false,
|
||
"title": ""
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'I’m Excellent. You?'"
|
||
]
|
||
},
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Use `transform_input_fn` and `transform_output_fn` if the serving endpoint\n",
|
||
"# expects a different input schema and does not return a JSON string,\n",
|
||
"# respectively, or you want to apply a prompt template on top.\n",
|
||
"\n",
|
||
"\n",
|
||
"def transform_input(**request):\n",
|
||
" full_prompt = f\"\"\"{request[\"prompt\"]}\n",
|
||
" Be Concise.\n",
|
||
" \"\"\"\n",
|
||
" request[\"prompt\"] = full_prompt\n",
|
||
" return request\n",
|
||
"\n",
|
||
"\n",
|
||
"llm = Databricks(endpoint_name=\"dolly\", transform_input_fn=transform_input)\n",
|
||
"\n",
|
||
"llm(\"How are you?\")"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"application/vnd.databricks.v1+cell": {
|
||
"cellMetadata": {},
|
||
"inputWidgets": {},
|
||
"nuid": "8ea49319-a041-494d-afcd-87bcf00d5efb",
|
||
"showTitle": false,
|
||
"title": ""
|
||
}
|
||
},
|
||
"source": [
|
||
"## Wrapping a cluster driver proxy app\n",
|
||
"\n",
|
||
"Prerequisites:\n",
|
||
"* An LLM loaded on a Databricks interactive cluster in \"single user\" or \"no isolation shared\" mode.\n",
|
||
"* A local HTTP server running on the driver node to serve the model at `\"/\"` using HTTP POST with JSON input/output.\n",
|
||
"* It uses a port number between `[3000, 8000]` and listens to the driver IP address or simply `0.0.0.0` instead of localhost only.\n",
|
||
"* You have \"Can Attach To\" permission to the cluster.\n",
|
||
"\n",
|
||
"The expected server schema (using JSON schema) is:\n",
|
||
"* inputs:\n",
|
||
" ```json\n",
|
||
" {\"type\": \"object\",\n",
|
||
" \"properties\": {\n",
|
||
" \"prompt\": {\"type\": \"string\"},\n",
|
||
" \"stop\": {\"type\": \"array\", \"items\": {\"type\": \"string\"}}},\n",
|
||
" \"required\": [\"prompt\"]}\n",
|
||
" ```\n",
|
||
"* outputs: `{\"type\": \"string\"}`\n",
|
||
"\n",
|
||
"If the server schema is incompatible or you want to insert extra configs, you can use `transform_input_fn` and `transform_output_fn` accordingly.\n",
|
||
"\n",
|
||
"The following is a minimal example for running a driver proxy app to serve an LLM:\n",
|
||
"\n",
|
||
"```python\n",
|
||
"from flask import Flask, request, jsonify\n",
|
||
"import torch\n",
|
||
"from transformers import pipeline, AutoTokenizer, StoppingCriteria\n",
|
||
"\n",
|
||
"model = \"databricks/dolly-v2-3b\"\n",
|
||
"tokenizer = AutoTokenizer.from_pretrained(model, padding_side=\"left\")\n",
|
||
"dolly = pipeline(model=model, tokenizer=tokenizer, trust_remote_code=True, device_map=\"auto\")\n",
|
||
"device = dolly.device\n",
|
||
"\n",
|
||
"class CheckStop(StoppingCriteria):\n",
|
||
" def __init__(self, stop=None):\n",
|
||
" super().__init__()\n",
|
||
" self.stop = stop or []\n",
|
||
" self.matched = \"\"\n",
|
||
" self.stop_ids = [tokenizer.encode(s, return_tensors='pt').to(device) for s in self.stop]\n",
|
||
" def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):\n",
|
||
" for i, s in enumerate(self.stop_ids):\n",
|
||
" if torch.all((s == input_ids[0][-s.shape[1]:])).item():\n",
|
||
" self.matched = self.stop[i]\n",
|
||
" return True\n",
|
||
" return False\n",
|
||
"\n",
|
||
"def llm(prompt, stop=None, **kwargs):\n",
|
||
" check_stop = CheckStop(stop)\n",
|
||
" result = dolly(prompt, stopping_criteria=[check_stop], **kwargs)\n",
|
||
" return result[0][\"generated_text\"].rstrip(check_stop.matched)\n",
|
||
"\n",
|
||
"app = Flask(\"dolly\")\n",
|
||
"\n",
|
||
"@app.route('/', methods=['POST'])\n",
|
||
"def serve_llm():\n",
|
||
" resp = llm(**request.json)\n",
|
||
" return jsonify(resp)\n",
|
||
"\n",
|
||
"app.run(host=\"0.0.0.0\", port=\"7777\")\n",
|
||
"```\n",
|
||
"\n",
|
||
"Once the server is running, you can create a `Databricks` instance to wrap it as an LLM."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"application/vnd.databricks.v1+cell": {
|
||
"cellMetadata": {
|
||
"byteLimit": 2048000,
|
||
"rowLimit": 10000
|
||
},
|
||
"inputWidgets": {},
|
||
"nuid": "e3330a01-e738-4170-a176-9954aff56442",
|
||
"showTitle": false,
|
||
"title": ""
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'Hello, thank you for asking. It is wonderful to hear that you are well.'"
|
||
]
|
||
},
|
||
"execution_count": 32,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# If running a Databricks notebook attached to the same cluster that runs the app,\n",
|
||
"# you only need to specify the driver port to create a `Databricks` instance.\n",
|
||
"llm = Databricks(cluster_driver_port=\"7777\")\n",
|
||
"\n",
|
||
"llm(\"How are you?\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"application/vnd.databricks.v1+cell": {
|
||
"cellMetadata": {
|
||
"byteLimit": 2048000,
|
||
"rowLimit": 10000
|
||
},
|
||
"inputWidgets": {},
|
||
"nuid": "39c121cf-0e44-4e31-91db-37fcac459677",
|
||
"showTitle": false,
|
||
"title": ""
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'I am well. You?'"
|
||
]
|
||
},
|
||
"execution_count": 40,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Otherwise, you can manually specify the cluster ID to use,\n",
|
||
"# as well as Databricks workspace hostname and personal access token.\n",
|
||
"\n",
|
||
"llm = Databricks(cluster_id=\"0000-000000-xxxxxxxx\", cluster_driver_port=\"7777\")\n",
|
||
"\n",
|
||
"llm(\"How are you?\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"application/vnd.databricks.v1+cell": {
|
||
"cellMetadata": {
|
||
"byteLimit": 2048000,
|
||
"rowLimit": 10000
|
||
},
|
||
"inputWidgets": {},
|
||
"nuid": "3d3de599-82fd-45e4-8d8b-bacfc49dc9ce",
|
||
"showTitle": false,
|
||
"title": ""
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'I am very well. It is a pleasure to meet you.'"
|
||
]
|
||
},
|
||
"execution_count": 31,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# If the app accepts extra parameters like `temperature`,\n",
|
||
"# you can set them in `model_kwargs`.\n",
|
||
"llm = Databricks(cluster_driver_port=\"7777\", model_kwargs={\"temperature\": 0.1})\n",
|
||
"\n",
|
||
"llm(\"How are you?\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"application/vnd.databricks.v1+cell": {
|
||
"cellMetadata": {
|
||
"byteLimit": 2048000,
|
||
"rowLimit": 10000
|
||
},
|
||
"inputWidgets": {},
|
||
"nuid": "853fae8e-8df4-41e6-9d45-7769f883fe80",
|
||
"showTitle": false,
|
||
"title": ""
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'I AM DOING GREAT THANK YOU.'"
|
||
]
|
||
},
|
||
"execution_count": 32,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Use `transform_input_fn` and `transform_output_fn` if the app\n",
|
||
"# expects a different input schema and does not return a JSON string,\n",
|
||
"# respectively, or you want to apply a prompt template on top.\n",
|
||
"\n",
|
||
"\n",
|
||
"def transform_input(**request):\n",
|
||
" full_prompt = f\"\"\"{request[\"prompt\"]}\n",
|
||
" Be Concise.\n",
|
||
" \"\"\"\n",
|
||
" request[\"prompt\"] = full_prompt\n",
|
||
" return request\n",
|
||
"\n",
|
||
"\n",
|
||
"def transform_output(response):\n",
|
||
" return response.upper()\n",
|
||
"\n",
|
||
"\n",
|
||
"llm = Databricks(\n",
|
||
" cluster_driver_port=\"7777\",\n",
|
||
" transform_input_fn=transform_input,\n",
|
||
" transform_output_fn=transform_output,\n",
|
||
")\n",
|
||
"\n",
|
||
"llm(\"How are you?\")"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"application/vnd.databricks.v1+notebook": {
|
||
"dashboards": [],
|
||
"language": "python",
|
||
"notebookMetadata": {
|
||
"pythonIndentUnit": 2
|
||
},
|
||
"notebookName": "databricks",
|
||
"widgets": {}
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "llm",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.10"
|
||
},
|
||
"orig_nbformat": 4
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0
|
||
}
|