diff --git a/docs/modules/agents/toolkits/examples/spark.ipynb b/docs/modules/agents/toolkits/examples/spark.ipynb index 0dc09240..6b462a90 100644 --- a/docs/modules/agents/toolkits/examples/spark.ipynb +++ b/docs/modules/agents/toolkits/examples/spark.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -17,7 +18,6 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.agents import create_spark_dataframe_agent\n", "import os\n", "\n", "os.environ[\"OPENAI_API_KEY\"] = \"...input your openai api key here...\"" @@ -25,9 +25,20 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 2, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "23/05/15 20:33:10 WARN Utils: Your hostname, Mikes-Mac-mini.local resolves to a loopback address: 127.0.0.1; using 192.168.68.115 instead (on interface en1)\n", + "23/05/15 20:33:10 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", + "Setting default log level to \"WARN\".\n", + "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", + "23/05/15 20:33:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" + ] + }, { "name": "stdout", "output_type": "stream", @@ -64,6 +75,7 @@ "source": [ "from langchain.llms import OpenAI\n", "from pyspark.sql import SparkSession\n", + "from langchain.agents import create_spark_dataframe_agent\n", "\n", "spark = SparkSession.builder.getOrCreate()\n", "csv_file_path = \"titanic.csv\"\n", @@ -92,7 +104,7 @@ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3mThought: I need to find out the size of the dataframe\n", + "\u001b[32;1m\u001b[1;3mThought: I need to find out how many rows are in the dataframe\n", "Action: python_repl_ast\n", "Action Input: df.count()\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n", @@ -205,7 +217,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -213,6 +225,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ diff --git a/langchain/agents/agent_toolkits/spark/base.py b/langchain/agents/agent_toolkits/spark/base.py index 1b91dc48..135600ad 100644 --- a/langchain/agents/agent_toolkits/spark/base.py +++ b/langchain/agents/agent_toolkits/spark/base.py @@ -14,9 +14,7 @@ def _validate_spark_df(df: Any) -> bool: try: from pyspark.sql import DataFrame as SparkLocalDataFrame - if not isinstance(df, SparkLocalDataFrame): - return False - return True + return isinstance(df, SparkLocalDataFrame) except ImportError: return False @@ -25,9 +23,7 @@ def _validate_spark_connect_df(df: Any) -> bool: try: from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame - if not isinstance(df, SparkConnectDataFrame): - return False - return True + return isinstance(df, SparkConnectDataFrame) except ImportError: return False