mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
230 lines
8.5 KiB
Plaintext
230 lines
8.5 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Spark Dataframe Agent\n",
|
||
|
"\n",
|
||
|
"This notebook shows how to use agents to interact with a Spark dataframe. It is mostly optimized for question answering.\n",
|
||
|
"\n",
|
||
|
"**NOTE: this agent calls the Python agent under the hood, which executes LLM generated Python code - this can be bad if the LLM generated Python code is harmful. Use cautiously.**"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from langchain.agents import create_spark_dataframe_agent\n",
|
||
|
"import os\n",
|
||
|
"\n",
|
||
|
"os.environ[\"OPENAI_API_KEY\"] = \"...input_your_openai_api_key...\""
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
|
||
|
"|PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|\n",
|
||
|
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
|
||
|
"| 1| 0| 3|Braund, Mr. Owen ...| male|22.0| 1| 0| A/5 21171| 7.25| null| S|\n",
|
||
|
"| 2| 1| 1|Cumings, Mrs. Joh...|female|38.0| 1| 0| PC 17599|71.2833| C85| C|\n",
|
||
|
"| 3| 1| 3|Heikkinen, Miss. ...|female|26.0| 0| 0|STON/O2. 3101282| 7.925| null| S|\n",
|
||
|
"| 4| 1| 1|Futrelle, Mrs. Ja...|female|35.0| 1| 0| 113803| 53.1| C123| S|\n",
|
||
|
"| 5| 0| 3|Allen, Mr. Willia...| male|35.0| 0| 0| 373450| 8.05| null| S|\n",
|
||
|
"| 6| 0| 3| Moran, Mr. James| male|null| 0| 0| 330877| 8.4583| null| Q|\n",
|
||
|
"| 7| 0| 1|McCarthy, Mr. Tim...| male|54.0| 0| 0| 17463|51.8625| E46| S|\n",
|
||
|
"| 8| 0| 3|Palsson, Master. ...| male| 2.0| 3| 1| 349909| 21.075| null| S|\n",
|
||
|
"| 9| 1| 3|Johnson, Mrs. Osc...|female|27.0| 0| 2| 347742|11.1333| null| S|\n",
|
||
|
"| 10| 1| 2|Nasser, Mrs. Nich...|female|14.0| 1| 0| 237736|30.0708| null| C|\n",
|
||
|
"| 11| 1| 3|Sandstrom, Miss. ...|female| 4.0| 1| 1| PP 9549| 16.7| G6| S|\n",
|
||
|
"| 12| 1| 1|Bonnell, Miss. El...|female|58.0| 0| 0| 113783| 26.55| C103| S|\n",
|
||
|
"| 13| 0| 3|Saundercock, Mr. ...| male|20.0| 0| 0| A/5. 2151| 8.05| null| S|\n",
|
||
|
"| 14| 0| 3|Andersson, Mr. An...| male|39.0| 1| 5| 347082| 31.275| null| S|\n",
|
||
|
"| 15| 0| 3|Vestrom, Miss. Hu...|female|14.0| 0| 0| 350406| 7.8542| null| S|\n",
|
||
|
"| 16| 1| 2|Hewlett, Mrs. (Ma...|female|55.0| 0| 0| 248706| 16.0| null| S|\n",
|
||
|
"| 17| 0| 3|Rice, Master. Eugene| male| 2.0| 4| 1| 382652| 29.125| null| Q|\n",
|
||
|
"| 18| 1| 2|Williams, Mr. Cha...| male|null| 0| 0| 244373| 13.0| null| S|\n",
|
||
|
"| 19| 0| 3|Vander Planke, Mr...|female|31.0| 1| 0| 345763| 18.0| null| S|\n",
|
||
|
"| 20| 1| 3|Masselmani, Mrs. ...|female|null| 0| 0| 2649| 7.225| null| C|\n",
|
||
|
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
|
||
|
"only showing top 20 rows\n",
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from langchain.llms import OpenAI\n",
|
||
|
"from pyspark.sql import SparkSession\n",
|
||
|
"\n",
|
||
|
"spark = SparkSession.builder.getOrCreate()\n",
|
||
|
"csv_file_path = \"titanic.csv\"\n",
|
||
|
"df = spark.read.csv(csv_file_path, header=True, inferSchema=True)\n",
|
||
|
"df.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"agent = create_spark_dataframe_agent(llm=OpenAI(temperature=0), df=df, verbose=True)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||
|
"\u001b[32;1m\u001b[1;3mThought: I need to find out how many rows are in the dataframe\n",
|
||
|
"Action: python_repl_ast\n",
|
||
|
"Action Input: df.count()\u001b[0m\n",
|
||
|
"Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n",
|
||
|
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||
|
"Final Answer: There are 891 rows in the dataframe.\u001b[0m\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"'There are 891 rows in the dataframe.'"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 17,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"agent.run(\"how many rows are there?\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||
|
"\u001b[32;1m\u001b[1;3mThought: I need to find out how many people have more than 3 siblings\n",
|
||
|
"Action: python_repl_ast\n",
|
||
|
"Action Input: df.filter(df.SibSp > 3).count()\u001b[0m\n",
|
||
|
"Observation: \u001b[36;1m\u001b[1;3m30\u001b[0m\n",
|
||
|
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||
|
"Final Answer: 30 people have more than 3 siblings.\u001b[0m\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"'30 people have more than 3 siblings.'"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"agent.run(\"how many people have more than 3 sibligngs\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||
|
"\u001b[32;1m\u001b[1;3mThought: I need to get the average age first\n",
|
||
|
"Action: python_repl_ast\n",
|
||
|
"Action Input: df.agg({\"Age\": \"mean\"}).collect()[0][0]\u001b[0m\n",
|
||
|
"Observation: \u001b[36;1m\u001b[1;3m29.69911764705882\u001b[0m\n",
|
||
|
"Thought:\u001b[32;1m\u001b[1;3m I now have the average age, I need to get the square root\n",
|
||
|
"Action: python_repl_ast\n",
|
||
|
"Action Input: math.sqrt(29.69911764705882)\u001b[0m\n",
|
||
|
"Observation: \u001b[36;1m\u001b[1;3mname 'math' is not defined\u001b[0m\n",
|
||
|
"Thought:\u001b[32;1m\u001b[1;3m I need to import math first\n",
|
||
|
"Action: python_repl_ast\n",
|
||
|
"Action Input: import math\u001b[0m\n",
|
||
|
"Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n",
|
||
|
"Thought:\u001b[32;1m\u001b[1;3m I now have the math library imported, I can get the square root\n",
|
||
|
"Action: python_repl_ast\n",
|
||
|
"Action Input: math.sqrt(29.69911764705882)\u001b[0m\n",
|
||
|
"Observation: \u001b[36;1m\u001b[1;3m5.449689683556195\u001b[0m\n",
|
||
|
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||
|
"Final Answer: 5.449689683556195\u001b[0m\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"'5.449689683556195'"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"agent.run(\"whats the square root of the average age?\")"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "LangChain",
|
||
|
"language": "python",
|
||
|
"name": "langchain"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.16"
|
||
|
},
|
||
|
"orig_nbformat": 4
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|