[nit] Simplify Spark Creation Validation Check A Little Bit (#4761)

- simplify the validation check a little bit.
- re-tested in jupyter notebook.

Reviewer: @hwchase17
searx_updates
Mike Wang 1 year ago committed by GitHub
parent e027a38f33
commit db6f7ed0ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
{ {
"cells": [ "cells": [
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
@ -17,7 +18,6 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.agents import create_spark_dataframe_agent\n",
"import os\n", "import os\n",
"\n", "\n",
"os.environ[\"OPENAI_API_KEY\"] = \"...input your openai api key here...\"" "os.environ[\"OPENAI_API_KEY\"] = \"...input your openai api key here...\""
@ -25,9 +25,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"23/05/15 20:33:10 WARN Utils: Your hostname, Mikes-Mac-mini.local resolves to a loopback address: 127.0.0.1; using 192.168.68.115 instead (on interface en1)\n",
"23/05/15 20:33:10 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
"23/05/15 20:33:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
@ -64,6 +75,7 @@
"source": [ "source": [
"from langchain.llms import OpenAI\n", "from langchain.llms import OpenAI\n",
"from pyspark.sql import SparkSession\n", "from pyspark.sql import SparkSession\n",
"from langchain.agents import create_spark_dataframe_agent\n",
"\n", "\n",
"spark = SparkSession.builder.getOrCreate()\n", "spark = SparkSession.builder.getOrCreate()\n",
"csv_file_path = \"titanic.csv\"\n", "csv_file_path = \"titanic.csv\"\n",
@ -92,7 +104,7 @@
"\n", "\n",
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to find out the size of the dataframe\n", "\u001b[32;1m\u001b[1;3mThought: I need to find out how many rows are in the dataframe\n",
"Action: python_repl_ast\n", "Action: python_repl_ast\n",
"Action Input: df.count()\u001b[0m\n", "Action Input: df.count()\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n",
@ -205,7 +217,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -213,6 +225,7 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [

@ -14,9 +14,7 @@ def _validate_spark_df(df: Any) -> bool:
try: try:
from pyspark.sql import DataFrame as SparkLocalDataFrame from pyspark.sql import DataFrame as SparkLocalDataFrame
if not isinstance(df, SparkLocalDataFrame): return isinstance(df, SparkLocalDataFrame)
return False
return True
except ImportError: except ImportError:
return False return False
@ -25,9 +23,7 @@ def _validate_spark_connect_df(df: Any) -> bool:
try: try:
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
if not isinstance(df, SparkConnectDataFrame): return isinstance(df, SparkConnectDataFrame)
return False
return True
except ImportError: except ImportError:
return False return False

Loading…
Cancel
Save