[nit] Simplify Spark Creation Validation Check A Little Bit (#4761)

- simplify the validation check a little bit.
- re-tested in jupyter notebook.

Reviewer: @hwchase17
searx_updates
Mike Wang 1 year ago committed by GitHub
parent e027a38f33
commit db6f7ed0ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -17,7 +18,6 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import create_spark_dataframe_agent\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"...input your openai api key here...\""
@ -25,9 +25,20 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"23/05/15 20:33:10 WARN Utils: Your hostname, Mikes-Mac-mini.local resolves to a loopback address: 127.0.0.1; using 192.168.68.115 instead (on interface en1)\n",
"23/05/15 20:33:10 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
"23/05/15 20:33:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
]
},
{
"name": "stdout",
"output_type": "stream",
@ -64,6 +75,7 @@
"source": [
"from langchain.llms import OpenAI\n",
"from pyspark.sql import SparkSession\n",
"from langchain.agents import create_spark_dataframe_agent\n",
"\n",
"spark = SparkSession.builder.getOrCreate()\n",
"csv_file_path = \"titanic.csv\"\n",
@ -92,7 +104,7 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to find out the size of the dataframe\n",
"\u001b[32;1m\u001b[1;3mThought: I need to find out how many rows are in the dataframe\n",
"Action: python_repl_ast\n",
"Action Input: df.count()\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n",
@ -205,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@ -213,6 +225,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [

@ -14,9 +14,7 @@ def _validate_spark_df(df: Any) -> bool:
try:
from pyspark.sql import DataFrame as SparkLocalDataFrame
if not isinstance(df, SparkLocalDataFrame):
return False
return True
return isinstance(df, SparkLocalDataFrame)
except ImportError:
return False
@ -25,9 +23,7 @@ def _validate_spark_connect_df(df: Any) -> bool:
try:
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
if not isinstance(df, SparkConnectDataFrame):
return False
return True
return isinstance(df, SparkConnectDataFrame)
except ImportError:
return False

Loading…
Cancel
Save