Harrison/spark connect example (#4659)

Co-authored-by: Mike Wang <62768671+skcoirz@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-05-13 21:44:54 -07:00 committed by GitHub
parent 2747ccbcf1
commit 9aa9fe7021
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 215 additions and 30 deletions

View File

@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"id": "f98e9c90-5c37-4fb9-af3e-d09693af8543",
"metadata": {
"tags": []
@ -27,7 +27,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 5,
"id": "cc422f53-c51c-4694-a834-72ecd1e68363",
"metadata": {
"tags": []
@ -206,9 +206,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "LangChain",
"language": "python",
"name": "python3"
"name": "langchain"
},
"language_info": {
"codemirror_mode": {
@ -220,7 +220,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.9.16"
}
},
"nbformat": 4,

View File

@ -6,26 +6,26 @@
"source": [
"# Spark Dataframe Agent\n",
"\n",
"This notebook shows how to use agents to interact with a Spark dataframe. It is mostly optimized for question answering.\n",
"This notebook shows how to use agents to interact with a Spark dataframe and Spark Connect. It is mostly optimized for question answering.\n",
"\n",
"**NOTE: this agent calls the Python agent under the hood, which executes LLM generated Python code - this can be bad if the LLM generated Python code is harmful. Use cautiously.**"
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import create_spark_dataframe_agent\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"...input_your_openai_api_key...\""
"os.environ[\"OPENAI_API_KEY\"] = \"...input your openai api key here...\""
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 11,
"metadata": {},
"outputs": [
{
@ -73,7 +73,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -82,7 +82,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@ -92,7 +92,7 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to find out how many rows are in the dataframe\n",
"\u001b[32;1m\u001b[1;3mThought: I need to find out the size of the dataframe\n",
"Action: python_repl_ast\n",
"Action Input: df.count()\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n",
@ -108,7 +108,7 @@
"'There are 891 rows in the dataframe.'"
]
},
"execution_count": 17,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@ -145,7 +145,7 @@
"'30 people have more than 3 siblings.'"
]
},
"execution_count": 12,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@ -156,7 +156,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 6,
"metadata": {},
"outputs": [
{
@ -194,7 +194,7 @@
"'5.449689683556195'"
]
},
"execution_count": 13,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -202,13 +202,183 @@
"source": [
"agent.run(\"whats the square root of the average age?\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"spark.stop()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Spark Connect Example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# in apache-spark root directory. (tested here with \"spark-3.4.0-bin-hadoop3 and later\")\n",
"# To launch Spark with support for Spark Connect sessions, run the start-connect-server.sh script.\n",
"!./sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_2.12:3.4.0"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"23/05/08 10:06:09 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n"
]
}
],
"source": [
"from pyspark.sql import SparkSession\n",
"\n",
"# Now that the Spark server is running, we can connect to it remotely using Spark Connect. We do this by \n",
"# creating a remote Spark session on the client where our application runs. Before we can do that, we need \n",
"# to make sure to stop the existing regular Spark session because it cannot coexist with the remote \n",
"# Spark Connect session we are about to create.\n",
"SparkSession.builder.master(\"local[*]\").getOrCreate().stop()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# The command we used above to launch the server configured Spark to run as localhost:15002. \n",
"# So now we can create a remote Spark session on the client using the following command.\n",
"spark = SparkSession.builder.remote(\"sc://localhost:15002\").getOrCreate()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
"|PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|\n",
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
"| 1| 0| 3|Braund, Mr. Owen ...| male|22.0| 1| 0| A/5 21171| 7.25| null| S|\n",
"| 2| 1| 1|Cumings, Mrs. Joh...|female|38.0| 1| 0| PC 17599|71.2833| C85| C|\n",
"| 3| 1| 3|Heikkinen, Miss. ...|female|26.0| 0| 0|STON/O2. 3101282| 7.925| null| S|\n",
"| 4| 1| 1|Futrelle, Mrs. Ja...|female|35.0| 1| 0| 113803| 53.1| C123| S|\n",
"| 5| 0| 3|Allen, Mr. Willia...| male|35.0| 0| 0| 373450| 8.05| null| S|\n",
"| 6| 0| 3| Moran, Mr. James| male|null| 0| 0| 330877| 8.4583| null| Q|\n",
"| 7| 0| 1|McCarthy, Mr. Tim...| male|54.0| 0| 0| 17463|51.8625| E46| S|\n",
"| 8| 0| 3|Palsson, Master. ...| male| 2.0| 3| 1| 349909| 21.075| null| S|\n",
"| 9| 1| 3|Johnson, Mrs. Osc...|female|27.0| 0| 2| 347742|11.1333| null| S|\n",
"| 10| 1| 2|Nasser, Mrs. Nich...|female|14.0| 1| 0| 237736|30.0708| null| C|\n",
"| 11| 1| 3|Sandstrom, Miss. ...|female| 4.0| 1| 1| PP 9549| 16.7| G6| S|\n",
"| 12| 1| 1|Bonnell, Miss. El...|female|58.0| 0| 0| 113783| 26.55| C103| S|\n",
"| 13| 0| 3|Saundercock, Mr. ...| male|20.0| 0| 0| A/5. 2151| 8.05| null| S|\n",
"| 14| 0| 3|Andersson, Mr. An...| male|39.0| 1| 5| 347082| 31.275| null| S|\n",
"| 15| 0| 3|Vestrom, Miss. Hu...|female|14.0| 0| 0| 350406| 7.8542| null| S|\n",
"| 16| 1| 2|Hewlett, Mrs. (Ma...|female|55.0| 0| 0| 248706| 16.0| null| S|\n",
"| 17| 0| 3|Rice, Master. Eugene| male| 2.0| 4| 1| 382652| 29.125| null| Q|\n",
"| 18| 1| 2|Williams, Mr. Cha...| male|null| 0| 0| 244373| 13.0| null| S|\n",
"| 19| 0| 3|Vander Planke, Mr...|female|31.0| 1| 0| 345763| 18.0| null| S|\n",
"| 20| 1| 3|Masselmani, Mrs. ...|female|null| 0| 0| 2649| 7.225| null| C|\n",
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
"only showing top 20 rows\n",
"\n"
]
}
],
"source": [
"csv_file_path = \"titanic.csv\"\n",
"df = spark.read.csv(csv_file_path, header=True, inferSchema=True)\n",
"df.show()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import create_spark_dataframe_agent\n",
"from langchain.llms import OpenAI\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"...input your openai api key here...\"\n",
"\n",
"agent = create_spark_dataframe_agent(llm=OpenAI(temperature=0), df=df, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Thought: I need to find the row with the highest fare\n",
"Action: python_repl_ast\n",
"Action Input: df.sort(df.Fare.desc()).first()\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mRow(PassengerId=259, Survived=1, Pclass=1, Name='Ward, Miss. Anna', Sex='female', Age=35.0, SibSp=0, Parch=0, Ticket='PC 17755', Fare=512.3292, Cabin=None, Embarked='C')\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the name of the person who bought the most expensive ticket\n",
"Final Answer: Miss. Anna Ward\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Miss. Anna Ward'"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"\"\"\n",
"who bought the most expensive ticket?\n",
"You can find all supported function types in https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/dataframe.html\n",
"\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"spark.stop()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "LangChain",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "langchain"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@ -220,9 +390,8 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"orig_nbformat": 4
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 2

View File

@ -10,6 +10,28 @@ from langchain.llms.base import BaseLLM
from langchain.tools.python.tool import PythonAstREPLTool
def _validate_spark_df(df: Any) -> bool:
try:
from pyspark.sql import DataFrame as SparkLocalDataFrame
if not isinstance(df, SparkLocalDataFrame):
return False
return True
except ImportError:
return False
def _validate_spark_connect_df(df: Any) -> bool:
try:
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
if not isinstance(df, SparkConnectDataFrame):
return False
return True
except ImportError:
return False
def create_spark_dataframe_agent(
llm: BaseLLM,
df: Any,
@ -26,15 +48,9 @@ def create_spark_dataframe_agent(
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a spark agent from an LLM and dataframe."""
try:
from pyspark.sql import DataFrame
except ImportError:
raise ValueError(
"spark package not found, please install with `pip install pyspark`"
)
if not isinstance(df, DataFrame):
raise ValueError(f"Expected Spark Data Frame object, got {type(df)}")
if not _validate_spark_df(df) and not _validate_spark_connect_df(df):
raise ValueError("Spark is not installed. run `pip install pyspark`.")
if input_variables is None:
input_variables = ["df", "input", "agent_scratchpad"]