forked from Archives/langchain
Harrison/use functions agent (#6185)
Co-authored-by: Francisco Ingham <24279597+fpingham@users.noreply.github.com>
This commit is contained in:
parent
7d2b946d0b
commit
e82687ddf4
@ -30,38 +30,83 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.llms import OpenAI"
|
"from langchain.llms import OpenAI\n",
|
||||||
|
"from langchain.chat_models import ChatOpenAI\n",
|
||||||
|
"from langchain.agents.agent_types import AgentType"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bd806175",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Using ZERO_SHOT_REACT_DESCRIPTION\n",
|
||||||
|
"\n",
|
||||||
|
"This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "a1717204",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent = create_csv_agent(\n",
|
||||||
|
" OpenAI(temperature=0), \n",
|
||||||
|
" 'titanic.csv', \n",
|
||||||
|
" verbose=True, \n",
|
||||||
|
" agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c31bb8a6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Using OpenAI Functions\n",
|
||||||
|
"\n",
|
||||||
|
"This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "16c4dc59",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent = create_csv_agent(\n",
|
||||||
|
" ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"), \n",
|
||||||
|
" 'titanic.csv', \n",
|
||||||
|
" verbose=True, \n",
|
||||||
|
" agent_type=AgentType.OPENAI_FUNCTIONS\n",
|
||||||
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 4,
|
||||||
"id": "16c4dc59",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"agent = create_csv_agent(OpenAI(temperature=0), 'titanic.csv', verbose=True)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"id": "46b9489d",
|
"id": "46b9489d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Error in on_chain_start callback: 'name'\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
|
"\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"Invoking: `python_repl_ast` with `df.shape[0]`\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
"\u001b[0m\u001b[36;1m\u001b[1;3m891\u001b[0m\u001b[32;1m\u001b[1;3mThere are 891 rows in the dataframe.\u001b[0m\n",
|
||||||
"\u001b[32;1m\u001b[1;3mThought: I need to count the number of rows\n",
|
|
||||||
"Action: python_repl_ast\n",
|
|
||||||
"Action Input: df.shape[0]\u001b[0m\n",
|
|
||||||
"Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n",
|
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
|
||||||
"Final Answer: There are 891 rows.\u001b[0m\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
@ -69,10 +114,10 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'There are 891 rows.'"
|
"'There are 891 rows in the dataframe.'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 5,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -83,23 +128,28 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 5,
|
||||||
"id": "a96309be",
|
"id": "a96309be",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"scrolled": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Error in on_chain_start callback: 'name'\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
|
"\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"Invoking: `python_repl_ast` with `df[df['SibSp'] > 3]['PassengerId'].count()`\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
"\u001b[0m\u001b[36;1m\u001b[1;3m30\u001b[0m\u001b[32;1m\u001b[1;3mThere are 30 people in the dataframe who have more than 3 siblings.\u001b[0m\n",
|
||||||
"\u001b[32;1m\u001b[1;3mThought: I need to count the number of people with more than 3 siblings\n",
|
|
||||||
"Action: python_repl_ast\n",
|
|
||||||
"Action Input: df[df['SibSp'] > 3].shape[0]\u001b[0m\n",
|
|
||||||
"Observation: \u001b[36;1m\u001b[1;3m30\u001b[0m\n",
|
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
|
||||||
"Final Answer: 30 people have more than 3 siblings.\u001b[0m\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
@ -107,10 +157,10 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'30 people have more than 3 siblings.'"
|
"'There are 30 people in the dataframe who have more than 3 siblings.'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 6,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -121,35 +171,39 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 6,
|
||||||
"id": "964a09f7",
|
"id": "964a09f7",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Error in on_chain_start callback: 'name'\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
|
"\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"Invoking: `python_repl_ast` with `import pandas as pd\n",
|
||||||
|
"import math\n",
|
||||||
|
"\n",
|
||||||
|
"# Create a dataframe\n",
|
||||||
|
"data = {'Age': [22, 38, 26, 35, 35]}\n",
|
||||||
|
"df = pd.DataFrame(data)\n",
|
||||||
|
"\n",
|
||||||
|
"# Calculate the average age\n",
|
||||||
|
"average_age = df['Age'].mean()\n",
|
||||||
|
"\n",
|
||||||
|
"# Calculate the square root of the average age\n",
|
||||||
|
"square_root = math.sqrt(average_age)\n",
|
||||||
|
"\n",
|
||||||
|
"square_root`\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
"\u001b[0m\u001b[36;1m\u001b[1;3m5.585696017507576\u001b[0m\u001b[32;1m\u001b[1;3mThe square root of the average age is approximately 5.59.\u001b[0m\n",
|
||||||
"\u001b[32;1m\u001b[1;3mThought: I need to calculate the average age first\n",
|
|
||||||
"Action: python_repl_ast\n",
|
|
||||||
"Action Input: df['Age'].mean()\u001b[0m\n",
|
|
||||||
"Observation: \u001b[36;1m\u001b[1;3m29.69911764705882\u001b[0m\n",
|
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I now need to calculate the square root of the average age\n",
|
|
||||||
"Action: python_repl_ast\n",
|
|
||||||
"Action Input: math.sqrt(df['Age'].mean())\u001b[0m\n",
|
|
||||||
"Observation: \u001b[36;1m\u001b[1;3mNameError(\"name 'math' is not defined\")\u001b[0m\n",
|
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I need to import the math library\n",
|
|
||||||
"Action: python_repl_ast\n",
|
|
||||||
"Action Input: import math\u001b[0m\n",
|
|
||||||
"Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n",
|
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I now need to calculate the square root of the average age\n",
|
|
||||||
"Action: python_repl_ast\n",
|
|
||||||
"Action Input: math.sqrt(df['Age'].mean())\u001b[0m\n",
|
|
||||||
"Observation: \u001b[36;1m\u001b[1;3m5.449689683556195\u001b[0m\n",
|
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
|
||||||
"Final Answer: 5.449689683556195\u001b[0m\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
@ -157,10 +211,10 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'5.449689683556195'"
|
"'The square root of the average age is approximately 5.59.'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 7,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -181,23 +235,26 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 8,
|
||||||
"id": "15f11fbd",
|
"id": "15f11fbd",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Error in on_chain_start callback: 'name'\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
|
"\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"Invoking: `python_repl_ast` with `df1['Age'].nunique() - df2['Age'].nunique()`\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
"\u001b[0m\u001b[36;1m\u001b[1;3m-1\u001b[0m\u001b[32;1m\u001b[1;3mThere is 1 row in the age column that is different between the two dataframes.\u001b[0m\n",
|
||||||
"\u001b[32;1m\u001b[1;3mThought: I need to compare the age columns in both dataframes\n",
|
|
||||||
"Action: python_repl_ast\n",
|
|
||||||
"Action Input: len(df1[df1['Age'] != df2['Age']])\u001b[0m\n",
|
|
||||||
"Observation: \u001b[36;1m\u001b[1;3m177\u001b[0m\n",
|
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
|
||||||
"Final Answer: 177 rows in the age column are different.\u001b[0m\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
@ -205,17 +262,17 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'177 rows in the age column are different.'"
|
"'There is 1 row in the age column that is different between the two dataframes.'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"agent = create_csv_agent(OpenAI(temperature=0), ['titanic.csv', 'titanic_age_fillna.csv'], verbose=True)\n",
|
"agent = create_csv_agent(ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"), ['titanic.csv', 'titanic_age_fillna.csv'], verbose=True, agent_type=AgentType.OPENAI_FUNCTIONS)\n",
|
||||||
"agent.run(\"how many rows in the age column are different?\")"
|
"agent.run(\"how many rows in the age column are different between the two dfs?\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -19,7 +19,9 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.agents import create_pandas_dataframe_agent"
|
"from langchain.agents import create_pandas_dataframe_agent\n",
|
||||||
|
"from langchain.chat_models import ChatOpenAI\n",
|
||||||
|
"from langchain.agents.agent_types import AgentType"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -35,6 +37,16 @@
|
|||||||
"df = pd.read_csv('titanic.csv')"
|
"df = pd.read_csv('titanic.csv')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "a62858e2",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Using ZERO_SHOT_REACT_DESCRIPTION\n",
|
||||||
|
"\n",
|
||||||
|
"This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
@ -45,9 +57,34 @@
|
|||||||
"agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df, verbose=True)"
|
"agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df, verbose=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "7233ab56",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Using OpenAI Functions\n",
|
||||||
|
"\n",
|
||||||
|
"This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 4,
|
||||||
|
"id": "a8ea710e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent = create_pandas_dataframe_agent(\n",
|
||||||
|
" ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"), \n",
|
||||||
|
" df, \n",
|
||||||
|
" verbose=True, \n",
|
||||||
|
" agent_type=AgentType.OPENAI_FUNCTIONS\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
"id": "a9207a2e",
|
"id": "a9207a2e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -57,13 +94,12 @@
|
|||||||
"text": [
|
"text": [
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
"\u001b[32;1m\u001b[1;3mThought: I need to count the number of rows\n",
|
"\u001b[32;1m\u001b[1;3m\n",
|
||||||
"Action: python_repl_ast\n",
|
"Invoking: `python_repl_ast` with `df.shape[0]`\n",
|
||||||
"Action Input: df.shape[0]\u001b[0m\n",
|
"\n",
|
||||||
"Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n",
|
"\n",
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
"\u001b[0m\u001b[36;1m\u001b[1;3m891\u001b[0m\u001b[32;1m\u001b[1;3mThere are 891 rows in the dataframe.\u001b[0m\n",
|
||||||
"Final Answer: There are 891 rows.\u001b[0m\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
@ -71,10 +107,10 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'There are 891 rows.'"
|
"'There are 891 rows in the dataframe.'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 4,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -172,7 +208,6 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "c4bc0584",
|
"id": "c4bc0584",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -257,7 +292,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.16"
|
"version": "3.9.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 1,
|
||||||
"id": "f98e9c90-5c37-4fb9-af3e-d09693af8543",
|
"id": "f98e9c90-5c37-4fb9-af3e-d09693af8543",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -22,12 +22,24 @@
|
|||||||
"from langchain.agents.agent_toolkits import create_python_agent\n",
|
"from langchain.agents.agent_toolkits import create_python_agent\n",
|
||||||
"from langchain.tools.python.tool import PythonREPLTool\n",
|
"from langchain.tools.python.tool import PythonREPLTool\n",
|
||||||
"from langchain.python import PythonREPL\n",
|
"from langchain.python import PythonREPL\n",
|
||||||
"from langchain.llms.openai import OpenAI"
|
"from langchain.llms.openai import OpenAI\n",
|
||||||
|
"from langchain.agents.agent_types import AgentType\n",
|
||||||
|
"from langchain.chat_models import ChatOpenAI"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ca30d64c",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Using ZERO_SHOT_REACT_DESCRIPTION\n",
|
||||||
|
"\n",
|
||||||
|
"This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 2,
|
||||||
"id": "cc422f53-c51c-4694-a834-72ecd1e68363",
|
"id": "cc422f53-c51c-4694-a834-72ecd1e68363",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -37,7 +49,34 @@
|
|||||||
"agent_executor = create_python_agent(\n",
|
"agent_executor = create_python_agent(\n",
|
||||||
" llm=OpenAI(temperature=0, max_tokens=1000),\n",
|
" llm=OpenAI(temperature=0, max_tokens=1000),\n",
|
||||||
" tool=PythonREPLTool(),\n",
|
" tool=PythonREPLTool(),\n",
|
||||||
" verbose=True\n",
|
" verbose=True,\n",
|
||||||
|
" agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bb487e8e",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Using OpenAI Functions\n",
|
||||||
|
"\n",
|
||||||
|
"This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "6e651822",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent_executor = create_python_agent(\n",
|
||||||
|
" llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"),\n",
|
||||||
|
" tool=PythonREPLTool(),\n",
|
||||||
|
" verbose=True,\n",
|
||||||
|
" agent_type=AgentType.OPENAI_FUNCTIONS,\n",
|
||||||
|
" agent_executor_kwargs={\"handle_parsing_errors\": True},\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -52,7 +91,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 4,
|
||||||
"id": "25cd4f92-ea9b-4fe6-9838-a4f85f81eebe",
|
"id": "25cd4f92-ea9b-4fe6-9838-a4f85f81eebe",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -64,23 +103,20 @@
|
|||||||
"text": [
|
"text": [
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
"\u001b[32;1m\u001b[1;3m I need to calculate the 10th fibonacci number\n",
|
"\u001b[32;1m\u001b[1;3m\n",
|
||||||
"Action: Python REPL\n",
|
"Invoking: `Python_REPL` with `def fibonacci(n):\n",
|
||||||
"Action Input: def fibonacci(n):\n",
|
" if n <= 0:\n",
|
||||||
" if n == 0:\n",
|
|
||||||
" return 0\n",
|
" return 0\n",
|
||||||
" elif n == 1:\n",
|
" elif n == 1:\n",
|
||||||
" return 1\n",
|
" return 1\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" return fibonacci(n-1) + fibonacci(n-2)\u001b[0m\n",
|
" return fibonacci(n-1) + fibonacci(n-2)\n",
|
||||||
"Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n",
|
"\n",
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I need to call the function with 10 as the argument\n",
|
"fibonacci(10)`\n",
|
||||||
"Action: Python REPL\n",
|
"\n",
|
||||||
"Action Input: fibonacci(10)\u001b[0m\n",
|
"\n",
|
||||||
"Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n",
|
"\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3mThe 10th Fibonacci number is 55.\u001b[0m\n",
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
|
||||||
"Final Answer: 55\u001b[0m\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
@ -88,10 +124,10 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'55'"
|
"'The 10th Fibonacci number is 55.'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -111,9 +147,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 5,
|
||||||
"id": "4b9f60e7-eb6a-4f14-8604-498d863d4482",
|
"id": "4b9f60e7-eb6a-4f14-8604-498d863d4482",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"scrolled": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
@ -121,59 +159,70 @@
|
|||||||
"text": [
|
"text": [
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
"\u001b[32;1m\u001b[1;3m I need to write a neural network in PyTorch and train it on the given data.\n",
|
"\u001b[32;1m\u001b[1;3mCould not parse tool input: {'name': 'python', 'arguments': 'import torch\\nimport torch.nn as nn\\nimport torch.optim as optim\\n\\n# Define the neural network\\nclass SingleNeuron(nn.Module):\\n def __init__(self):\\n super(SingleNeuron, self).__init__()\\n self.linear = nn.Linear(1, 1)\\n \\n def forward(self, x):\\n return self.linear(x)\\n\\n# Create the synthetic data\\nx_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)\\ny_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float32)\\n\\n# Create the neural network\\nmodel = SingleNeuron()\\n\\n# Define the loss function and optimizer\\ncriterion = nn.MSELoss()\\noptimizer = optim.SGD(model.parameters(), lr=0.01)\\n\\n# Train the neural network\\nfor epoch in range(1, 1001):\\n # Forward pass\\n y_pred = model(x_train)\\n \\n # Compute loss\\n loss = criterion(y_pred, y_train)\\n \\n # Backward pass and optimization\\n optimizer.zero_grad()\\n loss.backward()\\n optimizer.step()\\n \\n # Print the loss every 100 epochs\\n if epoch % 100 == 0:\\n print(f\"Epoch {epoch}: Loss = {loss.item()}\")\\n\\n# Make a prediction for x = 5\\nx_test = torch.tensor([[5.0]], dtype=torch.float32)\\ny_pred = model(x_test)\\ny_pred.item()'} because the `arguments` is not valid JSON.\u001b[0mInvalid or incomplete response\u001b[32;1m\u001b[1;3m\n",
|
||||||
"Action: Python REPL\n",
|
"Invoking: `Python_REPL` with `import torch\n",
|
||||||
"Action Input: \n",
|
"import torch.nn as nn\n",
|
||||||
"import torch\n",
|
"import torch.optim as optim\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Define the model\n",
|
"# Define the neural network\n",
|
||||||
"model = torch.nn.Sequential(\n",
|
"class SingleNeuron(nn.Module):\n",
|
||||||
" torch.nn.Linear(1, 1)\n",
|
" def __init__(self):\n",
|
||||||
")\n",
|
" super(SingleNeuron, self).__init__()\n",
|
||||||
|
" self.linear = nn.Linear(1, 1)\n",
|
||||||
|
" \n",
|
||||||
|
" def forward(self, x):\n",
|
||||||
|
" return self.linear(x)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Define the loss\n",
|
"# Create the synthetic data\n",
|
||||||
"loss_fn = torch.nn.MSELoss()\n",
|
"x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)\n",
|
||||||
|
"y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float32)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Define the optimizer\n",
|
"# Create the neural network\n",
|
||||||
"optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n",
|
"model = SingleNeuron()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Define the data\n",
|
"# Define the loss function and optimizer\n",
|
||||||
"x_data = torch.tensor([[1.0], [2.0], [3.0], [4.0]])\n",
|
"criterion = nn.MSELoss()\n",
|
||||||
"y_data = torch.tensor([[2.0], [4.0], [6.0], [8.0]])\n",
|
"optimizer = optim.SGD(model.parameters(), lr=0.01)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Train the model\n",
|
"# Train the neural network\n",
|
||||||
"for epoch in range(1000):\n",
|
"for epoch in range(1, 1001):\n",
|
||||||
" # Forward pass\n",
|
" # Forward pass\n",
|
||||||
" y_pred = model(x_data)\n",
|
" y_pred = model(x_train)\n",
|
||||||
"\n",
|
" \n",
|
||||||
" # Compute and print loss\n",
|
" # Compute loss\n",
|
||||||
" loss = loss_fn(y_pred, y_data)\n",
|
" loss = criterion(y_pred, y_train)\n",
|
||||||
" if (epoch+1) % 100 == 0:\n",
|
" \n",
|
||||||
" print(f'Epoch {epoch+1}: loss = {loss.item():.4f}')\n",
|
" # Backward pass and optimization\n",
|
||||||
"\n",
|
|
||||||
" # Zero the gradients\n",
|
|
||||||
" optimizer.zero_grad()\n",
|
" optimizer.zero_grad()\n",
|
||||||
"\n",
|
|
||||||
" # Backward pass\n",
|
|
||||||
" loss.backward()\n",
|
" loss.backward()\n",
|
||||||
"\n",
|
|
||||||
" # Update the weights\n",
|
|
||||||
" optimizer.step()\n",
|
" optimizer.step()\n",
|
||||||
"\u001b[0m\n",
|
" \n",
|
||||||
"Observation: \u001b[36;1m\u001b[1;3mEpoch 100: loss = 0.0013\n",
|
" # Print the loss every 100 epochs\n",
|
||||||
"Epoch 200: loss = 0.0007\n",
|
" if epoch % 100 == 0:\n",
|
||||||
"Epoch 300: loss = 0.0004\n",
|
" print(f\"Epoch {epoch}: Loss = {loss.item()}\")\n",
|
||||||
"Epoch 400: loss = 0.0002\n",
|
"\n",
|
||||||
"Epoch 500: loss = 0.0001\n",
|
"# Make a prediction for x = 5\n",
|
||||||
"Epoch 600: loss = 0.0001\n",
|
"x_test = torch.tensor([[5.0]], dtype=torch.float32)\n",
|
||||||
"Epoch 700: loss = 0.0000\n",
|
"y_pred = model(x_test)\n",
|
||||||
"Epoch 800: loss = 0.0000\n",
|
"y_pred.item()`\n",
|
||||||
"Epoch 900: loss = 0.0000\n",
|
"\n",
|
||||||
"Epoch 1000: loss = 0.0000\n",
|
"\n",
|
||||||
"\u001b[0m\n",
|
"\u001b[0m\u001b[36;1m\u001b[1;3mEpoch 100: Loss = 0.03825576975941658\n",
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
"Epoch 200: Loss = 0.02100197970867157\n",
|
||||||
"Final Answer: The prediction for x = 5 is 10.0.\u001b[0m\n",
|
"Epoch 300: Loss = 0.01152981910854578\n",
|
||||||
|
"Epoch 400: Loss = 0.006329738534986973\n",
|
||||||
|
"Epoch 500: Loss = 0.0034749575424939394\n",
|
||||||
|
"Epoch 600: Loss = 0.0019077073084190488\n",
|
||||||
|
"Epoch 700: Loss = 0.001047312980517745\n",
|
||||||
|
"Epoch 800: Loss = 0.0005749554838985205\n",
|
||||||
|
"Epoch 900: Loss = 0.0003156439634039998\n",
|
||||||
|
"Epoch 1000: Loss = 0.00017328384274151176\n",
|
||||||
|
"\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"Invoking: `Python_REPL` with `x_test.item()`\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3mThe prediction for x = 5 is 10.000173568725586.\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
@ -181,10 +230,10 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'The prediction for x = 5 is 10.0.'"
|
"'The prediction for x = 5 is 10.000173568725586.'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 4,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -206,9 +255,9 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "LangChain",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "langchain"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
@ -220,7 +269,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.16"
|
"version": "3.9.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -37,7 +37,30 @@
|
|||||||
"from langchain.agents.agent_toolkits import SQLDatabaseToolkit\n",
|
"from langchain.agents.agent_toolkits import SQLDatabaseToolkit\n",
|
||||||
"from langchain.sql_database import SQLDatabase\n",
|
"from langchain.sql_database import SQLDatabase\n",
|
||||||
"from langchain.llms.openai import OpenAI\n",
|
"from langchain.llms.openai import OpenAI\n",
|
||||||
"from langchain.agents import AgentExecutor"
|
"from langchain.agents import AgentExecutor\n",
|
||||||
|
"from langchain.agents.agent_types import AgentType\n",
|
||||||
|
"from langchain.chat_models import ChatOpenAI"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "65ec5bb3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"db = SQLDatabase.from_uri(\"sqlite:///../../../../../notebooks/Chinook.db\")\n",
|
||||||
|
"toolkit = SQLDatabaseToolkit(db=db, llm=OpenAI(temperature=0))\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "f74d1792",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Using ZERO_SHOT_REACT_DESCRIPTION\n",
|
||||||
|
"\n",
|
||||||
|
"This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -49,16 +72,39 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"db = SQLDatabase.from_uri(\"sqlite:///../../../../notebooks/Chinook.db\")\n",
|
|
||||||
"toolkit = SQLDatabaseToolkit(db=db)\n",
|
|
||||||
"\n",
|
|
||||||
"agent_executor = create_sql_agent(\n",
|
"agent_executor = create_sql_agent(\n",
|
||||||
" llm=OpenAI(temperature=0),\n",
|
" llm=OpenAI(temperature=0),\n",
|
||||||
" toolkit=toolkit,\n",
|
" toolkit=toolkit,\n",
|
||||||
" verbose=True\n",
|
" verbose=True,\n",
|
||||||
|
" agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "971cc455",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Using OpenAI Functions\n",
|
||||||
|
"\n",
|
||||||
|
"This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "6426a27d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# agent_executor = create_sql_agent(\n",
|
||||||
|
"# llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"),\n",
|
||||||
|
"# toolkit=toolkit,\n",
|
||||||
|
"# verbose=True,\n",
|
||||||
|
"# agent_type=AgentType.OPENAI_FUNCTIONS\n",
|
||||||
|
"# )"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "36ae48c7-cb08-4fef-977e-c7d4b96a464b",
|
"id": "36ae48c7-cb08-4fef-977e-c7d4b96a464b",
|
||||||
@ -69,7 +115,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 4,
|
||||||
"id": "ff70e83d-5ad0-4fc7-bb96-27d82ac166d7",
|
"id": "ff70e83d-5ad0-4fc7-bb96-27d82ac166d7",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -81,14 +127,16 @@
|
|||||||
"text": [
|
"text": [
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
"\u001b[32;1m\u001b[1;3mAction: list_tables_sql_db\n",
|
"\u001b[32;1m\u001b[1;3m\n",
|
||||||
"Action Input: \"\"\u001b[0m\n",
|
"Invoking: `list_tables_sql_db` with `{}`\n",
|
||||||
"Observation: \u001b[38;5;200m\u001b[1;3mArtist, Invoice, Playlist, Genre, Album, PlaylistTrack, Track, InvoiceLine, MediaType, Employee, Customer\u001b[0m\n",
|
"\n",
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I should look at the schema of the playlisttrack table\n",
|
"\n",
|
||||||
"Action: schema_sql_db\n",
|
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Track, PlaylistTrack, InvoiceLine, sales_table, Playlist, Genre, Employee, Customer, Invoice, MediaType\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
||||||
"Action Input: \"PlaylistTrack\"\u001b[0m\n",
|
"Invoking: `schema_sql_db` with `PlaylistTrack`\n",
|
||||||
"Observation: \u001b[33;1m\u001b[1;3m\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
|
||||||
"CREATE TABLE \"PlaylistTrack\" (\n",
|
"CREATE TABLE \"PlaylistTrack\" (\n",
|
||||||
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
|
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
|
||||||
"\t\"TrackId\" INTEGER NOT NULL, \n",
|
"\t\"TrackId\" INTEGER NOT NULL, \n",
|
||||||
@ -97,13 +145,36 @@
|
|||||||
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
|
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"SELECT * FROM 'PlaylistTrack' LIMIT 3;\n",
|
"/*\n",
|
||||||
"PlaylistId TrackId\n",
|
"3 rows from PlaylistTrack table:\n",
|
||||||
"1 3402\n",
|
"PlaylistId\tTrackId\n",
|
||||||
"1 3389\n",
|
"1\t3402\n",
|
||||||
"1 3390\u001b[0m\n",
|
"1\t3389\n",
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
"1\t3390\n",
|
||||||
"Final Answer: The PlaylistTrack table has two columns, PlaylistId and TrackId, and is linked to the Playlist and Track tables.\u001b[0m\n",
|
"*/\u001b[0m\u001b[32;1m\u001b[1;3mThe `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the relationship between playlists and tracks. \n",
|
||||||
|
"\n",
|
||||||
|
"Here is the schema of the `PlaylistTrack` table:\n",
|
||||||
|
"\n",
|
||||||
|
"```\n",
|
||||||
|
"CREATE TABLE \"PlaylistTrack\" (\n",
|
||||||
|
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
|
||||||
|
"\t\"TrackId\" INTEGER NOT NULL, \n",
|
||||||
|
"\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n",
|
||||||
|
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
|
||||||
|
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
|
||||||
|
")\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"Here are three sample rows from the `PlaylistTrack` table:\n",
|
||||||
|
"\n",
|
||||||
|
"```\n",
|
||||||
|
"PlaylistId TrackId\n",
|
||||||
|
"1 3402\n",
|
||||||
|
"1 3389\n",
|
||||||
|
"1 3390\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"Please let me know if there is anything else I can help you with.\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
@ -111,10 +182,10 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'The PlaylistTrack table has two columns, PlaylistId and TrackId, and is linked to the Playlist and Track tables.'"
|
"'The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the relationship between playlists and tracks. \\n\\nHere is the schema of the `PlaylistTrack` table:\\n\\n```\\nCREATE TABLE \"PlaylistTrack\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\n)\\n```\\n\\nHere are three sample rows from the `PlaylistTrack` table:\\n\\n```\\nPlaylistId TrackId\\n1 3402\\n1 3389\\n1 3390\\n```\\n\\nPlease let me know if there is anything else I can help you with.'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -519,7 +590,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.9"
|
"version": "3.9.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -1,19 +1,26 @@
|
|||||||
"""Agent for working with pandas objects."""
|
"""Agent for working with pandas objects."""
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from langchain.agents.agent import AgentExecutor
|
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||||
from langchain.agents.agent_toolkits.pandas.prompt import (
|
from langchain.agents.agent_toolkits.pandas.prompt import (
|
||||||
|
FUNCTIONS_WITH_DF,
|
||||||
|
FUNCTIONS_WITH_MULTI_DF,
|
||||||
MULTI_DF_PREFIX,
|
MULTI_DF_PREFIX,
|
||||||
|
MULTI_DF_PREFIX_FUNCTIONS,
|
||||||
PREFIX,
|
PREFIX,
|
||||||
|
PREFIX_FUNCTIONS,
|
||||||
SUFFIX_NO_DF,
|
SUFFIX_NO_DF,
|
||||||
SUFFIX_WITH_DF,
|
SUFFIX_WITH_DF,
|
||||||
SUFFIX_WITH_MULTI_DF,
|
SUFFIX_WITH_MULTI_DF,
|
||||||
)
|
)
|
||||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||||
|
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||||
|
from langchain.agents.types import AgentType
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
from langchain.schema import SystemMessage
|
||||||
from langchain.tools.python.tool import PythonAstREPLTool
|
from langchain.tools.python.tool import PythonAstREPLTool
|
||||||
|
|
||||||
|
|
||||||
@ -137,9 +144,108 @@ def _get_prompt_and_tools(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_functions_single_prompt(
|
||||||
|
df: Any,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
suffix: Optional[str] = None,
|
||||||
|
include_df_in_prompt: Optional[bool] = True,
|
||||||
|
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||||
|
if suffix is not None:
|
||||||
|
suffix_to_use = suffix
|
||||||
|
if include_df_in_prompt:
|
||||||
|
suffix_to_use = suffix_to_use.format(df_head=str(df.head().to_markdown()))
|
||||||
|
elif include_df_in_prompt:
|
||||||
|
suffix_to_use = FUNCTIONS_WITH_DF.format(df_head=str(df.head().to_markdown()))
|
||||||
|
else:
|
||||||
|
suffix_to_use = ""
|
||||||
|
|
||||||
|
if prefix is None:
|
||||||
|
prefix = PREFIX_FUNCTIONS
|
||||||
|
|
||||||
|
tools = [PythonAstREPLTool(locals={"df": df})]
|
||||||
|
system_message = SystemMessage(content=prefix + suffix_to_use)
|
||||||
|
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
||||||
|
return prompt, tools
|
||||||
|
|
||||||
|
|
||||||
|
def _get_functions_multi_prompt(
|
||||||
|
dfs: Any,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
suffix: Optional[str] = None,
|
||||||
|
include_df_in_prompt: Optional[bool] = True,
|
||||||
|
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||||
|
if suffix is not None:
|
||||||
|
suffix_to_use = suffix
|
||||||
|
if include_df_in_prompt:
|
||||||
|
dfs_head = "\n\n".join([d.head().to_markdown() for d in dfs])
|
||||||
|
suffix_to_use = suffix_to_use.format(
|
||||||
|
dfs_head=dfs_head,
|
||||||
|
)
|
||||||
|
elif include_df_in_prompt:
|
||||||
|
dfs_head = "\n\n".join([d.head().to_markdown() for d in dfs])
|
||||||
|
suffix_to_use = FUNCTIONS_WITH_MULTI_DF.format(
|
||||||
|
dfs_head=dfs_head,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
suffix_to_use = ""
|
||||||
|
|
||||||
|
if prefix is None:
|
||||||
|
prefix = MULTI_DF_PREFIX_FUNCTIONS
|
||||||
|
prefix = prefix.format(num_dfs=str(len(dfs)))
|
||||||
|
|
||||||
|
df_locals = {}
|
||||||
|
for i, dataframe in enumerate(dfs):
|
||||||
|
df_locals[f"df{i + 1}"] = dataframe
|
||||||
|
tools = [PythonAstREPLTool(locals=df_locals)]
|
||||||
|
system_message = SystemMessage(content=prefix + suffix_to_use)
|
||||||
|
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
||||||
|
return prompt, tools
|
||||||
|
|
||||||
|
|
||||||
|
def _get_functions_prompt_and_tools(
|
||||||
|
df: Any,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
suffix: Optional[str] = None,
|
||||||
|
input_variables: Optional[List[str]] = None,
|
||||||
|
include_df_in_prompt: Optional[bool] = True,
|
||||||
|
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||||
|
try:
|
||||||
|
import pandas as pd
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"pandas package not found, please install with `pip install pandas`"
|
||||||
|
)
|
||||||
|
if input_variables is not None:
|
||||||
|
raise ValueError("`input_variables` is not supported at the moment.")
|
||||||
|
|
||||||
|
if include_df_in_prompt is not None and suffix is not None:
|
||||||
|
raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
|
||||||
|
|
||||||
|
if isinstance(df, list):
|
||||||
|
for item in df:
|
||||||
|
if not isinstance(item, pd.DataFrame):
|
||||||
|
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||||
|
return _get_functions_multi_prompt(
|
||||||
|
df,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
include_df_in_prompt=include_df_in_prompt,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not isinstance(df, pd.DataFrame):
|
||||||
|
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||||
|
return _get_functions_single_prompt(
|
||||||
|
df,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
include_df_in_prompt=include_df_in_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_pandas_dataframe_agent(
|
def create_pandas_dataframe_agent(
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
df: Any,
|
df: Any,
|
||||||
|
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
suffix: Optional[str] = None,
|
suffix: Optional[str] = None,
|
||||||
@ -154,26 +260,44 @@ def create_pandas_dataframe_agent(
|
|||||||
**kwargs: Dict[str, Any],
|
**kwargs: Dict[str, Any],
|
||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
"""Construct a pandas agent from an LLM and dataframe."""
|
"""Construct a pandas agent from an LLM and dataframe."""
|
||||||
|
agent: BaseSingleActionAgent
|
||||||
prompt, tools = _get_prompt_and_tools(
|
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||||
df,
|
prompt, tools = _get_prompt_and_tools(
|
||||||
prefix=prefix,
|
df,
|
||||||
suffix=suffix,
|
prefix=prefix,
|
||||||
input_variables=input_variables,
|
suffix=suffix,
|
||||||
include_df_in_prompt=include_df_in_prompt,
|
input_variables=input_variables,
|
||||||
)
|
include_df_in_prompt=include_df_in_prompt,
|
||||||
llm_chain = LLMChain(
|
)
|
||||||
llm=llm,
|
llm_chain = LLMChain(
|
||||||
prompt=prompt,
|
llm=llm,
|
||||||
callback_manager=callback_manager,
|
prompt=prompt,
|
||||||
)
|
callback_manager=callback_manager,
|
||||||
tool_names = [tool.name for tool in tools]
|
)
|
||||||
agent = ZeroShotAgent(
|
tool_names = [tool.name for tool in tools]
|
||||||
llm_chain=llm_chain,
|
agent = ZeroShotAgent(
|
||||||
allowed_tools=tool_names,
|
llm_chain=llm_chain,
|
||||||
callback_manager=callback_manager,
|
allowed_tools=tool_names,
|
||||||
**kwargs,
|
callback_manager=callback_manager,
|
||||||
)
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||||
|
_prompt, tools = _get_functions_prompt_and_tools(
|
||||||
|
df,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
input_variables=input_variables,
|
||||||
|
include_df_in_prompt=include_df_in_prompt,
|
||||||
|
)
|
||||||
|
agent = OpenAIFunctionsAgent(
|
||||||
|
llm=llm,
|
||||||
|
prompt=_prompt,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
||||||
return AgentExecutor.from_agent_and_tools(
|
return AgentExecutor.from_agent_and_tools(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
@ -28,3 +28,17 @@ This is the result of `print(df.head())` for each dataframe:
|
|||||||
Begin!
|
Begin!
|
||||||
Question: {input}
|
Question: {input}
|
||||||
{agent_scratchpad}"""
|
{agent_scratchpad}"""
|
||||||
|
|
||||||
|
PREFIX_FUNCTIONS = """
|
||||||
|
You are working with a pandas dataframe in Python. The name of the dataframe is `df`."""
|
||||||
|
|
||||||
|
MULTI_DF_PREFIX_FUNCTIONS = """
|
||||||
|
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc."""
|
||||||
|
|
||||||
|
FUNCTIONS_WITH_DF = """
|
||||||
|
This is the result of `print(df.head())`:
|
||||||
|
{df_head}"""
|
||||||
|
|
||||||
|
FUNCTIONS_WITH_MULTI_DF = """
|
||||||
|
This is the result of `print(df.head())` for each dataframe:
|
||||||
|
{dfs_head}"""
|
||||||
|
@ -2,18 +2,22 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from langchain.agents.agent import AgentExecutor
|
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||||
from langchain.agents.agent_toolkits.python.prompt import PREFIX
|
from langchain.agents.agent_toolkits.python.prompt import PREFIX
|
||||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||||
|
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||||
|
from langchain.agents.types import AgentType
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.schema import SystemMessage
|
||||||
from langchain.tools.python.tool import PythonREPLTool
|
from langchain.tools.python.tool import PythonREPLTool
|
||||||
|
|
||||||
|
|
||||||
def create_python_agent(
|
def create_python_agent(
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
tool: PythonREPLTool,
|
tool: PythonREPLTool,
|
||||||
|
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
prefix: str = PREFIX,
|
prefix: str = PREFIX,
|
||||||
@ -22,14 +26,29 @@ def create_python_agent(
|
|||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
"""Construct a python agent from an LLM and tool."""
|
"""Construct a python agent from an LLM and tool."""
|
||||||
tools = [tool]
|
tools = [tool]
|
||||||
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
|
agent: BaseSingleActionAgent
|
||||||
llm_chain = LLMChain(
|
|
||||||
llm=llm,
|
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||||
prompt=prompt,
|
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
|
||||||
callback_manager=callback_manager,
|
llm_chain = LLMChain(
|
||||||
)
|
llm=llm,
|
||||||
tool_names = [tool.name for tool in tools]
|
prompt=prompt,
|
||||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
callback_manager=callback_manager,
|
||||||
|
)
|
||||||
|
tool_names = [tool.name for tool in tools]
|
||||||
|
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||||
|
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||||
|
system_message = SystemMessage(content=prefix)
|
||||||
|
_prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
||||||
|
agent = OpenAIFunctionsAgent(
|
||||||
|
llm=llm,
|
||||||
|
prompt=_prompt,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
||||||
return AgentExecutor.from_agent_and_tools(
|
return AgentExecutor.from_agent_and_tools(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
@ -1,22 +1,35 @@
|
|||||||
"""SQL agent."""
|
"""SQL agent."""
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.agents.agent import AgentExecutor
|
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||||
from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX
|
from langchain.agents.agent_toolkits.sql.prompt import (
|
||||||
|
SQL_FUNCTIONS_SUFFIX,
|
||||||
|
SQL_PREFIX,
|
||||||
|
SQL_SUFFIX,
|
||||||
|
)
|
||||||
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
||||||
|
from langchain.agents.agent_types import AgentType
|
||||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||||
|
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.prompts.chat import (
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
MessagesPlaceholder,
|
||||||
|
)
|
||||||
|
from langchain.schema import AIMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
def create_sql_agent(
|
def create_sql_agent(
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
toolkit: SQLDatabaseToolkit,
|
toolkit: SQLDatabaseToolkit,
|
||||||
|
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
prefix: str = SQL_PREFIX,
|
prefix: str = SQL_PREFIX,
|
||||||
suffix: str = SQL_SUFFIX,
|
suffix: Optional[str] = None,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[List[str]] = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
@ -30,20 +43,44 @@ def create_sql_agent(
|
|||||||
"""Construct a sql agent from an LLM and tools."""
|
"""Construct a sql agent from an LLM and tools."""
|
||||||
tools = toolkit.get_tools()
|
tools = toolkit.get_tools()
|
||||||
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
|
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
|
||||||
prompt = ZeroShotAgent.create_prompt(
|
agent: BaseSingleActionAgent
|
||||||
tools,
|
|
||||||
prefix=prefix,
|
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||||
suffix=suffix,
|
prompt = ZeroShotAgent.create_prompt(
|
||||||
format_instructions=format_instructions,
|
tools,
|
||||||
input_variables=input_variables,
|
prefix=prefix,
|
||||||
)
|
suffix=suffix or SQL_SUFFIX,
|
||||||
llm_chain = LLMChain(
|
format_instructions=format_instructions,
|
||||||
llm=llm,
|
input_variables=input_variables,
|
||||||
prompt=prompt,
|
)
|
||||||
callback_manager=callback_manager,
|
llm_chain = LLMChain(
|
||||||
)
|
llm=llm,
|
||||||
tool_names = [tool.name for tool in tools]
|
prompt=prompt,
|
||||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
callback_manager=callback_manager,
|
||||||
|
)
|
||||||
|
tool_names = [tool.name for tool in tools]
|
||||||
|
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||||
|
|
||||||
|
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=prefix),
|
||||||
|
HumanMessagePromptTemplate.from_template("{input}"),
|
||||||
|
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||||
|
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||||
|
]
|
||||||
|
input_variables = ["input", "agent_scratchpad"]
|
||||||
|
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||||
|
|
||||||
|
agent = OpenAIFunctionsAgent(
|
||||||
|
llm=llm,
|
||||||
|
prompt=_prompt,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
||||||
|
|
||||||
return AgentExecutor.from_agent_and_tools(
|
return AgentExecutor.from_agent_and_tools(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
@ -19,3 +19,5 @@ SQL_SUFFIX = """Begin!
|
|||||||
Question: {input}
|
Question: {input}
|
||||||
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
|
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
|
||||||
{agent_scratchpad}"""
|
{agent_scratchpad}"""
|
||||||
|
|
||||||
|
SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."""
|
||||||
|
@ -7,12 +7,11 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
|
|||||||
from langchain.agents import BaseSingleActionAgent
|
from langchain.agents import BaseSingleActionAgent
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import Callbacks
|
||||||
Callbacks,
|
|
||||||
)
|
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.prompts.chat import (
|
from langchain.prompts.chat import (
|
||||||
|
BaseMessagePromptTemplate,
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
MessagesPlaceholder,
|
MessagesPlaceholder,
|
||||||
@ -23,6 +22,7 @@ from langchain.schema import (
|
|||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
FunctionMessage,
|
FunctionMessage,
|
||||||
|
OutputParserException,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
@ -34,7 +34,9 @@ class _FunctionsAgentAction(AgentAction):
|
|||||||
message_log: List[BaseMessage]
|
message_log: List[BaseMessage]
|
||||||
|
|
||||||
|
|
||||||
def _convert_agent_action_to_messages(agent_action: AgentAction) -> List[BaseMessage]:
|
def _convert_agent_action_to_messages(
|
||||||
|
agent_action: AgentAction, observation: str
|
||||||
|
) -> List[BaseMessage]:
|
||||||
"""Convert an agent action to a message.
|
"""Convert an agent action to a message.
|
||||||
|
|
||||||
This code is used to reconstruct the original AI message from the agent action.
|
This code is used to reconstruct the original AI message from the agent action.
|
||||||
@ -45,9 +47,12 @@ def _convert_agent_action_to_messages(agent_action: AgentAction) -> List[BaseMes
|
|||||||
Returns:
|
Returns:
|
||||||
AIMessage that corresponds to the original tool invocation.
|
AIMessage that corresponds to the original tool invocation.
|
||||||
"""
|
"""
|
||||||
if not isinstance(agent_action, _FunctionsAgentAction):
|
if isinstance(agent_action, _FunctionsAgentAction):
|
||||||
raise ValueError("This agent type only works with _FunctionsAgentAction")
|
return agent_action.message_log + [
|
||||||
return agent_action.message_log
|
_create_function_message(agent_action, observation)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [AIMessage(content=agent_action.log)]
|
||||||
|
|
||||||
|
|
||||||
def _create_function_message(
|
def _create_function_message(
|
||||||
@ -61,7 +66,10 @@ def _create_function_message(
|
|||||||
FunctionMessage that corresponds to the original tool invocation
|
FunctionMessage that corresponds to the original tool invocation
|
||||||
"""
|
"""
|
||||||
if not isinstance(observation, str):
|
if not isinstance(observation, str):
|
||||||
content = json.dumps(observation)
|
try:
|
||||||
|
content = json.dumps(observation)
|
||||||
|
except Exception:
|
||||||
|
content = str(observation)
|
||||||
else:
|
else:
|
||||||
content = observation
|
content = observation
|
||||||
return FunctionMessage(
|
return FunctionMessage(
|
||||||
@ -83,8 +91,7 @@ def _format_intermediate_steps(
|
|||||||
|
|
||||||
for intermediate_step in intermediate_steps:
|
for intermediate_step in intermediate_steps:
|
||||||
agent_action, observation = intermediate_step
|
agent_action, observation = intermediate_step
|
||||||
messages.extend(_convert_agent_action_to_messages(agent_action))
|
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
|
||||||
messages.append(_create_function_message(agent_action, observation))
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -102,7 +109,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
|
|||||||
try:
|
try:
|
||||||
_tool_input = json.loads(function_call["arguments"])
|
_tool_input = json.loads(function_call["arguments"])
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
raise ValueError(
|
raise OutputParserException(
|
||||||
f"Could not parse tool input: {function_call} because "
|
f"Could not parse tool input: {function_call} because "
|
||||||
f"the `arguments` is not valid JSON."
|
f"the `arguments` is not valid JSON."
|
||||||
)
|
)
|
||||||
@ -202,12 +209,24 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
return agent_decision
|
return agent_decision
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_prompt(cls) -> BasePromptTemplate:
|
def create_prompt(
|
||||||
messages = [
|
cls,
|
||||||
SystemMessage(content="You are a helpful AI assistant."),
|
system_message: Optional[SystemMessage] = SystemMessage(
|
||||||
HumanMessagePromptTemplate.from_template("{input}"),
|
content="You are a helpful AI assistant."
|
||||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
),
|
||||||
]
|
) -> BasePromptTemplate:
|
||||||
|
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||||
|
if system_message:
|
||||||
|
messages = [system_message]
|
||||||
|
else:
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
messages.extend(
|
||||||
|
[
|
||||||
|
HumanMessagePromptTemplate.from_template("{input}"),
|
||||||
|
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||||
|
]
|
||||||
|
)
|
||||||
input_variables = ["input", "agent_scratchpad"]
|
input_variables = ["input", "agent_scratchpad"]
|
||||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ from langchain.utilities.arxiv import ArxivAPIWrapper
|
|||||||
class ArxivQueryRun(BaseTool):
|
class ArxivQueryRun(BaseTool):
|
||||||
"""Tool that adds the capability to search using the Arxiv API."""
|
"""Tool that adds the capability to search using the Arxiv API."""
|
||||||
|
|
||||||
name = "Arxiv"
|
name = "arxiv"
|
||||||
description = (
|
description = (
|
||||||
"A wrapper around Arxiv.org "
|
"A wrapper around Arxiv.org "
|
||||||
"Useful for when you need to answer questions about Physics, Mathematics, "
|
"Useful for when you need to answer questions about Physics, Mathematics, "
|
||||||
|
@ -27,7 +27,7 @@ class AzureCogsFormRecognizerTool(BaseTool):
|
|||||||
azure_cogs_endpoint: str = "" #: :meta private:
|
azure_cogs_endpoint: str = "" #: :meta private:
|
||||||
doc_analysis_client: Any #: :meta private:
|
doc_analysis_client: Any #: :meta private:
|
||||||
|
|
||||||
name = "Azure Cognitive Services Form Recognizer"
|
name = "azure_cognitive_services_form_recognizer"
|
||||||
description = (
|
description = (
|
||||||
"A wrapper around Azure Cognitive Services Form Recognizer. "
|
"A wrapper around Azure Cognitive Services Form Recognizer. "
|
||||||
"Useful for when you need to "
|
"Useful for when you need to "
|
||||||
|
@ -28,7 +28,7 @@ class AzureCogsImageAnalysisTool(BaseTool):
|
|||||||
vision_service: Any #: :meta private:
|
vision_service: Any #: :meta private:
|
||||||
analysis_options: Any #: :meta private:
|
analysis_options: Any #: :meta private:
|
||||||
|
|
||||||
name = "Azure Cognitive Services Image Analysis"
|
name = "azure_cognitive_services_image_analysis"
|
||||||
description = (
|
description = (
|
||||||
"A wrapper around Azure Cognitive Services Image Analysis. "
|
"A wrapper around Azure Cognitive Services Image Analysis. "
|
||||||
"Useful for when you need to analyze images. "
|
"Useful for when you need to analyze images. "
|
||||||
|
@ -32,7 +32,7 @@ class AzureCogsSpeech2TextTool(BaseTool):
|
|||||||
speech_language: str = "en-US" #: :meta private:
|
speech_language: str = "en-US" #: :meta private:
|
||||||
speech_config: Any #: :meta private:
|
speech_config: Any #: :meta private:
|
||||||
|
|
||||||
name = "Azure Cognitive Services Speech2Text"
|
name = "azure_cognitive_services_speech2text"
|
||||||
description = (
|
description = (
|
||||||
"A wrapper around Azure Cognitive Services Speech2Text. "
|
"A wrapper around Azure Cognitive Services Speech2Text. "
|
||||||
"Useful for when you need to transcribe audio to text. "
|
"Useful for when you need to transcribe audio to text. "
|
||||||
|
@ -28,7 +28,7 @@ class AzureCogsText2SpeechTool(BaseTool):
|
|||||||
speech_language: str = "en-US" #: :meta private:
|
speech_language: str = "en-US" #: :meta private:
|
||||||
speech_config: Any #: :meta private:
|
speech_config: Any #: :meta private:
|
||||||
|
|
||||||
name = "Azure Cognitive Services Text2Speech"
|
name = "azure_cognitive_services_text2speech"
|
||||||
description = (
|
description = (
|
||||||
"A wrapper around Azure Cognitive Services Text2Speech. "
|
"A wrapper around Azure Cognitive Services Text2Speech. "
|
||||||
"Useful for when you need to convert text to speech. "
|
"Useful for when you need to convert text to speech. "
|
||||||
|
@ -13,7 +13,7 @@ from langchain.utilities.bing_search import BingSearchAPIWrapper
|
|||||||
class BingSearchRun(BaseTool):
|
class BingSearchRun(BaseTool):
|
||||||
"""Tool that adds the capability to query the Bing search API."""
|
"""Tool that adds the capability to query the Bing search API."""
|
||||||
|
|
||||||
name = "Bing Search"
|
name = "bing_search"
|
||||||
description = (
|
description = (
|
||||||
"A wrapper around Bing Search. "
|
"A wrapper around Bing Search. "
|
||||||
"Useful for when you need to answer questions about current events. "
|
"Useful for when you need to answer questions about current events. "
|
||||||
|
@ -11,7 +11,7 @@ from langchain.utilities.brave_search import BraveSearchWrapper
|
|||||||
|
|
||||||
|
|
||||||
class BraveSearch(BaseTool):
|
class BraveSearch(BaseTool):
|
||||||
name = "brave-search"
|
name = "brave_search"
|
||||||
description = (
|
description = (
|
||||||
"a search engine. "
|
"a search engine. "
|
||||||
"useful for when you need to answer questions about current events."
|
"useful for when you need to answer questions about current events."
|
||||||
|
@ -16,7 +16,7 @@ from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
|||||||
class DuckDuckGoSearchRun(BaseTool):
|
class DuckDuckGoSearchRun(BaseTool):
|
||||||
"""Tool that adds the capability to query the DuckDuckGo search API."""
|
"""Tool that adds the capability to query the DuckDuckGo search API."""
|
||||||
|
|
||||||
name = "DuckDuckGo Search"
|
name = "duckduckgo_search"
|
||||||
description = (
|
description = (
|
||||||
"A wrapper around DuckDuckGo Search. "
|
"A wrapper around DuckDuckGo Search. "
|
||||||
"Useful for when you need to answer questions about current events. "
|
"Useful for when you need to answer questions about current events. "
|
||||||
|
@ -19,7 +19,7 @@ class GooglePlacesSchema(BaseModel):
|
|||||||
class GooglePlacesTool(BaseTool):
|
class GooglePlacesTool(BaseTool):
|
||||||
"""Tool that adds the capability to query the Google places API."""
|
"""Tool that adds the capability to query the Google places API."""
|
||||||
|
|
||||||
name = "Google Places"
|
name = "google_places"
|
||||||
description = (
|
description = (
|
||||||
"A wrapper around Google Places. "
|
"A wrapper around Google Places. "
|
||||||
"Useful for when you need to validate or "
|
"Useful for when you need to validate or "
|
||||||
|
@ -13,7 +13,7 @@ from langchain.utilities.google_search import GoogleSearchAPIWrapper
|
|||||||
class GoogleSearchRun(BaseTool):
|
class GoogleSearchRun(BaseTool):
|
||||||
"""Tool that adds the capability to query the Google search API."""
|
"""Tool that adds the capability to query the Google search API."""
|
||||||
|
|
||||||
name = "Google Search"
|
name = "google_search"
|
||||||
description = (
|
description = (
|
||||||
"A wrapper around Google Search. "
|
"A wrapper around Google Search. "
|
||||||
"Useful for when you need to answer questions about current events. "
|
"Useful for when you need to answer questions about current events. "
|
||||||
|
@ -15,7 +15,7 @@ from langchain.utilities.google_serper import GoogleSerperAPIWrapper
|
|||||||
class GoogleSerperRun(BaseTool):
|
class GoogleSerperRun(BaseTool):
|
||||||
"""Tool that adds the capability to query the Serper.dev Google search API."""
|
"""Tool that adds the capability to query the Serper.dev Google search API."""
|
||||||
|
|
||||||
name = "Google Serper"
|
name = "google_serper"
|
||||||
description = (
|
description = (
|
||||||
"A low-cost Google Search API."
|
"A low-cost Google Search API."
|
||||||
"Useful for when you need to answer questions about current events."
|
"Useful for when you need to answer questions about current events."
|
||||||
|
@ -19,7 +19,7 @@ def _print_func(text: str) -> None:
|
|||||||
class HumanInputRun(BaseTool):
|
class HumanInputRun(BaseTool):
|
||||||
"""Tool that adds the capability to ask user for input."""
|
"""Tool that adds the capability to ask user for input."""
|
||||||
|
|
||||||
name = "Human"
|
name = "human"
|
||||||
description = (
|
description = (
|
||||||
"You can ask a human for guidance when you think you "
|
"You can ask a human for guidance when you think you "
|
||||||
"got stuck or you are not sure what to do next. "
|
"got stuck or you are not sure what to do next. "
|
||||||
|
@ -13,7 +13,7 @@ from langchain.utilities.metaphor_search import MetaphorSearchAPIWrapper
|
|||||||
class MetaphorSearchResults(BaseTool):
|
class MetaphorSearchResults(BaseTool):
|
||||||
"""Tool that has capability to query the Metaphor Search API and get back json."""
|
"""Tool that has capability to query the Metaphor Search API and get back json."""
|
||||||
|
|
||||||
name = "Metaphor Search Results JSON"
|
name = "metaphor_search_results_json"
|
||||||
description = (
|
description = (
|
||||||
"A wrapper around Metaphor Search. "
|
"A wrapper around Metaphor Search. "
|
||||||
"Input should be a Metaphor-optimized query. "
|
"Input should be a Metaphor-optimized query. "
|
||||||
|
@ -34,7 +34,7 @@ def sanitize_input(query: str) -> str:
|
|||||||
class PythonREPLTool(BaseTool):
|
class PythonREPLTool(BaseTool):
|
||||||
"""A tool for running python code in a REPL."""
|
"""A tool for running python code in a REPL."""
|
||||||
|
|
||||||
name = "Python REPL"
|
name = "Python_REPL"
|
||||||
description = (
|
description = (
|
||||||
"A Python shell. Use this to execute python commands. "
|
"A Python shell. Use this to execute python commands. "
|
||||||
"Input should be a valid python command. "
|
"Input should be a valid python command. "
|
||||||
|
@ -20,7 +20,7 @@ class SceneXplainInput(BaseModel):
|
|||||||
class SceneXplainTool(BaseTool):
|
class SceneXplainTool(BaseTool):
|
||||||
"""Tool that adds the capability to explain images."""
|
"""Tool that adds the capability to explain images."""
|
||||||
|
|
||||||
name = "Image Explainer"
|
name = "image_explainer"
|
||||||
description = (
|
description = (
|
||||||
"An Image Captioning Tool: Use this tool to generate a detailed caption "
|
"An Image Captioning Tool: Use this tool to generate a detailed caption "
|
||||||
"for an image. The input can be an image file of any format, and "
|
"for an image. The input can be an image file of any format, and "
|
||||||
|
@ -14,7 +14,7 @@ from langchain.utilities.searx_search import SearxSearchWrapper
|
|||||||
class SearxSearchRun(BaseTool):
|
class SearxSearchRun(BaseTool):
|
||||||
"""Tool that adds the capability to query a Searx instance."""
|
"""Tool that adds the capability to query a Searx instance."""
|
||||||
|
|
||||||
name = "Searx Search"
|
name = "searx_search"
|
||||||
description = (
|
description = (
|
||||||
"A meta search engine."
|
"A meta search engine."
|
||||||
"Useful for when you need to answer questions about current events."
|
"Useful for when you need to answer questions about current events."
|
||||||
|
@ -13,7 +13,7 @@ from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
|
|||||||
class WolframAlphaQueryRun(BaseTool):
|
class WolframAlphaQueryRun(BaseTool):
|
||||||
"""Tool that adds the capability to query using the Wolfram Alpha SDK."""
|
"""Tool that adds the capability to query using the Wolfram Alpha SDK."""
|
||||||
|
|
||||||
name = "Wolfram Alpha"
|
name = "wolfram_alpha"
|
||||||
description = (
|
description = (
|
||||||
"A wrapper around Wolfram Alpha. "
|
"A wrapper around Wolfram Alpha. "
|
||||||
"Useful for when you need to answer questions about Math, "
|
"Useful for when you need to answer questions about Math, "
|
||||||
|
@ -19,7 +19,7 @@ from langchain.tools import BaseTool
|
|||||||
|
|
||||||
|
|
||||||
class YouTubeSearchTool(BaseTool):
|
class YouTubeSearchTool(BaseTool):
|
||||||
name = "YouTubeSearch"
|
name = "youtube_search"
|
||||||
description = (
|
description = (
|
||||||
"search for youtube videos associated with a person. "
|
"search for youtube videos associated with a person. "
|
||||||
"the input to this tool should be a comma separated list, "
|
"the input to this tool should be a comma separated list, "
|
||||||
|
@ -164,7 +164,7 @@ class ZapierNLAListActions(BaseTool):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "Zapier NLA: List Actions"
|
name = "ZapierNLA_list_actions"
|
||||||
description = BASE_ZAPIER_TOOL_PROMPT + (
|
description = BASE_ZAPIER_TOOL_PROMPT + (
|
||||||
"This tool returns a list of the user's exposed actions."
|
"This tool returns a list of the user's exposed actions."
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user