mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
274 lines
9.4 KiB
Plaintext
274 lines
9.4 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "707d13a7",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Databricks\n",
|
|
"\n",
|
|
"This notebook covers how to connect to the [Databricks runtimes](https://docs.databricks.com/runtime/index.html) and [Databricks SQL](https://www.databricks.com/product/databricks-sql) using the SQLDatabase wrapper of LangChain.\n",
|
|
"It is broken into 3 parts: installation and setup, connecting to Databricks, and examples."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0076d072",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Installation and Setup"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "739b489b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install databricks-sql-connector"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "73113163",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Connecting to Databricks\n",
|
|
"\n",
|
|
"You can connect to [Databricks runtimes](https://docs.databricks.com/runtime/index.html) and [Databricks SQL](https://www.databricks.com/product/databricks-sql) using the `SQLDatabase.from_databricks()` method.\n",
|
|
"\n",
|
|
"### Syntax\n",
|
|
"```python\n",
|
|
"SQLDatabase.from_databricks(\n",
|
|
" catalog: str,\n",
|
|
" schema: str,\n",
|
|
" host: Optional[str] = None,\n",
|
|
" api_token: Optional[str] = None,\n",
|
|
" warehouse_id: Optional[str] = None,\n",
|
|
" cluster_id: Optional[str] = None,\n",
|
|
" engine_args: Optional[dict] = None,\n",
|
|
" **kwargs: Any)\n",
|
|
"```\n",
|
|
"### Required Parameters\n",
|
|
"* `catalog`: The catalog name in the Databricks database.\n",
|
|
"* `schema`: The schema name in the catalog.\n",
|
|
"\n",
|
|
"### Optional Parameters\n",
|
|
"There following parameters are optional. When executing the method in a Databricks notebook, you don't need to provide them in most of the cases.\n",
|
|
"* `host`: The Databricks workspace hostname, excluding 'https://' part. Defaults to 'DATABRICKS_HOST' environment variable or current workspace if in a Databricks notebook.\n",
|
|
"* `api_token`: The Databricks personal access token for accessing the Databricks SQL warehouse or the cluster. Defaults to 'DATABRICKS_TOKEN' environment variable or a temporary one is generated if in a Databricks notebook.\n",
|
|
"* `warehouse_id`: The warehouse ID in the Databricks SQL.\n",
|
|
"* `cluster_id`: The cluster ID in the Databricks Runtime. If running in a Databricks notebook and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the cluster the notebook is attached to.\n",
|
|
"* `engine_args`: The arguments to be used when connecting Databricks.\n",
|
|
"* `**kwargs`: Additional keyword arguments for the `SQLDatabase.from_uri` method."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b11c7e48",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Examples"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "8102bca0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Connecting to Databricks with SQLDatabase wrapper\n",
|
|
"from langchain import SQLDatabase\n",
|
|
"\n",
|
|
"db = SQLDatabase.from_databricks(catalog=\"samples\", schema=\"nyctaxi\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "9dd36f58",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Creating a OpenAI Chat LLM wrapper\n",
|
|
"from langchain.chat_models import ChatOpenAI\n",
|
|
"\n",
|
|
"llm = ChatOpenAI(temperature=0, model_name=\"gpt-4\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5b5c5f1a",
|
|
"metadata": {},
|
|
"source": [
|
|
"### SQL Chain example\n",
|
|
"\n",
|
|
"This example demonstrates the use of the [SQL Chain](https://python.langchain.com/en/latest/modules/chains/examples/sqlite.html) for answering a question over a Databricks database."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "36f2270b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain import SQLDatabaseChain\n",
|
|
"\n",
|
|
"db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "4e2b5f25",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
|
|
"What is the average duration of taxi rides that start between midnight and 6am?\n",
|
|
"SQLQuery:\u001b[32;1m\u001b[1;3mSELECT AVG(UNIX_TIMESTAMP(tpep_dropoff_datetime) - UNIX_TIMESTAMP(tpep_pickup_datetime)) as avg_duration\n",
|
|
"FROM trips\n",
|
|
"WHERE HOUR(tpep_pickup_datetime) >= 0 AND HOUR(tpep_pickup_datetime) < 6\u001b[0m\n",
|
|
"SQLResult: \u001b[33;1m\u001b[1;3m[(987.8122786304605,)]\u001b[0m\n",
|
|
"Answer:\u001b[32;1m\u001b[1;3mThe average duration of taxi rides that start between midnight and 6am is 987.81 seconds.\u001b[0m\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'The average duration of taxi rides that start between midnight and 6am is 987.81 seconds.'"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"db_chain.run(\n",
|
|
" \"What is the average duration of taxi rides that start between midnight and 6am?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e496d5e5",
|
|
"metadata": {},
|
|
"source": [
|
|
"### SQL Database Agent example\n",
|
|
"\n",
|
|
"This example demonstrates the use of the [SQL Database Agent](/docs/modules/agents/toolkits/sql_database.html) for answering questions over a Databricks database."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "9918e86a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.agents import create_sql_agent\n",
|
|
"from langchain.agents.agent_toolkits import SQLDatabaseToolkit\n",
|
|
"\n",
|
|
"toolkit = SQLDatabaseToolkit(db=db, llm=llm)\n",
|
|
"agent = create_sql_agent(llm=llm, toolkit=toolkit, verbose=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "c484a76e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mAction: list_tables_sql_db\n",
|
|
"Action Input: \u001b[0m\n",
|
|
"Observation: \u001b[38;5;200m\u001b[1;3mtrips\u001b[0m\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3mI should check the schema of the trips table to see if it has the necessary columns for trip distance and duration.\n",
|
|
"Action: schema_sql_db\n",
|
|
"Action Input: trips\u001b[0m\n",
|
|
"Observation: \u001b[33;1m\u001b[1;3m\n",
|
|
"CREATE TABLE trips (\n",
|
|
"\ttpep_pickup_datetime TIMESTAMP, \n",
|
|
"\ttpep_dropoff_datetime TIMESTAMP, \n",
|
|
"\ttrip_distance FLOAT, \n",
|
|
"\tfare_amount FLOAT, \n",
|
|
"\tpickup_zip INT, \n",
|
|
"\tdropoff_zip INT\n",
|
|
") USING DELTA\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from trips table:\n",
|
|
"tpep_pickup_datetime\ttpep_dropoff_datetime\ttrip_distance\tfare_amount\tpickup_zip\tdropoff_zip\n",
|
|
"2016-02-14 16:52:13+00:00\t2016-02-14 17:16:04+00:00\t4.94\t19.0\t10282\t10171\n",
|
|
"2016-02-04 18:44:19+00:00\t2016-02-04 18:46:00+00:00\t0.28\t3.5\t10110\t10110\n",
|
|
"2016-02-17 17:13:57+00:00\t2016-02-17 17:17:55+00:00\t0.7\t5.0\t10103\t10023\n",
|
|
"*/\u001b[0m\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3mThe trips table has the necessary columns for trip distance and duration. I will write a query to find the longest trip distance and its duration.\n",
|
|
"Action: query_checker_sql_db\n",
|
|
"Action Input: SELECT trip_distance, tpep_dropoff_datetime - tpep_pickup_datetime as duration FROM trips ORDER BY trip_distance DESC LIMIT 1\u001b[0m\n",
|
|
"Observation: \u001b[31;1m\u001b[1;3mSELECT trip_distance, tpep_dropoff_datetime - tpep_pickup_datetime as duration FROM trips ORDER BY trip_distance DESC LIMIT 1\u001b[0m\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3mThe query is correct. I will now execute it to find the longest trip distance and its duration.\n",
|
|
"Action: query_sql_db\n",
|
|
"Action Input: SELECT trip_distance, tpep_dropoff_datetime - tpep_pickup_datetime as duration FROM trips ORDER BY trip_distance DESC LIMIT 1\u001b[0m\n",
|
|
"Observation: \u001b[36;1m\u001b[1;3m[(30.6, '0 00:43:31.000000000')]\u001b[0m\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
|
|
"Final Answer: The longest trip distance is 30.6 miles and it took 43 minutes and 31 seconds.\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'The longest trip distance is 30.6 miles and it took 43 minutes and 31 seconds.'"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent.run(\"What is the longest trip distance and how long did it take?\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|