From 9aa9fe7021b4d0635e3d9ce1a5d485649fe15395 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 13 May 2023 21:44:54 -0700 Subject: [PATCH] Harrison/spark connect example (#4659) Co-authored-by: Mike Wang <62768671+skcoirz@users.noreply.github.com> --- .../agents/toolkits/examples/python.ipynb | 10 +- .../agents/toolkits/examples/spark.ipynb | 203 ++++++++++++++++-- langchain/agents/agent_toolkits/spark/base.py | 32 ++- 3 files changed, 215 insertions(+), 30 deletions(-) diff --git a/docs/modules/agents/toolkits/examples/python.ipynb b/docs/modules/agents/toolkits/examples/python.ipynb index 08128ea2..1c05a1f9 100644 --- a/docs/modules/agents/toolkits/examples/python.ipynb +++ b/docs/modules/agents/toolkits/examples/python.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 4, "id": "f98e9c90-5c37-4fb9-af3e-d09693af8543", "metadata": { "tags": [] @@ -27,7 +27,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "id": "cc422f53-c51c-4694-a834-72ecd1e68363", "metadata": { "tags": [] @@ -206,9 +206,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "LangChain", "language": "python", - "name": "python3" + "name": "langchain" }, "language_info": { "codemirror_mode": { @@ -220,7 +220,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.9.16" } }, "nbformat": 4, diff --git a/docs/modules/agents/toolkits/examples/spark.ipynb b/docs/modules/agents/toolkits/examples/spark.ipynb index 8874826d..0dc09240 100644 --- a/docs/modules/agents/toolkits/examples/spark.ipynb +++ b/docs/modules/agents/toolkits/examples/spark.ipynb @@ -6,26 +6,26 @@ "source": [ "# Spark Dataframe Agent\n", "\n", - "This notebook shows how to use agents to interact with a Spark dataframe. It is mostly optimized for question answering.\n", + "This notebook shows how to use agents to interact with a Spark dataframe and Spark Connect. It is mostly optimized for question answering.\n", "\n", "**NOTE: this agent calls the Python agent under the hood, which executes LLM generated Python code - this can be bad if the LLM generated Python code is harmful. Use cautiously.**" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from langchain.agents import create_spark_dataframe_agent\n", "import os\n", "\n", - "os.environ[\"OPENAI_API_KEY\"] = \"...input_your_openai_api_key...\"" + "os.environ[\"OPENAI_API_KEY\"] = \"...input your openai api key here...\"" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -73,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -82,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -92,7 +92,7 @@ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3mThought: I need to find out how many rows are in the dataframe\n", + "\u001b[32;1m\u001b[1;3mThought: I need to find out the size of the dataframe\n", "Action: python_repl_ast\n", "Action Input: df.count()\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n", @@ -108,7 +108,7 @@ "'There are 891 rows in the dataframe.'" ] }, - "execution_count": 17, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -119,7 +119,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -145,7 +145,7 @@ "'30 people have more than 3 siblings.'" ] }, - "execution_count": 12, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -156,7 +156,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -194,7 +194,7 @@ "'5.449689683556195'" ] }, - "execution_count": 13, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -202,13 +202,183 @@ "source": [ "agent.run(\"whats the square root of the average age?\")" ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "spark.stop()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Spark Connect Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# in apache-spark root directory. (tested here with \"spark-3.4.0-bin-hadoop3 and later\")\n", + "# To launch Spark with support for Spark Connect sessions, run the start-connect-server.sh script.\n", + "!./sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_2.12:3.4.0" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "23/05/08 10:06:09 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n" + ] + } + ], + "source": [ + "from pyspark.sql import SparkSession\n", + "\n", + "# Now that the Spark server is running, we can connect to it remotely using Spark Connect. We do this by \n", + "# creating a remote Spark session on the client where our application runs. Before we can do that, we need \n", + "# to make sure to stop the existing regular Spark session because it cannot coexist with the remote \n", + "# Spark Connect session we are about to create.\n", + "SparkSession.builder.master(\"local[*]\").getOrCreate().stop()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# The command we used above to launch the server configured Spark to run as localhost:15002. \n", + "# So now we can create a remote Spark session on the client using the following command.\n", + "spark = SparkSession.builder.remote(\"sc://localhost:15002\").getOrCreate()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n", + "|PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|\n", + "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n", + "| 1| 0| 3|Braund, Mr. Owen ...| male|22.0| 1| 0| A/5 21171| 7.25| null| S|\n", + "| 2| 1| 1|Cumings, Mrs. Joh...|female|38.0| 1| 0| PC 17599|71.2833| C85| C|\n", + "| 3| 1| 3|Heikkinen, Miss. ...|female|26.0| 0| 0|STON/O2. 3101282| 7.925| null| S|\n", + "| 4| 1| 1|Futrelle, Mrs. Ja...|female|35.0| 1| 0| 113803| 53.1| C123| S|\n", + "| 5| 0| 3|Allen, Mr. Willia...| male|35.0| 0| 0| 373450| 8.05| null| S|\n", + "| 6| 0| 3| Moran, Mr. James| male|null| 0| 0| 330877| 8.4583| null| Q|\n", + "| 7| 0| 1|McCarthy, Mr. Tim...| male|54.0| 0| 0| 17463|51.8625| E46| S|\n", + "| 8| 0| 3|Palsson, Master. ...| male| 2.0| 3| 1| 349909| 21.075| null| S|\n", + "| 9| 1| 3|Johnson, Mrs. Osc...|female|27.0| 0| 2| 347742|11.1333| null| S|\n", + "| 10| 1| 2|Nasser, Mrs. Nich...|female|14.0| 1| 0| 237736|30.0708| null| C|\n", + "| 11| 1| 3|Sandstrom, Miss. ...|female| 4.0| 1| 1| PP 9549| 16.7| G6| S|\n", + "| 12| 1| 1|Bonnell, Miss. El...|female|58.0| 0| 0| 113783| 26.55| C103| S|\n", + "| 13| 0| 3|Saundercock, Mr. ...| male|20.0| 0| 0| A/5. 2151| 8.05| null| S|\n", + "| 14| 0| 3|Andersson, Mr. An...| male|39.0| 1| 5| 347082| 31.275| null| S|\n", + "| 15| 0| 3|Vestrom, Miss. Hu...|female|14.0| 0| 0| 350406| 7.8542| null| S|\n", + "| 16| 1| 2|Hewlett, Mrs. (Ma...|female|55.0| 0| 0| 248706| 16.0| null| S|\n", + "| 17| 0| 3|Rice, Master. Eugene| male| 2.0| 4| 1| 382652| 29.125| null| Q|\n", + "| 18| 1| 2|Williams, Mr. Cha...| male|null| 0| 0| 244373| 13.0| null| S|\n", + "| 19| 0| 3|Vander Planke, Mr...|female|31.0| 1| 0| 345763| 18.0| null| S|\n", + "| 20| 1| 3|Masselmani, Mrs. ...|female|null| 0| 0| 2649| 7.225| null| C|\n", + "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n", + "only showing top 20 rows\n", + "\n" + ] + } + ], + "source": [ + "csv_file_path = \"titanic.csv\"\n", + "df = spark.read.csv(csv_file_path, header=True, inferSchema=True)\n", + "df.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.agents import create_spark_dataframe_agent\n", + "from langchain.llms import OpenAI\n", + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"...input your openai api key here...\"\n", + "\n", + "agent = create_spark_dataframe_agent(llm=OpenAI(temperature=0), df=df, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "Thought: I need to find the row with the highest fare\n", + "Action: python_repl_ast\n", + "Action Input: df.sort(df.Fare.desc()).first()\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mRow(PassengerId=259, Survived=1, Pclass=1, Name='Ward, Miss. Anna', Sex='female', Age=35.0, SibSp=0, Parch=0, Ticket='PC 17755', Fare=512.3292, Cabin=None, Embarked='C')\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the name of the person who bought the most expensive ticket\n", + "Final Answer: Miss. Anna Ward\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Miss. Anna Ward'" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"\"\"\n", + "who bought the most expensive ticket?\n", + "You can find all supported function types in https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/dataframe.html\n", + "\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "spark.stop()" + ] } ], "metadata": { "kernelspec": { - "display_name": "LangChain", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "langchain" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -220,9 +390,8 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" - }, - "orig_nbformat": 4 + "version": "3.9.1" + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/langchain/agents/agent_toolkits/spark/base.py b/langchain/agents/agent_toolkits/spark/base.py index b789adf7..1b91dc48 100644 --- a/langchain/agents/agent_toolkits/spark/base.py +++ b/langchain/agents/agent_toolkits/spark/base.py @@ -10,6 +10,28 @@ from langchain.llms.base import BaseLLM from langchain.tools.python.tool import PythonAstREPLTool +def _validate_spark_df(df: Any) -> bool: + try: + from pyspark.sql import DataFrame as SparkLocalDataFrame + + if not isinstance(df, SparkLocalDataFrame): + return False + return True + except ImportError: + return False + + +def _validate_spark_connect_df(df: Any) -> bool: + try: + from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame + + if not isinstance(df, SparkConnectDataFrame): + return False + return True + except ImportError: + return False + + def create_spark_dataframe_agent( llm: BaseLLM, df: Any, @@ -26,15 +48,9 @@ def create_spark_dataframe_agent( **kwargs: Dict[str, Any], ) -> AgentExecutor: """Construct a spark agent from an LLM and dataframe.""" - try: - from pyspark.sql import DataFrame - except ImportError: - raise ValueError( - "spark package not found, please install with `pip install pyspark`" - ) - if not isinstance(df, DataFrame): - raise ValueError(f"Expected Spark Data Frame object, got {type(df)}") + if not _validate_spark_df(df) and not _validate_spark_connect_df(df): + raise ValueError("Spark is not installed. run `pip install pyspark`.") if input_variables is None: input_variables = ["df", "input", "agent_scratchpad"]