Add Spark SQL support (#4602) (#4956)

# Add Spark SQL support 
* Add Spark SQL support. It can connect to Spark via building a
local/remote SparkSession.
* Include a notebook example

I tried some complicated queries (window function, table joins), and the
tool works well.
Compared to the [Spark Dataframe

agent](https://python.langchain.com/en/latest/modules/agents/toolkits/examples/spark.html),
this tool is able to generate queries across multiple tables.

---------

# Your PR Title (What it does)

<!--
Thank you for contributing to LangChain! Your PR will appear in our next
release under the title you set. Please make sure it highlights your
valuable contribution.

Replace this with a description of the change, the issue it fixes (if
applicable), and relevant context. List any dependencies required for
this change.

After you're done, someone will review your PR. They may suggest
improvements. If no one reviews your PR within a few days, feel free to
@-mention the same people again, as notifications can get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

## Before submitting

<!-- If you're adding a new integration, include an integration test and
an example notebook showing its use! -->

## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:

<!-- For a quicker response, figure out the right person to tag with @

        @hwchase17 - project lead

        Tracing / Callbacks
        - @agola11

        Async
        - @agola11

        DataLoaders
        - @eyurtsev

        Models
        - @hwchase17
        - @agola11

        Agents / Tools / Toolkits
        - @vowelparrot
        
        VectorStores / Retrievers / Memory
        - @dev2049
        
 -->

---------

Co-authored-by: Gengliang Wang <gengliang@apache.org>
Co-authored-by: Mike W <62768671+skcoirz@users.noreply.github.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: UmerHA <40663591+UmerHA@users.noreply.github.com>
Co-authored-by: 张城铭 <z@hyperf.io>
Co-authored-by: assert <zhangchengming@kkguan.com>
Co-authored-by: blob42 <spike@w530>
Co-authored-by: Yuekai Zhang <zhangyuekai@foxmail.com>
Co-authored-by: Richard He <he.yucheng@outlook.com>
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
Co-authored-by: Leonid Ganeline <leo.gan.57@gmail.com>
Co-authored-by: Alexey Nominas <60900649+Chae4ek@users.noreply.github.com>
Co-authored-by: elBarkey <elbarkey@gmail.com>
Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com>
Co-authored-by: Jeffrey D <1289344+verygoodsoftwarenotvirus@users.noreply.github.com>
Co-authored-by: so2liu <yangliu35@outlook.com>
Co-authored-by: Viswanadh Rayavarapu <44315599+vishwa-rn@users.noreply.github.com>
Co-authored-by: Chakib Ben Ziane <contact@blob42.xyz>
Co-authored-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com>
Co-authored-by: Daniel Chalef <daniel.chalef@private.org>
Co-authored-by: Jari Bakken <jari.bakken@gmail.com>
Co-authored-by: escafati <scafatieugenio@gmail.com>
searx_updates
Harrison Chase 1 year ago committed by GitHub
parent 5feb60f426
commit 88a3a56c1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,348 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Spark SQL Agent\n",
"\n",
"This notebook shows how to use agents to interact with a Spark SQL. Similar to [SQL Database Agent](https://python.langchain.com/en/latest/modules/agents/toolkits/examples/sql_database.html), it is designed to address general inquiries about Spark SQL and facilitate error recovery.\n",
"\n",
"**NOTE: Note that, as this agent is in active development, all answers might not be correct. Additionally, it is not guaranteed that the agent won't perform DML statements on your Spark cluster given certain questions. Be careful running it on sensitive data!**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialization"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import create_spark_sql_agent\n",
"from langchain.agents.agent_toolkits import SparkSQLToolkit\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.utilities.spark_sql import SparkSQL"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
"23/05/18 16:03:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
"|PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|\n",
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
"| 1| 0| 3|Braund, Mr. Owen ...| male|22.0| 1| 0| A/5 21171| 7.25| null| S|\n",
"| 2| 1| 1|Cumings, Mrs. Joh...|female|38.0| 1| 0| PC 17599|71.2833| C85| C|\n",
"| 3| 1| 3|Heikkinen, Miss. ...|female|26.0| 0| 0|STON/O2. 3101282| 7.925| null| S|\n",
"| 4| 1| 1|Futrelle, Mrs. Ja...|female|35.0| 1| 0| 113803| 53.1| C123| S|\n",
"| 5| 0| 3|Allen, Mr. Willia...| male|35.0| 0| 0| 373450| 8.05| null| S|\n",
"| 6| 0| 3| Moran, Mr. James| male|null| 0| 0| 330877| 8.4583| null| Q|\n",
"| 7| 0| 1|McCarthy, Mr. Tim...| male|54.0| 0| 0| 17463|51.8625| E46| S|\n",
"| 8| 0| 3|Palsson, Master. ...| male| 2.0| 3| 1| 349909| 21.075| null| S|\n",
"| 9| 1| 3|Johnson, Mrs. Osc...|female|27.0| 0| 2| 347742|11.1333| null| S|\n",
"| 10| 1| 2|Nasser, Mrs. Nich...|female|14.0| 1| 0| 237736|30.0708| null| C|\n",
"| 11| 1| 3|Sandstrom, Miss. ...|female| 4.0| 1| 1| PP 9549| 16.7| G6| S|\n",
"| 12| 1| 1|Bonnell, Miss. El...|female|58.0| 0| 0| 113783| 26.55| C103| S|\n",
"| 13| 0| 3|Saundercock, Mr. ...| male|20.0| 0| 0| A/5. 2151| 8.05| null| S|\n",
"| 14| 0| 3|Andersson, Mr. An...| male|39.0| 1| 5| 347082| 31.275| null| S|\n",
"| 15| 0| 3|Vestrom, Miss. Hu...|female|14.0| 0| 0| 350406| 7.8542| null| S|\n",
"| 16| 1| 2|Hewlett, Mrs. (Ma...|female|55.0| 0| 0| 248706| 16.0| null| S|\n",
"| 17| 0| 3|Rice, Master. Eugene| male| 2.0| 4| 1| 382652| 29.125| null| Q|\n",
"| 18| 1| 2|Williams, Mr. Cha...| male|null| 0| 0| 244373| 13.0| null| S|\n",
"| 19| 0| 3|Vander Planke, Mr...|female|31.0| 1| 0| 345763| 18.0| null| S|\n",
"| 20| 1| 3|Masselmani, Mrs. ...|female|null| 0| 0| 2649| 7.225| null| C|\n",
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
"only showing top 20 rows\n",
"\n"
]
}
],
"source": [
"from pyspark.sql import SparkSession\n",
"\n",
"spark = SparkSession.builder.getOrCreate()\n",
"schema = \"langchain_example\"\n",
"spark.sql(f\"CREATE DATABASE IF NOT EXISTS {schema}\")\n",
"spark.sql(f\"USE {schema}\")\n",
"csv_file_path = \"titanic.csv\"\n",
"table = \"titanic\"\n",
"spark.read.csv(csv_file_path, header=True, inferSchema=True).write.saveAsTable(table)\n",
"spark.table(table).show()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Note, you can also connect to Spark via Spark connect. For example:\n",
"# db = SparkSQL.from_uri(\"sc://localhost:15002\", schema=schema)\n",
"spark_sql = SparkSQL(schema=schema)\n",
"llm = ChatOpenAI(temperature=0)\n",
"toolkit = SparkSQLToolkit(db=spark_sql, llm=llm)\n",
"agent_executor = create_spark_sql_agent(\n",
" llm=llm,\n",
" toolkit=toolkit,\n",
" verbose=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example: describing a table"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n",
"\u001B[32;1m\u001B[1;3mAction: list_tables_sql_db\n",
"Action Input: \u001B[0m\n",
"Observation: \u001B[38;5;200m\u001B[1;3mtitanic\u001B[0m\n",
"Thought:\u001B[32;1m\u001B[1;3mI found the titanic table. Now I need to get the schema and sample rows for the titanic table.\n",
"Action: schema_sql_db\n",
"Action Input: titanic\u001B[0m\n",
"Observation: \u001B[33;1m\u001B[1;3mCREATE TABLE langchain_example.titanic (\n",
" PassengerId INT,\n",
" Survived INT,\n",
" Pclass INT,\n",
" Name STRING,\n",
" Sex STRING,\n",
" Age DOUBLE,\n",
" SibSp INT,\n",
" Parch INT,\n",
" Ticket STRING,\n",
" Fare DOUBLE,\n",
" Cabin STRING,\n",
" Embarked STRING)\n",
";\n",
"\n",
"/*\n",
"3 rows from titanic table:\n",
"PassengerId\tSurvived\tPclass\tName\tSex\tAge\tSibSp\tParch\tTicket\tFare\tCabin\tEmbarked\n",
"1\t0\t3\tBraund, Mr. Owen Harris\tmale\t22.0\t1\t0\tA/5 21171\t7.25\tNone\tS\n",
"2\t1\t1\tCumings, Mrs. John Bradley (Florence Briggs Thayer)\tfemale\t38.0\t1\t0\tPC 17599\t71.2833\tC85\tC\n",
"3\t1\t3\tHeikkinen, Miss. Laina\tfemale\t26.0\t0\t0\tSTON/O2. 3101282\t7.925\tNone\tS\n",
"*/\u001B[0m\n",
"Thought:\u001B[32;1m\u001B[1;3mI now know the schema and sample rows for the titanic table.\n",
"Final Answer: The titanic table has the following columns: PassengerId (INT), Survived (INT), Pclass (INT), Name (STRING), Sex (STRING), Age (DOUBLE), SibSp (INT), Parch (INT), Ticket (STRING), Fare (DOUBLE), Cabin (STRING), and Embarked (STRING). Here are some sample rows from the table: \n",
"\n",
"1. PassengerId: 1, Survived: 0, Pclass: 3, Name: Braund, Mr. Owen Harris, Sex: male, Age: 22.0, SibSp: 1, Parch: 0, Ticket: A/5 21171, Fare: 7.25, Cabin: None, Embarked: S\n",
"2. PassengerId: 2, Survived: 1, Pclass: 1, Name: Cumings, Mrs. John Bradley (Florence Briggs Thayer), Sex: female, Age: 38.0, SibSp: 1, Parch: 0, Ticket: PC 17599, Fare: 71.2833, Cabin: C85, Embarked: C\n",
"3. PassengerId: 3, Survived: 1, Pclass: 3, Name: Heikkinen, Miss. Laina, Sex: female, Age: 26.0, SibSp: 0, Parch: 0, Ticket: STON/O2. 3101282, Fare: 7.925, Cabin: None, Embarked: S\u001B[0m\n",
"\n",
"\u001B[1m> Finished chain.\u001B[0m\n"
]
},
{
"data": {
"text/plain": "'The titanic table has the following columns: PassengerId (INT), Survived (INT), Pclass (INT), Name (STRING), Sex (STRING), Age (DOUBLE), SibSp (INT), Parch (INT), Ticket (STRING), Fare (DOUBLE), Cabin (STRING), and Embarked (STRING). Here are some sample rows from the table: \\n\\n1. PassengerId: 1, Survived: 0, Pclass: 3, Name: Braund, Mr. Owen Harris, Sex: male, Age: 22.0, SibSp: 1, Parch: 0, Ticket: A/5 21171, Fare: 7.25, Cabin: None, Embarked: S\\n2. PassengerId: 2, Survived: 1, Pclass: 1, Name: Cumings, Mrs. John Bradley (Florence Briggs Thayer), Sex: female, Age: 38.0, SibSp: 1, Parch: 0, Ticket: PC 17599, Fare: 71.2833, Cabin: C85, Embarked: C\\n3. PassengerId: 3, Survived: 1, Pclass: 3, Name: Heikkinen, Miss. Laina, Sex: female, Age: 26.0, SibSp: 0, Parch: 0, Ticket: STON/O2. 3101282, Fare: 7.925, Cabin: None, Embarked: S'"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.run(\"Describe the titanic table\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example: running queries"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n",
"\u001B[32;1m\u001B[1;3mAction: list_tables_sql_db\n",
"Action Input: \u001B[0m\n",
"Observation: \u001B[38;5;200m\u001B[1;3mtitanic\u001B[0m\n",
"Thought:\u001B[32;1m\u001B[1;3mI should check the schema of the titanic table to see if there is an age column.\n",
"Action: schema_sql_db\n",
"Action Input: titanic\u001B[0m\n",
"Observation: \u001B[33;1m\u001B[1;3mCREATE TABLE langchain_example.titanic (\n",
" PassengerId INT,\n",
" Survived INT,\n",
" Pclass INT,\n",
" Name STRING,\n",
" Sex STRING,\n",
" Age DOUBLE,\n",
" SibSp INT,\n",
" Parch INT,\n",
" Ticket STRING,\n",
" Fare DOUBLE,\n",
" Cabin STRING,\n",
" Embarked STRING)\n",
";\n",
"\n",
"/*\n",
"3 rows from titanic table:\n",
"PassengerId\tSurvived\tPclass\tName\tSex\tAge\tSibSp\tParch\tTicket\tFare\tCabin\tEmbarked\n",
"1\t0\t3\tBraund, Mr. Owen Harris\tmale\t22.0\t1\t0\tA/5 21171\t7.25\tNone\tS\n",
"2\t1\t1\tCumings, Mrs. John Bradley (Florence Briggs Thayer)\tfemale\t38.0\t1\t0\tPC 17599\t71.2833\tC85\tC\n",
"3\t1\t3\tHeikkinen, Miss. Laina\tfemale\t26.0\t0\t0\tSTON/O2. 3101282\t7.925\tNone\tS\n",
"*/\u001B[0m\n",
"Thought:\u001B[32;1m\u001B[1;3mThere is an Age column in the titanic table. I should write a query to calculate the average age and then find the square root of the result.\n",
"Action: query_checker_sql_db\n",
"Action Input: SELECT SQRT(AVG(Age)) as square_root_of_avg_age FROM titanic\u001B[0m\n",
"Observation: \u001B[31;1m\u001B[1;3mThe original query seems to be correct. Here it is again:\n",
"\n",
"SELECT SQRT(AVG(Age)) as square_root_of_avg_age FROM titanic\u001B[0m\n",
"Thought:\u001B[32;1m\u001B[1;3mThe query is correct, so I can execute it to find the square root of the average age.\n",
"Action: query_sql_db\n",
"Action Input: SELECT SQRT(AVG(Age)) as square_root_of_avg_age FROM titanic\u001B[0m\n",
"Observation: \u001B[36;1m\u001B[1;3m[('5.449689683556195',)]\u001B[0m\n",
"Thought:\u001B[32;1m\u001B[1;3mI now know the final answer\n",
"Final Answer: The square root of the average age is approximately 5.45.\u001B[0m\n",
"\n",
"\u001B[1m> Finished chain.\u001B[0m\n"
]
},
{
"data": {
"text/plain": "'The square root of the average age is approximately 5.45.'"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.run(\"whats the square root of the average age?\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n",
"\u001B[32;1m\u001B[1;3mAction: list_tables_sql_db\n",
"Action Input: \u001B[0m\n",
"Observation: \u001B[38;5;200m\u001B[1;3mtitanic\u001B[0m\n",
"Thought:\u001B[32;1m\u001B[1;3mI should check the schema of the titanic table to see what columns are available.\n",
"Action: schema_sql_db\n",
"Action Input: titanic\u001B[0m\n",
"Observation: \u001B[33;1m\u001B[1;3mCREATE TABLE langchain_example.titanic (\n",
" PassengerId INT,\n",
" Survived INT,\n",
" Pclass INT,\n",
" Name STRING,\n",
" Sex STRING,\n",
" Age DOUBLE,\n",
" SibSp INT,\n",
" Parch INT,\n",
" Ticket STRING,\n",
" Fare DOUBLE,\n",
" Cabin STRING,\n",
" Embarked STRING)\n",
";\n",
"\n",
"/*\n",
"3 rows from titanic table:\n",
"PassengerId\tSurvived\tPclass\tName\tSex\tAge\tSibSp\tParch\tTicket\tFare\tCabin\tEmbarked\n",
"1\t0\t3\tBraund, Mr. Owen Harris\tmale\t22.0\t1\t0\tA/5 21171\t7.25\tNone\tS\n",
"2\t1\t1\tCumings, Mrs. John Bradley (Florence Briggs Thayer)\tfemale\t38.0\t1\t0\tPC 17599\t71.2833\tC85\tC\n",
"3\t1\t3\tHeikkinen, Miss. Laina\tfemale\t26.0\t0\t0\tSTON/O2. 3101282\t7.925\tNone\tS\n",
"*/\u001B[0m\n",
"Thought:\u001B[32;1m\u001B[1;3mI can use the titanic table to find the oldest survived passenger. I will query the Name and Age columns, filtering by Survived and ordering by Age in descending order.\n",
"Action: query_checker_sql_db\n",
"Action Input: SELECT Name, Age FROM titanic WHERE Survived = 1 ORDER BY Age DESC LIMIT 1\u001B[0m\n",
"Observation: \u001B[31;1m\u001B[1;3mSELECT Name, Age FROM titanic WHERE Survived = 1 ORDER BY Age DESC LIMIT 1\u001B[0m\n",
"Thought:\u001B[32;1m\u001B[1;3mThe query is correct. Now I will execute it to find the oldest survived passenger.\n",
"Action: query_sql_db\n",
"Action Input: SELECT Name, Age FROM titanic WHERE Survived = 1 ORDER BY Age DESC LIMIT 1\u001B[0m\n",
"Observation: \u001B[36;1m\u001B[1;3m[('Barkworth, Mr. Algernon Henry Wilson', '80.0')]\u001B[0m\n",
"Thought:\u001B[32;1m\u001B[1;3mI now know the final answer.\n",
"Final Answer: The oldest survived passenger is Barkworth, Mr. Algernon Henry Wilson, who was 80 years old.\u001B[0m\n",
"\n",
"\u001B[1m> Finished chain.\u001B[0m\n"
]
},
{
"data": {
"text/plain": "'The oldest survived passenger is Barkworth, Mr. Algernon Henry Wilson, who was 80 years old.'"
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.run(\"What's the name of the oldest survived passenger?\")"
],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -15,6 +15,7 @@ from langchain.agents.agent_toolkits import (
create_pbi_agent,
create_pbi_chat_agent,
create_spark_dataframe_agent,
create_spark_sql_agent,
create_sql_agent,
create_vectorstore_agent,
create_vectorstore_router_agent,
@ -59,6 +60,7 @@ __all__ = [
"create_pbi_agent",
"create_pbi_chat_agent",
"create_spark_dataframe_agent",
"create_spark_sql_agent",
"create_sql_agent",
"create_vectorstore_agent",
"create_vectorstore_router_agent",

@ -18,6 +18,8 @@ from langchain.agents.agent_toolkits.powerbi.chat_base import create_pbi_chat_ag
from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit
from langchain.agents.agent_toolkits.python.base import create_python_agent
from langchain.agents.agent_toolkits.spark.base import create_spark_dataframe_agent
from langchain.agents.agent_toolkits.spark_sql.base import create_spark_sql_agent
from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
from langchain.agents.agent_toolkits.sql.base import create_sql_agent
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_toolkits.vectorstore.base import (
@ -41,6 +43,7 @@ __all__ = [
"create_vectorstore_agent",
"JsonToolkit",
"SQLDatabaseToolkit",
"SparkSQLToolkit",
"NLAToolkit",
"PowerBIToolkit",
"OpenAPIToolkit",
@ -50,6 +53,7 @@ __all__ = [
"VectorStoreRouterToolkit",
"create_pandas_dataframe_agent",
"create_spark_dataframe_agent",
"create_spark_sql_agent",
"create_csv_agent",
"ZapierToolkit",
"GmailToolkit",

@ -0,0 +1,56 @@
"""Spark SQL agent."""
from typing import Any, Dict, List, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUFFIX
from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
def create_spark_sql_agent(
llm: BaseLanguageModel,
toolkit: SparkSQLToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = SQL_PREFIX,
suffix: str = SQL_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a sql agent from an LLM and tools."""
tools = toolkit.get_tools()
prefix = prefix.format(top_k=top_k)
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)

@ -0,0 +1,21 @@
# flake8: noqa
SQL_PREFIX = """You are an agent designed to interact with Spark SQL.
Given an input question, create a syntactically correct Spark SQL query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
"""
SQL_SUFFIX = """Begin!
Question: {input}
Thought: I should look at the tables in the database to see what I can query.
{agent_scratchpad}"""

@ -0,0 +1,36 @@
"""Toolkit for interacting with Spark SQL."""
from typing import List
from pydantic import Field
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.base_language import BaseLanguageModel
from langchain.tools import BaseTool
from langchain.tools.spark_sql.tool import (
InfoSparkSQLTool,
ListSparkSQLTool,
QueryCheckerTool,
QuerySparkSQLTool,
)
from langchain.utilities.spark_sql import SparkSQL
class SparkSQLToolkit(BaseToolkit):
"""Toolkit for interacting with Spark SQL."""
db: SparkSQL = Field(exclude=True)
llm: BaseLanguageModel = Field(exclude=True)
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
QuerySparkSQLTool(db=self.db),
InfoSparkSQLTool(db=self.db),
ListSparkSQLTool(db=self.db),
QueryCheckerTool(db=self.db, llm=self.llm),
]

@ -0,0 +1 @@
"""Tools for interacting with Spark SQL."""

@ -0,0 +1,14 @@
# flake8: noqa
QUERY_CHECKER = """
{query}
Double check the Spark SQL query above for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query."""

@ -0,0 +1,152 @@
# flake8: noqa
"""Tools for interacting with Spark SQL."""
from typing import Any, Dict, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
from langchain.utilities.spark_sql import SparkSQL
from langchain.tools.base import BaseTool
from langchain.tools.spark_sql.prompt import QUERY_CHECKER
class BaseSparkSQLTool(BaseModel):
"""Base tool for interacting with Spark SQL."""
db: SparkSQL = Field(exclude=True)
# Override BaseTool.Config to appease mypy
# See https://github.com/pydantic/pydantic/issues/4173
class Config(BaseTool.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
extra = Extra.forbid
class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
"""Tool for querying a Spark SQL."""
name = "query_sql_db"
description = """
Input to this tool is a detailed and correct SQL query, output is a result from the Spark SQL.
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
"""
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Execute the query, return the results or an error message."""
return self.db.run_no_throw(query)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("QuerySqlDbTool does not support async")
class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool):
"""Tool for getting metadata about a Spark SQL."""
name = "schema_sql_db"
description = """
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Be sure that the tables actually exist by calling list_tables_sql_db first!
Example Input: "table1, table2, table3"
"""
def _run(
self,
table_names: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for tables in a comma-separated list."""
return self.db.get_table_info_no_throw(table_names.split(", "))
async def _arun(
self,
table_name: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("SchemaSqlDbTool does not support async")
class ListSparkSQLTool(BaseSparkSQLTool, BaseTool):
"""Tool for getting tables names."""
name = "list_tables_sql_db"
description = "Input is an empty string, output is a comma separated list of tables in the Spark SQL."
def _run(
self,
tool_input: str = "",
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for a specific table."""
return ", ".join(self.db.get_usable_table_names())
async def _arun(
self,
tool_input: str = "",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("ListTablesSqlDbTool does not support async")
class QueryCheckerTool(BaseSparkSQLTool, BaseTool):
"""Use an LLM to check if a query is correct.
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
template: str = QUERY_CHECKER
llm: BaseLanguageModel
llm_chain: LLMChain = Field(init=False)
name = "query_checker_sql_db"
description = """
Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with query_sql_db!
"""
@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "llm_chain" not in values:
values["llm_chain"] = LLMChain(
llm=values.get("llm"),
prompt=PromptTemplate(
template=QUERY_CHECKER, input_variables=["query"]
),
)
if values["llm_chain"].prompt.input_variables != ["query"]:
raise ValueError(
"LLM chain for QueryCheckerTool need to use ['query'] as input_variables "
"for the embedded prompt"
)
return values
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the LLM to check the query."""
return self.llm_chain.predict(query=query)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
return await self.llm_chain.apredict(query=query)

@ -16,6 +16,7 @@ from langchain.utilities.powerbi import PowerBIDataset
from langchain.utilities.python import PythonREPL
from langchain.utilities.searx_search import SearxSearchWrapper
from langchain.utilities.serpapi import SerpAPIWrapper
from langchain.utilities.spark_sql import SparkSQL
from langchain.utilities.wikipedia import WikipediaAPIWrapper
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
@ -38,5 +39,6 @@ __all__ = [
"PythonREPL",
"LambdaWrapper",
"PowerBIDataset",
"SparkSQL",
"MetaphorSearchAPIWrapper",
]

@ -0,0 +1,174 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Iterable, List, Optional
if TYPE_CHECKING:
from pyspark.sql import DataFrame, Row, SparkSession
class SparkSQL:
def __init__(
self,
spark_session: Optional[SparkSession] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
):
try:
from pyspark.sql import SparkSession
except ImportError:
raise ValueError(
"pyspark is not installed. Please install it with `pip install pyspark`"
)
self._spark = (
spark_session if spark_session else SparkSession.builder.getOrCreate()
)
if catalog is not None:
self._spark.catalog.setCurrentCatalog(catalog)
if schema is not None:
self._spark.catalog.setCurrentDatabase(schema)
self._all_tables = set(self._get_all_table_names())
self._include_tables = set(include_tables) if include_tables else set()
if self._include_tables:
missing_tables = self._include_tables - self._all_tables
if missing_tables:
raise ValueError(
f"include_tables {missing_tables} not found in database"
)
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
if self._ignore_tables:
missing_tables = self._ignore_tables - self._all_tables
if missing_tables:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
usable_tables = self.get_usable_table_names()
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
if not isinstance(sample_rows_in_table_info, int):
raise TypeError("sample_rows_in_table_info must be an integer")
self._sample_rows_in_table_info = sample_rows_in_table_info
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> SparkSQL:
"""Creating a remote Spark Session via Spark connect.
For example: SparkSQL.from_uri("sc://localhost:15002")
"""
try:
from pyspark.sql import SparkSession
except ImportError:
raise ValueError(
"pyspark is not installed. Please install it with `pip install pyspark`"
)
spark = SparkSession.builder.remote(database_uri).getOrCreate()
return cls(spark, **kwargs)
def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return self._include_tables
# sorting the result can help LLM understanding it.
return sorted(self._all_tables - self._ignore_tables)
def _get_all_table_names(self) -> Iterable[str]:
rows = self._spark.sql("SHOW TABLES").select("tableName").collect()
return list(map(lambda row: row.tableName, rows))
def _get_create_table_stmt(self, table: str) -> str:
statement = (
self._spark.sql(f"SHOW CREATE TABLE {table}").collect()[0].createtab_stmt
)
# Ignore the data source provider and options to reduce the number of tokens.
using_clause_index = statement.find("USING")
return statement[:using_clause_index] + ";"
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
tables = []
for table_name in all_table_names:
table_info = self._get_create_table_stmt(table_name)
if self._sample_rows_in_table_info:
table_info += "\n\n/*"
table_info += f"\n{self._get_sample_spark_rows(table_name)}\n"
table_info += "*/"
tables.append(table_info)
final_str = "\n\n".join(tables)
return final_str
def _get_sample_spark_rows(self, table: str) -> str:
query = f"SELECT * FROM {table} LIMIT {self._sample_rows_in_table_info}"
df = self._spark.sql(query)
columns_str = "\t".join(list(map(lambda f: f.name, df.schema.fields)))
try:
sample_rows = self._get_dataframe_results(df)
# save the sample rows in string format
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
except Exception:
sample_rows_str = ""
return (
f"{self._sample_rows_in_table_info} rows from {table} table:\n"
f"{columns_str}\n"
f"{sample_rows_str}"
)
def _convert_row_as_tuple(self, row: Row) -> tuple:
return tuple(map(str, row.asDict().values()))
def _get_dataframe_results(self, df: DataFrame) -> list:
return list(map(self._convert_row_as_tuple, df.collect()))
def run(self, command: str, fetch: str = "all") -> str:
df = self._spark.sql(command)
if fetch == "one":
df = df.limit(1)
return str(self._get_dataframe_results(df))
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables.
Follows best practices as specified in: Rajkumar et al, 2022
(https://arxiv.org/abs/2204.00498)
If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as
demonstrated in the paper.
"""
try:
return self.get_table_info(table_names)
except ValueError as e:
"""Format the error message"""
return f"Error: {e}"
def run_no_throw(self, command: str, fetch: str = "all") -> str:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
If the statement throws an error, the error message is returned.
"""
try:
from pyspark.errors import PySparkException
except ImportError:
raise ValueError(
"pyspark is not installed. Please install it with `pip install pyspark`"
)
try:
return self.run(command, fetch)
except PySparkException as e:
"""Format the error message"""
return f"Error: {e}"

@ -24,6 +24,7 @@ _EXPECTED = [
"create_pbi_agent",
"create_pbi_chat_agent",
"create_spark_dataframe_agent",
"create_spark_sql_agent",
"create_sql_agent",
"create_vectorstore_agent",
"create_vectorstore_router_agent",

Loading…
Cancel
Save