From 67db495fcf621e426e38abc4015357e7620143a6 Mon Sep 17 00:00:00 2001 From: Mike Wang <62768671+skcoirz@users.noreply.github.com> Date: Wed, 3 May 2023 22:45:23 -0700 Subject: [PATCH] [agent] Add Spark Agent (#4020) - added support for spark through pyspark library. - added jupyter notebook as example. --- .../agents/toolkits/examples/spark.ipynb | 229 ++++++++++++++++++ langchain/agents/__init__.py | 2 + langchain/agents/agent_toolkits/__init__.py | 2 + .../agents/agent_toolkits/spark/__init__.py | 1 + langchain/agents/agent_toolkits/spark/base.py | 68 ++++++ .../agents/agent_toolkits/spark/prompt.py | 13 + tests/unit_tests/agents/test_public_api.py | 1 + 7 files changed, 316 insertions(+) create mode 100644 docs/modules/agents/toolkits/examples/spark.ipynb create mode 100644 langchain/agents/agent_toolkits/spark/__init__.py create mode 100644 langchain/agents/agent_toolkits/spark/base.py create mode 100644 langchain/agents/agent_toolkits/spark/prompt.py diff --git a/docs/modules/agents/toolkits/examples/spark.ipynb b/docs/modules/agents/toolkits/examples/spark.ipynb new file mode 100644 index 00000000..c3ea9010 --- /dev/null +++ b/docs/modules/agents/toolkits/examples/spark.ipynb @@ -0,0 +1,229 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Spark Dataframe Agent\n", + "\n", + "This notebook shows how to use agents to interact with a Spark dataframe. It is mostly optimized for question answering.\n", + "\n", + "**NOTE: this agent calls the Python agent under the hood, which executes LLM generated Python code - this can be bad if the LLM generated Python code is harmful. Use cautiously.**" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.agents import create_spark_dataframe_agent\n", + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"...input_your_openai_api_key...\"" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n", + "|PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|\n", + "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n", + "| 1| 0| 3|Braund, Mr. Owen ...| male|22.0| 1| 0| A/5 21171| 7.25| null| S|\n", + "| 2| 1| 1|Cumings, Mrs. Joh...|female|38.0| 1| 0| PC 17599|71.2833| C85| C|\n", + "| 3| 1| 3|Heikkinen, Miss. ...|female|26.0| 0| 0|STON/O2. 3101282| 7.925| null| S|\n", + "| 4| 1| 1|Futrelle, Mrs. Ja...|female|35.0| 1| 0| 113803| 53.1| C123| S|\n", + "| 5| 0| 3|Allen, Mr. Willia...| male|35.0| 0| 0| 373450| 8.05| null| S|\n", + "| 6| 0| 3| Moran, Mr. James| male|null| 0| 0| 330877| 8.4583| null| Q|\n", + "| 7| 0| 1|McCarthy, Mr. Tim...| male|54.0| 0| 0| 17463|51.8625| E46| S|\n", + "| 8| 0| 3|Palsson, Master. ...| male| 2.0| 3| 1| 349909| 21.075| null| S|\n", + "| 9| 1| 3|Johnson, Mrs. Osc...|female|27.0| 0| 2| 347742|11.1333| null| S|\n", + "| 10| 1| 2|Nasser, Mrs. Nich...|female|14.0| 1| 0| 237736|30.0708| null| C|\n", + "| 11| 1| 3|Sandstrom, Miss. ...|female| 4.0| 1| 1| PP 9549| 16.7| G6| S|\n", + "| 12| 1| 1|Bonnell, Miss. El...|female|58.0| 0| 0| 113783| 26.55| C103| S|\n", + "| 13| 0| 3|Saundercock, Mr. ...| male|20.0| 0| 0| A/5. 2151| 8.05| null| S|\n", + "| 14| 0| 3|Andersson, Mr. An...| male|39.0| 1| 5| 347082| 31.275| null| S|\n", + "| 15| 0| 3|Vestrom, Miss. Hu...|female|14.0| 0| 0| 350406| 7.8542| null| S|\n", + "| 16| 1| 2|Hewlett, Mrs. (Ma...|female|55.0| 0| 0| 248706| 16.0| null| S|\n", + "| 17| 0| 3|Rice, Master. Eugene| male| 2.0| 4| 1| 382652| 29.125| null| Q|\n", + "| 18| 1| 2|Williams, Mr. Cha...| male|null| 0| 0| 244373| 13.0| null| S|\n", + "| 19| 0| 3|Vander Planke, Mr...|female|31.0| 1| 0| 345763| 18.0| null| S|\n", + "| 20| 1| 3|Masselmani, Mrs. ...|female|null| 0| 0| 2649| 7.225| null| C|\n", + "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n", + "only showing top 20 rows\n", + "\n" + ] + } + ], + "source": [ + "from langchain.llms import OpenAI\n", + "from pyspark.sql import SparkSession\n", + "\n", + "spark = SparkSession.builder.getOrCreate()\n", + "csv_file_path = \"titanic.csv\"\n", + "df = spark.read.csv(csv_file_path, header=True, inferSchema=True)\n", + "df.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "agent = create_spark_dataframe_agent(llm=OpenAI(temperature=0), df=df, verbose=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mThought: I need to find out how many rows are in the dataframe\n", + "Action: python_repl_ast\n", + "Action Input: df.count()\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: There are 891 rows in the dataframe.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'There are 891 rows in the dataframe.'" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"how many rows are there?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mThought: I need to find out how many people have more than 3 siblings\n", + "Action: python_repl_ast\n", + "Action Input: df.filter(df.SibSp > 3).count()\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3m30\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: 30 people have more than 3 siblings.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'30 people have more than 3 siblings.'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"how many people have more than 3 sibligngs\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mThought: I need to get the average age first\n", + "Action: python_repl_ast\n", + "Action Input: df.agg({\"Age\": \"mean\"}).collect()[0][0]\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3m29.69911764705882\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now have the average age, I need to get the square root\n", + "Action: python_repl_ast\n", + "Action Input: math.sqrt(29.69911764705882)\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mname 'math' is not defined\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to import math first\n", + "Action: python_repl_ast\n", + "Action Input: import math\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now have the math library imported, I can get the square root\n", + "Action: python_repl_ast\n", + "Action Input: math.sqrt(29.69911764705882)\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3m5.449689683556195\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: 5.449689683556195\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'5.449689683556195'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"whats the square root of the average age?\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "LangChain", + "language": "python", + "name": "langchain" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/langchain/agents/__init__.py b/langchain/agents/__init__.py index 5081d048..b70dea94 100644 --- a/langchain/agents/__init__.py +++ b/langchain/agents/__init__.py @@ -14,6 +14,7 @@ from langchain.agents.agent_toolkits import ( create_pandas_dataframe_agent, create_pbi_agent, create_pbi_chat_agent, + create_spark_dataframe_agent, create_sql_agent, create_vectorstore_agent, create_vectorstore_router_agent, @@ -53,6 +54,7 @@ __all__ = [ "create_pandas_dataframe_agent", "create_pbi_agent", "create_pbi_chat_agent", + "create_spark_dataframe_agent", "create_sql_agent", "create_vectorstore_agent", "create_vectorstore_router_agent", diff --git a/langchain/agents/agent_toolkits/__init__.py b/langchain/agents/agent_toolkits/__init__.py index 4a0bbd1c..95239754 100644 --- a/langchain/agents/agent_toolkits/__init__.py +++ b/langchain/agents/agent_toolkits/__init__.py @@ -16,6 +16,7 @@ from langchain.agents.agent_toolkits.powerbi.base import create_pbi_agent from langchain.agents.agent_toolkits.powerbi.chat_base import create_pbi_chat_agent from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit from langchain.agents.agent_toolkits.python.base import create_python_agent +from langchain.agents.agent_toolkits.spark.base import create_spark_dataframe_agent from langchain.agents.agent_toolkits.sql.base import create_sql_agent from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit from langchain.agents.agent_toolkits.vectorstore.base import ( @@ -47,6 +48,7 @@ __all__ = [ "VectorStoreInfo", "VectorStoreRouterToolkit", "create_pandas_dataframe_agent", + "create_spark_dataframe_agent", "create_csv_agent", "ZapierToolkit", "JiraToolkit", diff --git a/langchain/agents/agent_toolkits/spark/__init__.py b/langchain/agents/agent_toolkits/spark/__init__.py new file mode 100644 index 00000000..ded6eb03 --- /dev/null +++ b/langchain/agents/agent_toolkits/spark/__init__.py @@ -0,0 +1 @@ +"""spark toolkit""" diff --git a/langchain/agents/agent_toolkits/spark/base.py b/langchain/agents/agent_toolkits/spark/base.py new file mode 100644 index 00000000..b789adf7 --- /dev/null +++ b/langchain/agents/agent_toolkits/spark/base.py @@ -0,0 +1,68 @@ +"""Agent for working with pandas objects.""" +from typing import Any, Dict, List, Optional + +from langchain.agents.agent import AgentExecutor +from langchain.agents.agent_toolkits.spark.prompt import PREFIX, SUFFIX +from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.callbacks.base import BaseCallbackManager +from langchain.chains.llm import LLMChain +from langchain.llms.base import BaseLLM +from langchain.tools.python.tool import PythonAstREPLTool + + +def create_spark_dataframe_agent( + llm: BaseLLM, + df: Any, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = PREFIX, + suffix: str = SUFFIX, + input_variables: Optional[List[str]] = None, + verbose: bool = False, + return_intermediate_steps: bool = False, + max_iterations: Optional[int] = 15, + max_execution_time: Optional[float] = None, + early_stopping_method: str = "force", + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Dict[str, Any], +) -> AgentExecutor: + """Construct a spark agent from an LLM and dataframe.""" + try: + from pyspark.sql import DataFrame + except ImportError: + raise ValueError( + "spark package not found, please install with `pip install pyspark`" + ) + + if not isinstance(df, DataFrame): + raise ValueError(f"Expected Spark Data Frame object, got {type(df)}") + + if input_variables is None: + input_variables = ["df", "input", "agent_scratchpad"] + tools = [PythonAstREPLTool(locals={"df": df})] + prompt = ZeroShotAgent.create_prompt( + tools, prefix=prefix, suffix=suffix, input_variables=input_variables + ) + partial_prompt = prompt.partial(df=str(df.first())) + llm_chain = LLMChain( + llm=llm, + prompt=partial_prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent( + llm_chain=llm_chain, + allowed_tools=tool_names, + callback_manager=callback_manager, + **kwargs, + ) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + return_intermediate_steps=return_intermediate_steps, + max_iterations=max_iterations, + max_execution_time=max_execution_time, + early_stopping_method=early_stopping_method, + **(agent_executor_kwargs or {}), + ) diff --git a/langchain/agents/agent_toolkits/spark/prompt.py b/langchain/agents/agent_toolkits/spark/prompt.py new file mode 100644 index 00000000..32ce2c34 --- /dev/null +++ b/langchain/agents/agent_toolkits/spark/prompt.py @@ -0,0 +1,13 @@ +# flake8: noqa + +PREFIX = """ +You are working with a spark dataframe in Python. The name of the dataframe is `df`. +You should use the tools below to answer the question posed of you:""" + +SUFFIX = """ +This is the result of `print(df.first())`: +{df} + +Begin! +Question: {input} +{agent_scratchpad}""" diff --git a/tests/unit_tests/agents/test_public_api.py b/tests/unit_tests/agents/test_public_api.py index 489adcfe..bc6fd4f3 100644 --- a/tests/unit_tests/agents/test_public_api.py +++ b/tests/unit_tests/agents/test_public_api.py @@ -23,6 +23,7 @@ _EXPECTED = [ "create_pandas_dataframe_agent", "create_pbi_agent", "create_pbi_chat_agent", + "create_spark_dataframe_agent", "create_sql_agent", "create_vectorstore_agent", "create_vectorstore_router_agent",