diff --git a/docs/modules/agents/toolkits/examples/spark_sql.ipynb b/docs/modules/agents/toolkits/examples/spark_sql.ipynb new file mode 100644 index 00000000..c4405c05 --- /dev/null +++ b/docs/modules/agents/toolkits/examples/spark_sql.ipynb @@ -0,0 +1,348 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Spark SQL Agent\n", + "\n", + "This notebook shows how to use agents to interact with a Spark SQL. Similar to [SQL Database Agent](https://python.langchain.com/en/latest/modules/agents/toolkits/examples/sql_database.html), it is designed to address general inquiries about Spark SQL and facilitate error recovery.\n", + "\n", + "**NOTE: Note that, as this agent is in active development, all answers might not be correct. Additionally, it is not guaranteed that the agent won't perform DML statements on your Spark cluster given certain questions. Be careful running it on sensitive data!**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialization" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.agents import create_spark_sql_agent\n", + "from langchain.agents.agent_toolkits import SparkSQLToolkit\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.utilities.spark_sql import SparkSQL" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting default log level to \"WARN\".\n", + "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", + "23/05/18 16:03:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n", + "|PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|\n", + "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n", + "| 1| 0| 3|Braund, Mr. Owen ...| male|22.0| 1| 0| A/5 21171| 7.25| null| S|\n", + "| 2| 1| 1|Cumings, Mrs. Joh...|female|38.0| 1| 0| PC 17599|71.2833| C85| C|\n", + "| 3| 1| 3|Heikkinen, Miss. ...|female|26.0| 0| 0|STON/O2. 3101282| 7.925| null| S|\n", + "| 4| 1| 1|Futrelle, Mrs. Ja...|female|35.0| 1| 0| 113803| 53.1| C123| S|\n", + "| 5| 0| 3|Allen, Mr. Willia...| male|35.0| 0| 0| 373450| 8.05| null| S|\n", + "| 6| 0| 3| Moran, Mr. James| male|null| 0| 0| 330877| 8.4583| null| Q|\n", + "| 7| 0| 1|McCarthy, Mr. Tim...| male|54.0| 0| 0| 17463|51.8625| E46| S|\n", + "| 8| 0| 3|Palsson, Master. ...| male| 2.0| 3| 1| 349909| 21.075| null| S|\n", + "| 9| 1| 3|Johnson, Mrs. Osc...|female|27.0| 0| 2| 347742|11.1333| null| S|\n", + "| 10| 1| 2|Nasser, Mrs. Nich...|female|14.0| 1| 0| 237736|30.0708| null| C|\n", + "| 11| 1| 3|Sandstrom, Miss. ...|female| 4.0| 1| 1| PP 9549| 16.7| G6| S|\n", + "| 12| 1| 1|Bonnell, Miss. El...|female|58.0| 0| 0| 113783| 26.55| C103| S|\n", + "| 13| 0| 3|Saundercock, Mr. ...| male|20.0| 0| 0| A/5. 2151| 8.05| null| S|\n", + "| 14| 0| 3|Andersson, Mr. An...| male|39.0| 1| 5| 347082| 31.275| null| S|\n", + "| 15| 0| 3|Vestrom, Miss. Hu...|female|14.0| 0| 0| 350406| 7.8542| null| S|\n", + "| 16| 1| 2|Hewlett, Mrs. (Ma...|female|55.0| 0| 0| 248706| 16.0| null| S|\n", + "| 17| 0| 3|Rice, Master. Eugene| male| 2.0| 4| 1| 382652| 29.125| null| Q|\n", + "| 18| 1| 2|Williams, Mr. Cha...| male|null| 0| 0| 244373| 13.0| null| S|\n", + "| 19| 0| 3|Vander Planke, Mr...|female|31.0| 1| 0| 345763| 18.0| null| S|\n", + "| 20| 1| 3|Masselmani, Mrs. ...|female|null| 0| 0| 2649| 7.225| null| C|\n", + "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n", + "only showing top 20 rows\n", + "\n" + ] + } + ], + "source": [ + "from pyspark.sql import SparkSession\n", + "\n", + "spark = SparkSession.builder.getOrCreate()\n", + "schema = \"langchain_example\"\n", + "spark.sql(f\"CREATE DATABASE IF NOT EXISTS {schema}\")\n", + "spark.sql(f\"USE {schema}\")\n", + "csv_file_path = \"titanic.csv\"\n", + "table = \"titanic\"\n", + "spark.read.csv(csv_file_path, header=True, inferSchema=True).write.saveAsTable(table)\n", + "spark.table(table).show()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Note, you can also connect to Spark via Spark connect. For example:\n", + "# db = SparkSQL.from_uri(\"sc://localhost:15002\", schema=schema)\n", + "spark_sql = SparkSQL(schema=schema)\n", + "llm = ChatOpenAI(temperature=0)\n", + "toolkit = SparkSQLToolkit(db=spark_sql, llm=llm)\n", + "agent_executor = create_spark_sql_agent(\n", + " llm=llm,\n", + " toolkit=toolkit,\n", + " verbose=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example: describing a table" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mAction: list_tables_sql_db\n", + "Action Input: \u001B[0m\n", + "Observation: \u001B[38;5;200m\u001B[1;3mtitanic\u001B[0m\n", + "Thought:\u001B[32;1m\u001B[1;3mI found the titanic table. Now I need to get the schema and sample rows for the titanic table.\n", + "Action: schema_sql_db\n", + "Action Input: titanic\u001B[0m\n", + "Observation: \u001B[33;1m\u001B[1;3mCREATE TABLE langchain_example.titanic (\n", + " PassengerId INT,\n", + " Survived INT,\n", + " Pclass INT,\n", + " Name STRING,\n", + " Sex STRING,\n", + " Age DOUBLE,\n", + " SibSp INT,\n", + " Parch INT,\n", + " Ticket STRING,\n", + " Fare DOUBLE,\n", + " Cabin STRING,\n", + " Embarked STRING)\n", + ";\n", + "\n", + "/*\n", + "3 rows from titanic table:\n", + "PassengerId\tSurvived\tPclass\tName\tSex\tAge\tSibSp\tParch\tTicket\tFare\tCabin\tEmbarked\n", + "1\t0\t3\tBraund, Mr. Owen Harris\tmale\t22.0\t1\t0\tA/5 21171\t7.25\tNone\tS\n", + "2\t1\t1\tCumings, Mrs. John Bradley (Florence Briggs Thayer)\tfemale\t38.0\t1\t0\tPC 17599\t71.2833\tC85\tC\n", + "3\t1\t3\tHeikkinen, Miss. Laina\tfemale\t26.0\t0\t0\tSTON/O2. 3101282\t7.925\tNone\tS\n", + "*/\u001B[0m\n", + "Thought:\u001B[32;1m\u001B[1;3mI now know the schema and sample rows for the titanic table.\n", + "Final Answer: The titanic table has the following columns: PassengerId (INT), Survived (INT), Pclass (INT), Name (STRING), Sex (STRING), Age (DOUBLE), SibSp (INT), Parch (INT), Ticket (STRING), Fare (DOUBLE), Cabin (STRING), and Embarked (STRING). Here are some sample rows from the table: \n", + "\n", + "1. PassengerId: 1, Survived: 0, Pclass: 3, Name: Braund, Mr. Owen Harris, Sex: male, Age: 22.0, SibSp: 1, Parch: 0, Ticket: A/5 21171, Fare: 7.25, Cabin: None, Embarked: S\n", + "2. PassengerId: 2, Survived: 1, Pclass: 1, Name: Cumings, Mrs. John Bradley (Florence Briggs Thayer), Sex: female, Age: 38.0, SibSp: 1, Parch: 0, Ticket: PC 17599, Fare: 71.2833, Cabin: C85, Embarked: C\n", + "3. PassengerId: 3, Survived: 1, Pclass: 3, Name: Heikkinen, Miss. Laina, Sex: female, Age: 26.0, SibSp: 0, Parch: 0, Ticket: STON/O2. 3101282, Fare: 7.925, Cabin: None, Embarked: S\u001B[0m\n", + "\n", + "\u001B[1m> Finished chain.\u001B[0m\n" + ] + }, + { + "data": { + "text/plain": "'The titanic table has the following columns: PassengerId (INT), Survived (INT), Pclass (INT), Name (STRING), Sex (STRING), Age (DOUBLE), SibSp (INT), Parch (INT), Ticket (STRING), Fare (DOUBLE), Cabin (STRING), and Embarked (STRING). Here are some sample rows from the table: \\n\\n1. PassengerId: 1, Survived: 0, Pclass: 3, Name: Braund, Mr. Owen Harris, Sex: male, Age: 22.0, SibSp: 1, Parch: 0, Ticket: A/5 21171, Fare: 7.25, Cabin: None, Embarked: S\\n2. PassengerId: 2, Survived: 1, Pclass: 1, Name: Cumings, Mrs. John Bradley (Florence Briggs Thayer), Sex: female, Age: 38.0, SibSp: 1, Parch: 0, Ticket: PC 17599, Fare: 71.2833, Cabin: C85, Embarked: C\\n3. PassengerId: 3, Survived: 1, Pclass: 3, Name: Heikkinen, Miss. Laina, Sex: female, Age: 26.0, SibSp: 0, Parch: 0, Ticket: STON/O2. 3101282, Fare: 7.925, Cabin: None, Embarked: S'" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.run(\"Describe the titanic table\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example: running queries" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mAction: list_tables_sql_db\n", + "Action Input: \u001B[0m\n", + "Observation: \u001B[38;5;200m\u001B[1;3mtitanic\u001B[0m\n", + "Thought:\u001B[32;1m\u001B[1;3mI should check the schema of the titanic table to see if there is an age column.\n", + "Action: schema_sql_db\n", + "Action Input: titanic\u001B[0m\n", + "Observation: \u001B[33;1m\u001B[1;3mCREATE TABLE langchain_example.titanic (\n", + " PassengerId INT,\n", + " Survived INT,\n", + " Pclass INT,\n", + " Name STRING,\n", + " Sex STRING,\n", + " Age DOUBLE,\n", + " SibSp INT,\n", + " Parch INT,\n", + " Ticket STRING,\n", + " Fare DOUBLE,\n", + " Cabin STRING,\n", + " Embarked STRING)\n", + ";\n", + "\n", + "/*\n", + "3 rows from titanic table:\n", + "PassengerId\tSurvived\tPclass\tName\tSex\tAge\tSibSp\tParch\tTicket\tFare\tCabin\tEmbarked\n", + "1\t0\t3\tBraund, Mr. Owen Harris\tmale\t22.0\t1\t0\tA/5 21171\t7.25\tNone\tS\n", + "2\t1\t1\tCumings, Mrs. John Bradley (Florence Briggs Thayer)\tfemale\t38.0\t1\t0\tPC 17599\t71.2833\tC85\tC\n", + "3\t1\t3\tHeikkinen, Miss. Laina\tfemale\t26.0\t0\t0\tSTON/O2. 3101282\t7.925\tNone\tS\n", + "*/\u001B[0m\n", + "Thought:\u001B[32;1m\u001B[1;3mThere is an Age column in the titanic table. I should write a query to calculate the average age and then find the square root of the result.\n", + "Action: query_checker_sql_db\n", + "Action Input: SELECT SQRT(AVG(Age)) as square_root_of_avg_age FROM titanic\u001B[0m\n", + "Observation: \u001B[31;1m\u001B[1;3mThe original query seems to be correct. Here it is again:\n", + "\n", + "SELECT SQRT(AVG(Age)) as square_root_of_avg_age FROM titanic\u001B[0m\n", + "Thought:\u001B[32;1m\u001B[1;3mThe query is correct, so I can execute it to find the square root of the average age.\n", + "Action: query_sql_db\n", + "Action Input: SELECT SQRT(AVG(Age)) as square_root_of_avg_age FROM titanic\u001B[0m\n", + "Observation: \u001B[36;1m\u001B[1;3m[('5.449689683556195',)]\u001B[0m\n", + "Thought:\u001B[32;1m\u001B[1;3mI now know the final answer\n", + "Final Answer: The square root of the average age is approximately 5.45.\u001B[0m\n", + "\n", + "\u001B[1m> Finished chain.\u001B[0m\n" + ] + }, + { + "data": { + "text/plain": "'The square root of the average age is approximately 5.45.'" + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.run(\"whats the square root of the average age?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mAction: list_tables_sql_db\n", + "Action Input: \u001B[0m\n", + "Observation: \u001B[38;5;200m\u001B[1;3mtitanic\u001B[0m\n", + "Thought:\u001B[32;1m\u001B[1;3mI should check the schema of the titanic table to see what columns are available.\n", + "Action: schema_sql_db\n", + "Action Input: titanic\u001B[0m\n", + "Observation: \u001B[33;1m\u001B[1;3mCREATE TABLE langchain_example.titanic (\n", + " PassengerId INT,\n", + " Survived INT,\n", + " Pclass INT,\n", + " Name STRING,\n", + " Sex STRING,\n", + " Age DOUBLE,\n", + " SibSp INT,\n", + " Parch INT,\n", + " Ticket STRING,\n", + " Fare DOUBLE,\n", + " Cabin STRING,\n", + " Embarked STRING)\n", + ";\n", + "\n", + "/*\n", + "3 rows from titanic table:\n", + "PassengerId\tSurvived\tPclass\tName\tSex\tAge\tSibSp\tParch\tTicket\tFare\tCabin\tEmbarked\n", + "1\t0\t3\tBraund, Mr. Owen Harris\tmale\t22.0\t1\t0\tA/5 21171\t7.25\tNone\tS\n", + "2\t1\t1\tCumings, Mrs. John Bradley (Florence Briggs Thayer)\tfemale\t38.0\t1\t0\tPC 17599\t71.2833\tC85\tC\n", + "3\t1\t3\tHeikkinen, Miss. Laina\tfemale\t26.0\t0\t0\tSTON/O2. 3101282\t7.925\tNone\tS\n", + "*/\u001B[0m\n", + "Thought:\u001B[32;1m\u001B[1;3mI can use the titanic table to find the oldest survived passenger. I will query the Name and Age columns, filtering by Survived and ordering by Age in descending order.\n", + "Action: query_checker_sql_db\n", + "Action Input: SELECT Name, Age FROM titanic WHERE Survived = 1 ORDER BY Age DESC LIMIT 1\u001B[0m\n", + "Observation: \u001B[31;1m\u001B[1;3mSELECT Name, Age FROM titanic WHERE Survived = 1 ORDER BY Age DESC LIMIT 1\u001B[0m\n", + "Thought:\u001B[32;1m\u001B[1;3mThe query is correct. Now I will execute it to find the oldest survived passenger.\n", + "Action: query_sql_db\n", + "Action Input: SELECT Name, Age FROM titanic WHERE Survived = 1 ORDER BY Age DESC LIMIT 1\u001B[0m\n", + "Observation: \u001B[36;1m\u001B[1;3m[('Barkworth, Mr. Algernon Henry Wilson', '80.0')]\u001B[0m\n", + "Thought:\u001B[32;1m\u001B[1;3mI now know the final answer.\n", + "Final Answer: The oldest survived passenger is Barkworth, Mr. Algernon Henry Wilson, who was 80 years old.\u001B[0m\n", + "\n", + "\u001B[1m> Finished chain.\u001B[0m\n" + ] + }, + { + "data": { + "text/plain": "'The oldest survived passenger is Barkworth, Mr. Algernon Henry Wilson, who was 80 years old.'" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.run(\"What's the name of the oldest survived passenger?\")" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/langchain/agents/__init__.py b/langchain/agents/__init__.py index 86af75ff..28a9d32d 100644 --- a/langchain/agents/__init__.py +++ b/langchain/agents/__init__.py @@ -15,6 +15,7 @@ from langchain.agents.agent_toolkits import ( create_pbi_agent, create_pbi_chat_agent, create_spark_dataframe_agent, + create_spark_sql_agent, create_sql_agent, create_vectorstore_agent, create_vectorstore_router_agent, @@ -59,6 +60,7 @@ __all__ = [ "create_pbi_agent", "create_pbi_chat_agent", "create_spark_dataframe_agent", + "create_spark_sql_agent", "create_sql_agent", "create_vectorstore_agent", "create_vectorstore_router_agent", diff --git a/langchain/agents/agent_toolkits/__init__.py b/langchain/agents/agent_toolkits/__init__.py index 4ceb63ec..76defec7 100644 --- a/langchain/agents/agent_toolkits/__init__.py +++ b/langchain/agents/agent_toolkits/__init__.py @@ -18,6 +18,8 @@ from langchain.agents.agent_toolkits.powerbi.chat_base import create_pbi_chat_ag from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit from langchain.agents.agent_toolkits.python.base import create_python_agent from langchain.agents.agent_toolkits.spark.base import create_spark_dataframe_agent +from langchain.agents.agent_toolkits.spark_sql.base import create_spark_sql_agent +from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit from langchain.agents.agent_toolkits.sql.base import create_sql_agent from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit from langchain.agents.agent_toolkits.vectorstore.base import ( @@ -41,6 +43,7 @@ __all__ = [ "create_vectorstore_agent", "JsonToolkit", "SQLDatabaseToolkit", + "SparkSQLToolkit", "NLAToolkit", "PowerBIToolkit", "OpenAPIToolkit", @@ -50,6 +53,7 @@ __all__ = [ "VectorStoreRouterToolkit", "create_pandas_dataframe_agent", "create_spark_dataframe_agent", + "create_spark_sql_agent", "create_csv_agent", "ZapierToolkit", "GmailToolkit", diff --git a/langchain/agents/agent_toolkits/spark_sql/__init__.py b/langchain/agents/agent_toolkits/spark_sql/__init__.py new file mode 100644 index 00000000..4308c079 --- /dev/null +++ b/langchain/agents/agent_toolkits/spark_sql/__init__.py @@ -0,0 +1 @@ +"""Spark SQL agent.""" diff --git a/langchain/agents/agent_toolkits/spark_sql/base.py b/langchain/agents/agent_toolkits/spark_sql/base.py new file mode 100644 index 00000000..3d4eb399 --- /dev/null +++ b/langchain/agents/agent_toolkits/spark_sql/base.py @@ -0,0 +1,56 @@ +"""Spark SQL agent.""" +from typing import Any, Dict, List, Optional + +from langchain.agents.agent import AgentExecutor +from langchain.agents.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUFFIX +from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit +from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.base import BaseCallbackManager +from langchain.chains.llm import LLMChain + + +def create_spark_sql_agent( + llm: BaseLanguageModel, + toolkit: SparkSQLToolkit, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = SQL_PREFIX, + suffix: str = SQL_SUFFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, + input_variables: Optional[List[str]] = None, + top_k: int = 10, + max_iterations: Optional[int] = 15, + max_execution_time: Optional[float] = None, + early_stopping_method: str = "force", + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Dict[str, Any], +) -> AgentExecutor: + """Construct a sql agent from an LLM and tools.""" + tools = toolkit.get_tools() + prefix = prefix.format(top_k=top_k) + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=prefix, + suffix=suffix, + format_instructions=format_instructions, + input_variables=input_variables, + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + max_iterations=max_iterations, + max_execution_time=max_execution_time, + early_stopping_method=early_stopping_method, + **(agent_executor_kwargs or {}), + ) diff --git a/langchain/agents/agent_toolkits/spark_sql/prompt.py b/langchain/agents/agent_toolkits/spark_sql/prompt.py new file mode 100644 index 00000000..b499085d --- /dev/null +++ b/langchain/agents/agent_toolkits/spark_sql/prompt.py @@ -0,0 +1,21 @@ +# flake8: noqa + +SQL_PREFIX = """You are an agent designed to interact with Spark SQL. +Given an input question, create a syntactically correct Spark SQL query to run, then look at the results of the query and return the answer. +Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. +You can order the results by a relevant column to return the most interesting examples in the database. +Never query for all the columns from a specific table, only ask for the relevant columns given the question. +You have access to tools for interacting with the database. +Only use the below tools. Only use the information returned by the below tools to construct your final answer. +You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. + +DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. + +If the question does not seem related to the database, just return "I don't know" as the answer. +""" + +SQL_SUFFIX = """Begin! + +Question: {input} +Thought: I should look at the tables in the database to see what I can query. +{agent_scratchpad}""" diff --git a/langchain/agents/agent_toolkits/spark_sql/toolkit.py b/langchain/agents/agent_toolkits/spark_sql/toolkit.py new file mode 100644 index 00000000..620f12f9 --- /dev/null +++ b/langchain/agents/agent_toolkits/spark_sql/toolkit.py @@ -0,0 +1,36 @@ +"""Toolkit for interacting with Spark SQL.""" +from typing import List + +from pydantic import Field + +from langchain.agents.agent_toolkits.base import BaseToolkit +from langchain.base_language import BaseLanguageModel +from langchain.tools import BaseTool +from langchain.tools.spark_sql.tool import ( + InfoSparkSQLTool, + ListSparkSQLTool, + QueryCheckerTool, + QuerySparkSQLTool, +) +from langchain.utilities.spark_sql import SparkSQL + + +class SparkSQLToolkit(BaseToolkit): + """Toolkit for interacting with Spark SQL.""" + + db: SparkSQL = Field(exclude=True) + llm: BaseLanguageModel = Field(exclude=True) + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return [ + QuerySparkSQLTool(db=self.db), + InfoSparkSQLTool(db=self.db), + ListSparkSQLTool(db=self.db), + QueryCheckerTool(db=self.db, llm=self.llm), + ] diff --git a/langchain/tools/spark_sql/__init__.py b/langchain/tools/spark_sql/__init__.py new file mode 100644 index 00000000..01039b77 --- /dev/null +++ b/langchain/tools/spark_sql/__init__.py @@ -0,0 +1 @@ +"""Tools for interacting with Spark SQL.""" diff --git a/langchain/tools/spark_sql/prompt.py b/langchain/tools/spark_sql/prompt.py new file mode 100644 index 00000000..98a523b8 --- /dev/null +++ b/langchain/tools/spark_sql/prompt.py @@ -0,0 +1,14 @@ +# flake8: noqa +QUERY_CHECKER = """ +{query} +Double check the Spark SQL query above for common mistakes, including: +- Using NOT IN with NULL values +- Using UNION when UNION ALL should have been used +- Using BETWEEN for exclusive ranges +- Data type mismatch in predicates +- Properly quoting identifiers +- Using the correct number of arguments for functions +- Casting to the correct data type +- Using the proper columns for joins + +If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.""" diff --git a/langchain/tools/spark_sql/tool.py b/langchain/tools/spark_sql/tool.py new file mode 100644 index 00000000..43f862aa --- /dev/null +++ b/langchain/tools/spark_sql/tool.py @@ -0,0 +1,152 @@ +# flake8: noqa +"""Tools for interacting with Spark SQL.""" +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Extra, Field, root_validator + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.chains.llm import LLMChain +from langchain.prompts import PromptTemplate +from langchain.utilities.spark_sql import SparkSQL +from langchain.tools.base import BaseTool +from langchain.tools.spark_sql.prompt import QUERY_CHECKER + + +class BaseSparkSQLTool(BaseModel): + """Base tool for interacting with Spark SQL.""" + + db: SparkSQL = Field(exclude=True) + + # Override BaseTool.Config to appease mypy + # See https://github.com/pydantic/pydantic/issues/4173 + class Config(BaseTool.Config): + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + extra = Extra.forbid + + +class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool): + """Tool for querying a Spark SQL.""" + + name = "query_sql_db" + description = """ + Input to this tool is a detailed and correct SQL query, output is a result from the Spark SQL. + If the query is not correct, an error message will be returned. + If an error is returned, rewrite the query, check the query, and try again. + """ + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Execute the query, return the results or an error message.""" + return self.db.run_no_throw(query) + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("QuerySqlDbTool does not support async") + + +class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool): + """Tool for getting metadata about a Spark SQL.""" + + name = "schema_sql_db" + description = """ + Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. + Be sure that the tables actually exist by calling list_tables_sql_db first! + + Example Input: "table1, table2, table3" + """ + + def _run( + self, + table_names: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for tables in a comma-separated list.""" + return self.db.get_table_info_no_throw(table_names.split(", ")) + + async def _arun( + self, + table_name: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("SchemaSqlDbTool does not support async") + + +class ListSparkSQLTool(BaseSparkSQLTool, BaseTool): + """Tool for getting tables names.""" + + name = "list_tables_sql_db" + description = "Input is an empty string, output is a comma separated list of tables in the Spark SQL." + + def _run( + self, + tool_input: str = "", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for a specific table.""" + return ", ".join(self.db.get_usable_table_names()) + + async def _arun( + self, + tool_input: str = "", + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("ListTablesSqlDbTool does not support async") + + +class QueryCheckerTool(BaseSparkSQLTool, BaseTool): + """Use an LLM to check if a query is correct. + Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" + + template: str = QUERY_CHECKER + llm: BaseLanguageModel + llm_chain: LLMChain = Field(init=False) + name = "query_checker_sql_db" + description = """ + Use this tool to double check if your query is correct before executing it. + Always use this tool before executing a query with query_sql_db! + """ + + @root_validator(pre=True) + def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "llm_chain" not in values: + values["llm_chain"] = LLMChain( + llm=values.get("llm"), + prompt=PromptTemplate( + template=QUERY_CHECKER, input_variables=["query"] + ), + ) + + if values["llm_chain"].prompt.input_variables != ["query"]: + raise ValueError( + "LLM chain for QueryCheckerTool need to use ['query'] as input_variables " + "for the embedded prompt" + ) + + return values + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the LLM to check the query.""" + return self.llm_chain.predict(query=query) + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + return await self.llm_chain.apredict(query=query) diff --git a/langchain/utilities/__init__.py b/langchain/utilities/__init__.py index 1a0c4f43..bd45a8dc 100644 --- a/langchain/utilities/__init__.py +++ b/langchain/utilities/__init__.py @@ -16,6 +16,7 @@ from langchain.utilities.powerbi import PowerBIDataset from langchain.utilities.python import PythonREPL from langchain.utilities.searx_search import SearxSearchWrapper from langchain.utilities.serpapi import SerpAPIWrapper +from langchain.utilities.spark_sql import SparkSQL from langchain.utilities.wikipedia import WikipediaAPIWrapper from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper @@ -38,5 +39,6 @@ __all__ = [ "PythonREPL", "LambdaWrapper", "PowerBIDataset", + "SparkSQL", "MetaphorSearchAPIWrapper", ] diff --git a/langchain/utilities/spark_sql.py b/langchain/utilities/spark_sql.py new file mode 100644 index 00000000..12cfbd2d --- /dev/null +++ b/langchain/utilities/spark_sql.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Iterable, List, Optional + +if TYPE_CHECKING: + from pyspark.sql import DataFrame, Row, SparkSession + + +class SparkSQL: + def __init__( + self, + spark_session: Optional[SparkSession] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + ignore_tables: Optional[List[str]] = None, + include_tables: Optional[List[str]] = None, + sample_rows_in_table_info: int = 3, + ): + try: + from pyspark.sql import SparkSession + except ImportError: + raise ValueError( + "pyspark is not installed. Please install it with `pip install pyspark`" + ) + + self._spark = ( + spark_session if spark_session else SparkSession.builder.getOrCreate() + ) + if catalog is not None: + self._spark.catalog.setCurrentCatalog(catalog) + if schema is not None: + self._spark.catalog.setCurrentDatabase(schema) + + self._all_tables = set(self._get_all_table_names()) + self._include_tables = set(include_tables) if include_tables else set() + if self._include_tables: + missing_tables = self._include_tables - self._all_tables + if missing_tables: + raise ValueError( + f"include_tables {missing_tables} not found in database" + ) + self._ignore_tables = set(ignore_tables) if ignore_tables else set() + if self._ignore_tables: + missing_tables = self._ignore_tables - self._all_tables + if missing_tables: + raise ValueError( + f"ignore_tables {missing_tables} not found in database" + ) + usable_tables = self.get_usable_table_names() + self._usable_tables = set(usable_tables) if usable_tables else self._all_tables + + if not isinstance(sample_rows_in_table_info, int): + raise TypeError("sample_rows_in_table_info must be an integer") + + self._sample_rows_in_table_info = sample_rows_in_table_info + + @classmethod + def from_uri( + cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any + ) -> SparkSQL: + """Creating a remote Spark Session via Spark connect. + For example: SparkSQL.from_uri("sc://localhost:15002") + """ + try: + from pyspark.sql import SparkSession + except ImportError: + raise ValueError( + "pyspark is not installed. Please install it with `pip install pyspark`" + ) + + spark = SparkSession.builder.remote(database_uri).getOrCreate() + return cls(spark, **kwargs) + + def get_usable_table_names(self) -> Iterable[str]: + """Get names of tables available.""" + if self._include_tables: + return self._include_tables + # sorting the result can help LLM understanding it. + return sorted(self._all_tables - self._ignore_tables) + + def _get_all_table_names(self) -> Iterable[str]: + rows = self._spark.sql("SHOW TABLES").select("tableName").collect() + return list(map(lambda row: row.tableName, rows)) + + def _get_create_table_stmt(self, table: str) -> str: + statement = ( + self._spark.sql(f"SHOW CREATE TABLE {table}").collect()[0].createtab_stmt + ) + # Ignore the data source provider and options to reduce the number of tokens. + using_clause_index = statement.find("USING") + return statement[:using_clause_index] + ";" + + def get_table_info(self, table_names: Optional[List[str]] = None) -> str: + all_table_names = self.get_usable_table_names() + if table_names is not None: + missing_tables = set(table_names).difference(all_table_names) + if missing_tables: + raise ValueError(f"table_names {missing_tables} not found in database") + all_table_names = table_names + tables = [] + for table_name in all_table_names: + table_info = self._get_create_table_stmt(table_name) + if self._sample_rows_in_table_info: + table_info += "\n\n/*" + table_info += f"\n{self._get_sample_spark_rows(table_name)}\n" + table_info += "*/" + tables.append(table_info) + final_str = "\n\n".join(tables) + return final_str + + def _get_sample_spark_rows(self, table: str) -> str: + query = f"SELECT * FROM {table} LIMIT {self._sample_rows_in_table_info}" + df = self._spark.sql(query) + columns_str = "\t".join(list(map(lambda f: f.name, df.schema.fields))) + try: + sample_rows = self._get_dataframe_results(df) + # save the sample rows in string format + sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) + except Exception: + sample_rows_str = "" + + return ( + f"{self._sample_rows_in_table_info} rows from {table} table:\n" + f"{columns_str}\n" + f"{sample_rows_str}" + ) + + def _convert_row_as_tuple(self, row: Row) -> tuple: + return tuple(map(str, row.asDict().values())) + + def _get_dataframe_results(self, df: DataFrame) -> list: + return list(map(self._convert_row_as_tuple, df.collect())) + + def run(self, command: str, fetch: str = "all") -> str: + df = self._spark.sql(command) + if fetch == "one": + df = df.limit(1) + return str(self._get_dataframe_results(df)) + + def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: + """Get information about specified tables. + + Follows best practices as specified in: Rajkumar et al, 2022 + (https://arxiv.org/abs/2204.00498) + + If `sample_rows_in_table_info`, the specified number of sample rows will be + appended to each table description. This can increase performance as + demonstrated in the paper. + """ + try: + return self.get_table_info(table_names) + except ValueError as e: + """Format the error message""" + return f"Error: {e}" + + def run_no_throw(self, command: str, fetch: str = "all") -> str: + """Execute a SQL command and return a string representing the results. + + If the statement returns rows, a string of the results is returned. + If the statement returns no rows, an empty string is returned. + + If the statement throws an error, the error message is returned. + """ + try: + from pyspark.errors import PySparkException + except ImportError: + raise ValueError( + "pyspark is not installed. Please install it with `pip install pyspark`" + ) + try: + return self.run(command, fetch) + except PySparkException as e: + """Format the error message""" + return f"Error: {e}" diff --git a/tests/unit_tests/agents/test_public_api.py b/tests/unit_tests/agents/test_public_api.py index 0b6f9df9..58eb111e 100644 --- a/tests/unit_tests/agents/test_public_api.py +++ b/tests/unit_tests/agents/test_public_api.py @@ -24,6 +24,7 @@ _EXPECTED = [ "create_pbi_agent", "create_pbi_chat_agent", "create_spark_dataframe_agent", + "create_spark_sql_agent", "create_sql_agent", "create_vectorstore_agent", "create_vectorstore_router_agent",