Harrison/use functions agent (#6185)

Co-authored-by: Francisco Ingham <24279597+fpingham@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-06-15 08:18:50 -07:00 committed by GitHub
parent 7d2b946d0b
commit e82687ddf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 682 additions and 255 deletions

View File

@ -30,38 +30,83 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.llms import OpenAI" "from langchain.llms import OpenAI\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.agents.agent_types import AgentType"
]
},
{
"cell_type": "markdown",
"id": "bd806175",
"metadata": {},
"source": [
"## Using ZERO_SHOT_REACT_DESCRIPTION\n",
"\n",
"This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a1717204",
"metadata": {},
"outputs": [],
"source": [
"agent = create_csv_agent(\n",
" OpenAI(temperature=0), \n",
" 'titanic.csv', \n",
" verbose=True, \n",
" agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c31bb8a6",
"metadata": {},
"source": [
"## Using OpenAI Functions\n",
"\n",
"This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "16c4dc59",
"metadata": {},
"outputs": [],
"source": [
"agent = create_csv_agent(\n",
" ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"), \n",
" 'titanic.csv', \n",
" verbose=True, \n",
" agent_type=AgentType.OPENAI_FUNCTIONS\n",
")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 4,
"id": "16c4dc59",
"metadata": {},
"outputs": [],
"source": [
"agent = create_csv_agent(OpenAI(temperature=0), 'titanic.csv', verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "46b9489d", "id": "46b9489d",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Error in on_chain_start callback: 'name'\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `python_repl_ast` with `df.shape[0]`\n",
"\n", "\n",
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[0m\u001b[36;1m\u001b[1;3m891\u001b[0m\u001b[32;1m\u001b[1;3mThere are 891 rows in the dataframe.\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to count the number of rows\n",
"Action: python_repl_ast\n",
"Action Input: df.shape[0]\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: There are 891 rows.\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -69,10 +114,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'There are 891 rows.'" "'There are 891 rows in the dataframe.'"
] ]
}, },
"execution_count": 5, "execution_count": 4,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -83,23 +128,28 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 5,
"id": "a96309be", "id": "a96309be",
"metadata": {}, "metadata": {
"scrolled": false
},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Error in on_chain_start callback: 'name'\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `python_repl_ast` with `df[df['SibSp'] > 3]['PassengerId'].count()`\n",
"\n", "\n",
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[0m\u001b[36;1m\u001b[1;3m30\u001b[0m\u001b[32;1m\u001b[1;3mThere are 30 people in the dataframe who have more than 3 siblings.\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to count the number of people with more than 3 siblings\n",
"Action: python_repl_ast\n",
"Action Input: df[df['SibSp'] > 3].shape[0]\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m30\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 30 people have more than 3 siblings.\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -107,10 +157,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'30 people have more than 3 siblings.'" "'There are 30 people in the dataframe who have more than 3 siblings.'"
] ]
}, },
"execution_count": 6, "execution_count": 5,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -121,35 +171,39 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 6,
"id": "964a09f7", "id": "964a09f7",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Error in on_chain_start callback: 'name'\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `python_repl_ast` with `import pandas as pd\n",
"import math\n",
"\n",
"# Create a dataframe\n",
"data = {'Age': [22, 38, 26, 35, 35]}\n",
"df = pd.DataFrame(data)\n",
"\n",
"# Calculate the average age\n",
"average_age = df['Age'].mean()\n",
"\n",
"# Calculate the square root of the average age\n",
"square_root = math.sqrt(average_age)\n",
"\n",
"square_root`\n",
"\n", "\n",
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[0m\u001b[36;1m\u001b[1;3m5.585696017507576\u001b[0m\u001b[32;1m\u001b[1;3mThe square root of the average age is approximately 5.59.\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to calculate the average age first\n",
"Action: python_repl_ast\n",
"Action Input: df['Age'].mean()\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m29.69911764705882\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now need to calculate the square root of the average age\n",
"Action: python_repl_ast\n",
"Action Input: math.sqrt(df['Age'].mean())\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mNameError(\"name 'math' is not defined\")\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to import the math library\n",
"Action: python_repl_ast\n",
"Action Input: import math\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now need to calculate the square root of the average age\n",
"Action: python_repl_ast\n",
"Action Input: math.sqrt(df['Age'].mean())\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m5.449689683556195\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 5.449689683556195\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -157,10 +211,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'5.449689683556195'" "'The square root of the average age is approximately 5.59.'"
] ]
}, },
"execution_count": 7, "execution_count": 6,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -181,23 +235,26 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 8,
"id": "15f11fbd", "id": "15f11fbd",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Error in on_chain_start callback: 'name'\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `python_repl_ast` with `df1['Age'].nunique() - df2['Age'].nunique()`\n",
"\n", "\n",
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[0m\u001b[36;1m\u001b[1;3m-1\u001b[0m\u001b[32;1m\u001b[1;3mThere is 1 row in the age column that is different between the two dataframes.\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to compare the age columns in both dataframes\n",
"Action: python_repl_ast\n",
"Action Input: len(df1[df1['Age'] != df2['Age']])\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m177\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 177 rows in the age column are different.\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -205,17 +262,17 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'177 rows in the age column are different.'" "'There is 1 row in the age column that is different between the two dataframes.'"
] ]
}, },
"execution_count": 3, "execution_count": 8,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"agent = create_csv_agent(OpenAI(temperature=0), ['titanic.csv', 'titanic_age_fillna.csv'], verbose=True)\n", "agent = create_csv_agent(ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"), ['titanic.csv', 'titanic_age_fillna.csv'], verbose=True, agent_type=AgentType.OPENAI_FUNCTIONS)\n",
"agent.run(\"how many rows in the age column are different?\")" "agent.run(\"how many rows in the age column are different between the two dfs?\")"
] ]
}, },
{ {

View File

@ -19,7 +19,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.agents import create_pandas_dataframe_agent" "from langchain.agents import create_pandas_dataframe_agent\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.agents.agent_types import AgentType"
] ]
}, },
{ {
@ -35,6 +37,16 @@
"df = pd.read_csv('titanic.csv')" "df = pd.read_csv('titanic.csv')"
] ]
}, },
{
"cell_type": "markdown",
"id": "a62858e2",
"metadata": {},
"source": [
"## Using ZERO_SHOT_REACT_DESCRIPTION\n",
"\n",
"This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
@ -45,9 +57,34 @@
"agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df, verbose=True)" "agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df, verbose=True)"
] ]
}, },
{
"cell_type": "markdown",
"id": "7233ab56",
"metadata": {},
"source": [
"## Using OpenAI Functions\n",
"\n",
"This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 4,
"id": "a8ea710e",
"metadata": {},
"outputs": [],
"source": [
"agent = create_pandas_dataframe_agent(\n",
" ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"), \n",
" df, \n",
" verbose=True, \n",
" agent_type=AgentType.OPENAI_FUNCTIONS\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a9207a2e", "id": "a9207a2e",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -57,13 +94,12 @@
"text": [ "text": [
"\n", "\n",
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to count the number of rows\n", "\u001b[32;1m\u001b[1;3m\n",
"Action: python_repl_ast\n", "Invoking: `python_repl_ast` with `df.shape[0]`\n",
"Action Input: df.shape[0]\u001b[0m\n", "\n",
"Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n", "\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", "\u001b[0m\u001b[36;1m\u001b[1;3m891\u001b[0m\u001b[32;1m\u001b[1;3mThere are 891 rows in the dataframe.\u001b[0m\n",
"Final Answer: There are 891 rows.\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -71,10 +107,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'There are 891 rows.'" "'There are 891 rows in the dataframe.'"
] ]
}, },
"execution_count": 4, "execution_count": 5,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -172,7 +208,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"id": "c4bc0584", "id": "c4bc0584",
"metadata": {}, "metadata": {},
@ -257,7 +292,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.16" "version": "3.9.1"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -12,7 +12,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 1,
"id": "f98e9c90-5c37-4fb9-af3e-d09693af8543", "id": "f98e9c90-5c37-4fb9-af3e-d09693af8543",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -22,12 +22,24 @@
"from langchain.agents.agent_toolkits import create_python_agent\n", "from langchain.agents.agent_toolkits import create_python_agent\n",
"from langchain.tools.python.tool import PythonREPLTool\n", "from langchain.tools.python.tool import PythonREPLTool\n",
"from langchain.python import PythonREPL\n", "from langchain.python import PythonREPL\n",
"from langchain.llms.openai import OpenAI" "from langchain.llms.openai import OpenAI\n",
"from langchain.agents.agent_types import AgentType\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "markdown",
"id": "ca30d64c",
"metadata": {},
"source": [
"## Using ZERO_SHOT_REACT_DESCRIPTION\n",
"\n",
"This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 2,
"id": "cc422f53-c51c-4694-a834-72ecd1e68363", "id": "cc422f53-c51c-4694-a834-72ecd1e68363",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -37,7 +49,34 @@
"agent_executor = create_python_agent(\n", "agent_executor = create_python_agent(\n",
" llm=OpenAI(temperature=0, max_tokens=1000),\n", " llm=OpenAI(temperature=0, max_tokens=1000),\n",
" tool=PythonREPLTool(),\n", " tool=PythonREPLTool(),\n",
" verbose=True\n", " verbose=True,\n",
" agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n",
")"
]
},
{
"cell_type": "markdown",
"id": "bb487e8e",
"metadata": {},
"source": [
"## Using OpenAI Functions\n",
"\n",
"This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6e651822",
"metadata": {},
"outputs": [],
"source": [
"agent_executor = create_python_agent(\n",
" llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"),\n",
" tool=PythonREPLTool(),\n",
" verbose=True,\n",
" agent_type=AgentType.OPENAI_FUNCTIONS,\n",
" agent_executor_kwargs={\"handle_parsing_errors\": True},\n",
")" ")"
] ]
}, },
@ -52,7 +91,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 4,
"id": "25cd4f92-ea9b-4fe6-9838-a4f85f81eebe", "id": "25cd4f92-ea9b-4fe6-9838-a4f85f81eebe",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -64,23 +103,20 @@
"text": [ "text": [
"\n", "\n",
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to calculate the 10th fibonacci number\n", "\u001b[32;1m\u001b[1;3m\n",
"Action: Python REPL\n", "Invoking: `Python_REPL` with `def fibonacci(n):\n",
"Action Input: def fibonacci(n):\n", " if n <= 0:\n",
" if n == 0:\n",
" return 0\n", " return 0\n",
" elif n == 1:\n", " elif n == 1:\n",
" return 1\n", " return 1\n",
" else:\n", " else:\n",
" return fibonacci(n-1) + fibonacci(n-2)\u001b[0m\n", " return fibonacci(n-1) + fibonacci(n-2)\n",
"Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n", "\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to call the function with 10 as the argument\n", "fibonacci(10)`\n",
"Action: Python REPL\n", "\n",
"Action Input: fibonacci(10)\u001b[0m\n", "\n",
"Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n", "\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3mThe 10th Fibonacci number is 55.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 55\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -88,10 +124,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'55'" "'The 10th Fibonacci number is 55.'"
] ]
}, },
"execution_count": 3, "execution_count": 4,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -111,9 +147,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 5,
"id": "4b9f60e7-eb6a-4f14-8604-498d863d4482", "id": "4b9f60e7-eb6a-4f14-8604-498d863d4482",
"metadata": {}, "metadata": {
"scrolled": false
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -121,59 +159,70 @@
"text": [ "text": [
"\n", "\n",
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to write a neural network in PyTorch and train it on the given data.\n", "\u001b[32;1m\u001b[1;3mCould not parse tool input: {'name': 'python', 'arguments': 'import torch\\nimport torch.nn as nn\\nimport torch.optim as optim\\n\\n# Define the neural network\\nclass SingleNeuron(nn.Module):\\n def __init__(self):\\n super(SingleNeuron, self).__init__()\\n self.linear = nn.Linear(1, 1)\\n \\n def forward(self, x):\\n return self.linear(x)\\n\\n# Create the synthetic data\\nx_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)\\ny_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float32)\\n\\n# Create the neural network\\nmodel = SingleNeuron()\\n\\n# Define the loss function and optimizer\\ncriterion = nn.MSELoss()\\noptimizer = optim.SGD(model.parameters(), lr=0.01)\\n\\n# Train the neural network\\nfor epoch in range(1, 1001):\\n # Forward pass\\n y_pred = model(x_train)\\n \\n # Compute loss\\n loss = criterion(y_pred, y_train)\\n \\n # Backward pass and optimization\\n optimizer.zero_grad()\\n loss.backward()\\n optimizer.step()\\n \\n # Print the loss every 100 epochs\\n if epoch % 100 == 0:\\n print(f\"Epoch {epoch}: Loss = {loss.item()}\")\\n\\n# Make a prediction for x = 5\\nx_test = torch.tensor([[5.0]], dtype=torch.float32)\\ny_pred = model(x_test)\\ny_pred.item()'} because the `arguments` is not valid JSON.\u001b[0mInvalid or incomplete response\u001b[32;1m\u001b[1;3m\n",
"Action: Python REPL\n", "Invoking: `Python_REPL` with `import torch\n",
"Action Input: \n", "import torch.nn as nn\n",
"import torch\n", "import torch.optim as optim\n",
"\n", "\n",
"# Define the model\n", "# Define the neural network\n",
"model = torch.nn.Sequential(\n", "class SingleNeuron(nn.Module):\n",
" torch.nn.Linear(1, 1)\n", " def __init__(self):\n",
")\n", " super(SingleNeuron, self).__init__()\n",
" self.linear = nn.Linear(1, 1)\n",
" \n",
" def forward(self, x):\n",
" return self.linear(x)\n",
"\n", "\n",
"# Define the loss\n", "# Create the synthetic data\n",
"loss_fn = torch.nn.MSELoss()\n", "x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)\n",
"y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float32)\n",
"\n", "\n",
"# Define the optimizer\n", "# Create the neural network\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n", "model = SingleNeuron()\n",
"\n", "\n",
"# Define the data\n", "# Define the loss function and optimizer\n",
"x_data = torch.tensor([[1.0], [2.0], [3.0], [4.0]])\n", "criterion = nn.MSELoss()\n",
"y_data = torch.tensor([[2.0], [4.0], [6.0], [8.0]])\n", "optimizer = optim.SGD(model.parameters(), lr=0.01)\n",
"\n", "\n",
"# Train the model\n", "# Train the neural network\n",
"for epoch in range(1000):\n", "for epoch in range(1, 1001):\n",
" # Forward pass\n", " # Forward pass\n",
" y_pred = model(x_data)\n", " y_pred = model(x_train)\n",
"\n", " \n",
" # Compute and print loss\n", " # Compute loss\n",
" loss = loss_fn(y_pred, y_data)\n", " loss = criterion(y_pred, y_train)\n",
" if (epoch+1) % 100 == 0:\n", " \n",
" print(f'Epoch {epoch+1}: loss = {loss.item():.4f}')\n", " # Backward pass and optimization\n",
"\n",
" # Zero the gradients\n",
" optimizer.zero_grad()\n", " optimizer.zero_grad()\n",
"\n",
" # Backward pass\n",
" loss.backward()\n", " loss.backward()\n",
"\n",
" # Update the weights\n",
" optimizer.step()\n", " optimizer.step()\n",
"\u001b[0m\n", " \n",
"Observation: \u001b[36;1m\u001b[1;3mEpoch 100: loss = 0.0013\n", " # Print the loss every 100 epochs\n",
"Epoch 200: loss = 0.0007\n", " if epoch % 100 == 0:\n",
"Epoch 300: loss = 0.0004\n", " print(f\"Epoch {epoch}: Loss = {loss.item()}\")\n",
"Epoch 400: loss = 0.0002\n", "\n",
"Epoch 500: loss = 0.0001\n", "# Make a prediction for x = 5\n",
"Epoch 600: loss = 0.0001\n", "x_test = torch.tensor([[5.0]], dtype=torch.float32)\n",
"Epoch 700: loss = 0.0000\n", "y_pred = model(x_test)\n",
"Epoch 800: loss = 0.0000\n", "y_pred.item()`\n",
"Epoch 900: loss = 0.0000\n", "\n",
"Epoch 1000: loss = 0.0000\n", "\n",
"\u001b[0m\n", "\u001b[0m\u001b[36;1m\u001b[1;3mEpoch 100: Loss = 0.03825576975941658\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Epoch 200: Loss = 0.02100197970867157\n",
"Final Answer: The prediction for x = 5 is 10.0.\u001b[0m\n", "Epoch 300: Loss = 0.01152981910854578\n",
"Epoch 400: Loss = 0.006329738534986973\n",
"Epoch 500: Loss = 0.0034749575424939394\n",
"Epoch 600: Loss = 0.0019077073084190488\n",
"Epoch 700: Loss = 0.001047312980517745\n",
"Epoch 800: Loss = 0.0005749554838985205\n",
"Epoch 900: Loss = 0.0003156439634039998\n",
"Epoch 1000: Loss = 0.00017328384274151176\n",
"\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `Python_REPL` with `x_test.item()`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3mThe prediction for x = 5 is 10.000173568725586.\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -181,10 +230,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'The prediction for x = 5 is 10.0.'" "'The prediction for x = 5 is 10.000173568725586.'"
] ]
}, },
"execution_count": 4, "execution_count": 5,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -206,9 +255,9 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "LangChain", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "langchain" "name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
@ -220,7 +269,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.16" "version": "3.9.1"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -37,7 +37,30 @@
"from langchain.agents.agent_toolkits import SQLDatabaseToolkit\n", "from langchain.agents.agent_toolkits import SQLDatabaseToolkit\n",
"from langchain.sql_database import SQLDatabase\n", "from langchain.sql_database import SQLDatabase\n",
"from langchain.llms.openai import OpenAI\n", "from langchain.llms.openai import OpenAI\n",
"from langchain.agents import AgentExecutor" "from langchain.agents import AgentExecutor\n",
"from langchain.agents.agent_types import AgentType\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "65ec5bb3",
"metadata": {},
"outputs": [],
"source": [
"db = SQLDatabase.from_uri(\"sqlite:///../../../../../notebooks/Chinook.db\")\n",
"toolkit = SQLDatabaseToolkit(db=db, llm=OpenAI(temperature=0))\n"
]
},
{
"cell_type": "markdown",
"id": "f74d1792",
"metadata": {},
"source": [
"## Using ZERO_SHOT_REACT_DESCRIPTION\n",
"\n",
"This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above."
] ]
}, },
{ {
@ -49,16 +72,39 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"db = SQLDatabase.from_uri(\"sqlite:///../../../../notebooks/Chinook.db\")\n",
"toolkit = SQLDatabaseToolkit(db=db)\n",
"\n",
"agent_executor = create_sql_agent(\n", "agent_executor = create_sql_agent(\n",
" llm=OpenAI(temperature=0),\n", " llm=OpenAI(temperature=0),\n",
" toolkit=toolkit,\n", " toolkit=toolkit,\n",
" verbose=True\n", " verbose=True,\n",
" agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n",
")" ")"
] ]
}, },
{
"cell_type": "markdown",
"id": "971cc455",
"metadata": {},
"source": [
"## Using OpenAI Functions\n",
"\n",
"This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6426a27d",
"metadata": {},
"outputs": [],
"source": [
"# agent_executor = create_sql_agent(\n",
"# llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"),\n",
"# toolkit=toolkit,\n",
"# verbose=True,\n",
"# agent_type=AgentType.OPENAI_FUNCTIONS\n",
"# )"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "36ae48c7-cb08-4fef-977e-c7d4b96a464b", "id": "36ae48c7-cb08-4fef-977e-c7d4b96a464b",
@ -69,7 +115,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 4,
"id": "ff70e83d-5ad0-4fc7-bb96-27d82ac166d7", "id": "ff70e83d-5ad0-4fc7-bb96-27d82ac166d7",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -81,14 +127,16 @@
"text": [ "text": [
"\n", "\n",
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mAction: list_tables_sql_db\n", "\u001b[32;1m\u001b[1;3m\n",
"Action Input: \"\"\u001b[0m\n", "Invoking: `list_tables_sql_db` with `{}`\n",
"Observation: \u001b[38;5;200m\u001b[1;3mArtist, Invoice, Playlist, Genre, Album, PlaylistTrack, Track, InvoiceLine, MediaType, Employee, Customer\u001b[0m\n", "\n",
"Thought:\u001b[32;1m\u001b[1;3m I should look at the schema of the playlisttrack table\n", "\n",
"Action: schema_sql_db\n", "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Track, PlaylistTrack, InvoiceLine, sales_table, Playlist, Genre, Employee, Customer, Invoice, MediaType\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Action Input: \"PlaylistTrack\"\u001b[0m\n", "Invoking: `schema_sql_db` with `PlaylistTrack`\n",
"Observation: \u001b[33;1m\u001b[1;3m\n", "\n",
"\n",
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"PlaylistTrack\" (\n", "CREATE TABLE \"PlaylistTrack\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n", "\t\"PlaylistId\" INTEGER NOT NULL, \n",
"\t\"TrackId\" INTEGER NOT NULL, \n", "\t\"TrackId\" INTEGER NOT NULL, \n",
@ -97,13 +145,36 @@
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n", "\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
")\n", ")\n",
"\n", "\n",
"SELECT * FROM 'PlaylistTrack' LIMIT 3;\n", "/*\n",
"PlaylistId TrackId\n", "3 rows from PlaylistTrack table:\n",
"1 3402\n", "PlaylistId\tTrackId\n",
"1 3389\n", "1\t3402\n",
"1 3390\u001b[0m\n", "1\t3389\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", "1\t3390\n",
"Final Answer: The PlaylistTrack table has two columns, PlaylistId and TrackId, and is linked to the Playlist and Track tables.\u001b[0m\n", "*/\u001b[0m\u001b[32;1m\u001b[1;3mThe `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the relationship between playlists and tracks. \n",
"\n",
"Here is the schema of the `PlaylistTrack` table:\n",
"\n",
"```\n",
"CREATE TABLE \"PlaylistTrack\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n",
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
")\n",
"```\n",
"\n",
"Here are three sample rows from the `PlaylistTrack` table:\n",
"\n",
"```\n",
"PlaylistId TrackId\n",
"1 3402\n",
"1 3389\n",
"1 3390\n",
"```\n",
"\n",
"Please let me know if there is anything else I can help you with.\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -111,10 +182,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'The PlaylistTrack table has two columns, PlaylistId and TrackId, and is linked to the Playlist and Track tables.'" "'The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the relationship between playlists and tracks. \\n\\nHere is the schema of the `PlaylistTrack` table:\\n\\n```\\nCREATE TABLE \"PlaylistTrack\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\n)\\n```\\n\\nHere are three sample rows from the `PlaylistTrack` table:\\n\\n```\\nPlaylistId TrackId\\n1 3402\\n1 3389\\n1 3390\\n```\\n\\nPlease let me know if there is anything else I can help you with.'"
] ]
}, },
"execution_count": 3, "execution_count": 4,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -519,7 +590,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.9" "version": "3.9.1"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -1,19 +1,26 @@
"""Agent for working with pandas objects.""" """Agent for working with pandas objects."""
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.pandas.prompt import ( from langchain.agents.agent_toolkits.pandas.prompt import (
FUNCTIONS_WITH_DF,
FUNCTIONS_WITH_MULTI_DF,
MULTI_DF_PREFIX, MULTI_DF_PREFIX,
MULTI_DF_PREFIX_FUNCTIONS,
PREFIX, PREFIX,
PREFIX_FUNCTIONS,
SUFFIX_NO_DF, SUFFIX_NO_DF,
SUFFIX_WITH_DF, SUFFIX_WITH_DF,
SUFFIX_WITH_MULTI_DF, SUFFIX_WITH_MULTI_DF,
) )
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.agents.types import AgentType
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import SystemMessage
from langchain.tools.python.tool import PythonAstREPLTool from langchain.tools.python.tool import PythonAstREPLTool
@ -137,9 +144,108 @@ def _get_prompt_and_tools(
) )
def _get_functions_single_prompt(
df: Any,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
include_df_in_prompt: Optional[bool] = True,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
if suffix is not None:
suffix_to_use = suffix
if include_df_in_prompt:
suffix_to_use = suffix_to_use.format(df_head=str(df.head().to_markdown()))
elif include_df_in_prompt:
suffix_to_use = FUNCTIONS_WITH_DF.format(df_head=str(df.head().to_markdown()))
else:
suffix_to_use = ""
if prefix is None:
prefix = PREFIX_FUNCTIONS
tools = [PythonAstREPLTool(locals={"df": df})]
system_message = SystemMessage(content=prefix + suffix_to_use)
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
return prompt, tools
def _get_functions_multi_prompt(
dfs: Any,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
include_df_in_prompt: Optional[bool] = True,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
if suffix is not None:
suffix_to_use = suffix
if include_df_in_prompt:
dfs_head = "\n\n".join([d.head().to_markdown() for d in dfs])
suffix_to_use = suffix_to_use.format(
dfs_head=dfs_head,
)
elif include_df_in_prompt:
dfs_head = "\n\n".join([d.head().to_markdown() for d in dfs])
suffix_to_use = FUNCTIONS_WITH_MULTI_DF.format(
dfs_head=dfs_head,
)
else:
suffix_to_use = ""
if prefix is None:
prefix = MULTI_DF_PREFIX_FUNCTIONS
prefix = prefix.format(num_dfs=str(len(dfs)))
df_locals = {}
for i, dataframe in enumerate(dfs):
df_locals[f"df{i + 1}"] = dataframe
tools = [PythonAstREPLTool(locals=df_locals)]
system_message = SystemMessage(content=prefix + suffix_to_use)
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
return prompt, tools
def _get_functions_prompt_and_tools(
df: Any,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
try:
import pandas as pd
except ImportError:
raise ValueError(
"pandas package not found, please install with `pip install pandas`"
)
if input_variables is not None:
raise ValueError("`input_variables` is not supported at the moment.")
if include_df_in_prompt is not None and suffix is not None:
raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
if isinstance(df, list):
for item in df:
if not isinstance(item, pd.DataFrame):
raise ValueError(f"Expected pandas object, got {type(df)}")
return _get_functions_multi_prompt(
df,
prefix=prefix,
suffix=suffix,
include_df_in_prompt=include_df_in_prompt,
)
else:
if not isinstance(df, pd.DataFrame):
raise ValueError(f"Expected pandas object, got {type(df)}")
return _get_functions_single_prompt(
df,
prefix=prefix,
suffix=suffix,
include_df_in_prompt=include_df_in_prompt,
)
def create_pandas_dataframe_agent( def create_pandas_dataframe_agent(
llm: BaseLanguageModel, llm: BaseLanguageModel,
df: Any, df: Any,
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
suffix: Optional[str] = None, suffix: Optional[str] = None,
@ -154,26 +260,44 @@ def create_pandas_dataframe_agent(
**kwargs: Dict[str, Any], **kwargs: Dict[str, Any],
) -> AgentExecutor: ) -> AgentExecutor:
"""Construct a pandas agent from an LLM and dataframe.""" """Construct a pandas agent from an LLM and dataframe."""
agent: BaseSingleActionAgent
prompt, tools = _get_prompt_and_tools( if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
df, prompt, tools = _get_prompt_and_tools(
prefix=prefix, df,
suffix=suffix, prefix=prefix,
input_variables=input_variables, suffix=suffix,
include_df_in_prompt=include_df_in_prompt, input_variables=input_variables,
) include_df_in_prompt=include_df_in_prompt,
llm_chain = LLMChain( )
llm=llm, llm_chain = LLMChain(
prompt=prompt, llm=llm,
callback_manager=callback_manager, prompt=prompt,
) callback_manager=callback_manager,
tool_names = [tool.name for tool in tools] )
agent = ZeroShotAgent( tool_names = [tool.name for tool in tools]
llm_chain=llm_chain, agent = ZeroShotAgent(
allowed_tools=tool_names, llm_chain=llm_chain,
callback_manager=callback_manager, allowed_tools=tool_names,
**kwargs, callback_manager=callback_manager,
) **kwargs,
)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
_prompt, tools = _get_functions_prompt_and_tools(
df,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
return AgentExecutor.from_agent_and_tools( return AgentExecutor.from_agent_and_tools(
agent=agent, agent=agent,
tools=tools, tools=tools,

View File

@ -28,3 +28,17 @@ This is the result of `print(df.head())` for each dataframe:
Begin! Begin!
Question: {input} Question: {input}
{agent_scratchpad}""" {agent_scratchpad}"""
PREFIX_FUNCTIONS = """
You are working with a pandas dataframe in Python. The name of the dataframe is `df`."""
MULTI_DF_PREFIX_FUNCTIONS = """
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc."""
FUNCTIONS_WITH_DF = """
This is the result of `print(df.head())`:
{df_head}"""
FUNCTIONS_WITH_MULTI_DF = """
This is the result of `print(df.head())` for each dataframe:
{dfs_head}"""

View File

@ -2,18 +2,22 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.python.prompt import PREFIX from langchain.agents.agent_toolkits.python.prompt import PREFIX
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.agents.types import AgentType
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.schema import SystemMessage
from langchain.tools.python.tool import PythonREPLTool from langchain.tools.python.tool import PythonREPLTool
def create_python_agent( def create_python_agent(
llm: BaseLanguageModel, llm: BaseLanguageModel,
tool: PythonREPLTool, tool: PythonREPLTool,
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
verbose: bool = False, verbose: bool = False,
prefix: str = PREFIX, prefix: str = PREFIX,
@ -22,14 +26,29 @@ def create_python_agent(
) -> AgentExecutor: ) -> AgentExecutor:
"""Construct a python agent from an LLM and tool.""" """Construct a python agent from an LLM and tool."""
tools = [tool] tools = [tool]
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix) agent: BaseSingleActionAgent
llm_chain = LLMChain(
llm=llm, if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt=prompt, prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
callback_manager=callback_manager, llm_chain = LLMChain(
) llm=llm,
tool_names = [tool.name for tool in tools] prompt=prompt,
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
system_message = SystemMessage(content=prefix)
_prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
return AgentExecutor.from_agent_and_tools( return AgentExecutor.from_agent_and_tools(
agent=agent, agent=agent,
tools=tools, tools=tools,

View File

@ -1,22 +1,35 @@
"""SQL agent.""" """SQL agent."""
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX from langchain.agents.agent_toolkits.sql.prompt import (
SQL_FUNCTIONS_SUFFIX,
SQL_PREFIX,
SQL_SUFFIX,
)
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.schema import AIMessage, SystemMessage
def create_sql_agent( def create_sql_agent(
llm: BaseLanguageModel, llm: BaseLanguageModel,
toolkit: SQLDatabaseToolkit, toolkit: SQLDatabaseToolkit,
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = SQL_PREFIX, prefix: str = SQL_PREFIX,
suffix: str = SQL_SUFFIX, suffix: Optional[str] = None,
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None, input_variables: Optional[List[str]] = None,
top_k: int = 10, top_k: int = 10,
@ -30,20 +43,44 @@ def create_sql_agent(
"""Construct a sql agent from an LLM and tools.""" """Construct a sql agent from an LLM and tools."""
tools = toolkit.get_tools() tools = toolkit.get_tools()
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k) prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
prompt = ZeroShotAgent.create_prompt( agent: BaseSingleActionAgent
tools,
prefix=prefix, if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
suffix=suffix, prompt = ZeroShotAgent.create_prompt(
format_instructions=format_instructions, tools,
input_variables=input_variables, prefix=prefix,
) suffix=suffix or SQL_SUFFIX,
llm_chain = LLMChain( format_instructions=format_instructions,
llm=llm, input_variables=input_variables,
prompt=prompt, )
callback_manager=callback_manager, llm_chain = LLMChain(
) llm=llm,
tool_names = [tool.name for tool in tools] prompt=prompt,
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
messages = [
SystemMessage(content=prefix),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
input_variables = ["input", "agent_scratchpad"]
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
return AgentExecutor.from_agent_and_tools( return AgentExecutor.from_agent_and_tools(
agent=agent, agent=agent,
tools=tools, tools=tools,

View File

@ -19,3 +19,5 @@ SQL_SUFFIX = """Begin!
Question: {input} Question: {input}
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables. Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
{agent_scratchpad}""" {agent_scratchpad}"""
SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."""

View File

@ -7,12 +7,11 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
from langchain.agents import BaseSingleActionAgent from langchain.agents import BaseSingleActionAgent
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import ( from langchain.callbacks.manager import Callbacks
Callbacks,
)
from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import ( from langchain.prompts.chat import (
BaseMessagePromptTemplate,
ChatPromptTemplate, ChatPromptTemplate,
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
MessagesPlaceholder, MessagesPlaceholder,
@ -23,6 +22,7 @@ from langchain.schema import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
FunctionMessage, FunctionMessage,
OutputParserException,
SystemMessage, SystemMessage,
) )
from langchain.tools import BaseTool from langchain.tools import BaseTool
@ -34,7 +34,9 @@ class _FunctionsAgentAction(AgentAction):
message_log: List[BaseMessage] message_log: List[BaseMessage]
def _convert_agent_action_to_messages(agent_action: AgentAction) -> List[BaseMessage]: def _convert_agent_action_to_messages(
agent_action: AgentAction, observation: str
) -> List[BaseMessage]:
"""Convert an agent action to a message. """Convert an agent action to a message.
This code is used to reconstruct the original AI message from the agent action. This code is used to reconstruct the original AI message from the agent action.
@ -45,9 +47,12 @@ def _convert_agent_action_to_messages(agent_action: AgentAction) -> List[BaseMes
Returns: Returns:
AIMessage that corresponds to the original tool invocation. AIMessage that corresponds to the original tool invocation.
""" """
if not isinstance(agent_action, _FunctionsAgentAction): if isinstance(agent_action, _FunctionsAgentAction):
raise ValueError("This agent type only works with _FunctionsAgentAction") return agent_action.message_log + [
return agent_action.message_log _create_function_message(agent_action, observation)
]
else:
return [AIMessage(content=agent_action.log)]
def _create_function_message( def _create_function_message(
@ -61,7 +66,10 @@ def _create_function_message(
FunctionMessage that corresponds to the original tool invocation FunctionMessage that corresponds to the original tool invocation
""" """
if not isinstance(observation, str): if not isinstance(observation, str):
content = json.dumps(observation) try:
content = json.dumps(observation)
except Exception:
content = str(observation)
else: else:
content = observation content = observation
return FunctionMessage( return FunctionMessage(
@ -83,8 +91,7 @@ def _format_intermediate_steps(
for intermediate_step in intermediate_steps: for intermediate_step in intermediate_steps:
agent_action, observation = intermediate_step agent_action, observation = intermediate_step
messages.extend(_convert_agent_action_to_messages(agent_action)) messages.extend(_convert_agent_action_to_messages(agent_action, observation))
messages.append(_create_function_message(agent_action, observation))
return messages return messages
@ -102,7 +109,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
try: try:
_tool_input = json.loads(function_call["arguments"]) _tool_input = json.loads(function_call["arguments"])
except JSONDecodeError: except JSONDecodeError:
raise ValueError( raise OutputParserException(
f"Could not parse tool input: {function_call} because " f"Could not parse tool input: {function_call} because "
f"the `arguments` is not valid JSON." f"the `arguments` is not valid JSON."
) )
@ -202,12 +209,24 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
return agent_decision return agent_decision
@classmethod @classmethod
def create_prompt(cls) -> BasePromptTemplate: def create_prompt(
messages = [ cls,
SystemMessage(content="You are a helpful AI assistant."), system_message: Optional[SystemMessage] = SystemMessage(
HumanMessagePromptTemplate.from_template("{input}"), content="You are a helpful AI assistant."
MessagesPlaceholder(variable_name="agent_scratchpad"), ),
] ) -> BasePromptTemplate:
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message:
messages = [system_message]
else:
messages = []
messages.extend(
[
HumanMessagePromptTemplate.from_template("{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
input_variables = ["input", "agent_scratchpad"] input_variables = ["input", "agent_scratchpad"]
return ChatPromptTemplate(input_variables=input_variables, messages=messages) return ChatPromptTemplate(input_variables=input_variables, messages=messages)

View File

@ -15,7 +15,7 @@ from langchain.utilities.arxiv import ArxivAPIWrapper
class ArxivQueryRun(BaseTool): class ArxivQueryRun(BaseTool):
"""Tool that adds the capability to search using the Arxiv API.""" """Tool that adds the capability to search using the Arxiv API."""
name = "Arxiv" name = "arxiv"
description = ( description = (
"A wrapper around Arxiv.org " "A wrapper around Arxiv.org "
"Useful for when you need to answer questions about Physics, Mathematics, " "Useful for when you need to answer questions about Physics, Mathematics, "

View File

@ -27,7 +27,7 @@ class AzureCogsFormRecognizerTool(BaseTool):
azure_cogs_endpoint: str = "" #: :meta private: azure_cogs_endpoint: str = "" #: :meta private:
doc_analysis_client: Any #: :meta private: doc_analysis_client: Any #: :meta private:
name = "Azure Cognitive Services Form Recognizer" name = "azure_cognitive_services_form_recognizer"
description = ( description = (
"A wrapper around Azure Cognitive Services Form Recognizer. " "A wrapper around Azure Cognitive Services Form Recognizer. "
"Useful for when you need to " "Useful for when you need to "

View File

@ -28,7 +28,7 @@ class AzureCogsImageAnalysisTool(BaseTool):
vision_service: Any #: :meta private: vision_service: Any #: :meta private:
analysis_options: Any #: :meta private: analysis_options: Any #: :meta private:
name = "Azure Cognitive Services Image Analysis" name = "azure_cognitive_services_image_analysis"
description = ( description = (
"A wrapper around Azure Cognitive Services Image Analysis. " "A wrapper around Azure Cognitive Services Image Analysis. "
"Useful for when you need to analyze images. " "Useful for when you need to analyze images. "

View File

@ -32,7 +32,7 @@ class AzureCogsSpeech2TextTool(BaseTool):
speech_language: str = "en-US" #: :meta private: speech_language: str = "en-US" #: :meta private:
speech_config: Any #: :meta private: speech_config: Any #: :meta private:
name = "Azure Cognitive Services Speech2Text" name = "azure_cognitive_services_speech2text"
description = ( description = (
"A wrapper around Azure Cognitive Services Speech2Text. " "A wrapper around Azure Cognitive Services Speech2Text. "
"Useful for when you need to transcribe audio to text. " "Useful for when you need to transcribe audio to text. "

View File

@ -28,7 +28,7 @@ class AzureCogsText2SpeechTool(BaseTool):
speech_language: str = "en-US" #: :meta private: speech_language: str = "en-US" #: :meta private:
speech_config: Any #: :meta private: speech_config: Any #: :meta private:
name = "Azure Cognitive Services Text2Speech" name = "azure_cognitive_services_text2speech"
description = ( description = (
"A wrapper around Azure Cognitive Services Text2Speech. " "A wrapper around Azure Cognitive Services Text2Speech. "
"Useful for when you need to convert text to speech. " "Useful for when you need to convert text to speech. "

View File

@ -13,7 +13,7 @@ from langchain.utilities.bing_search import BingSearchAPIWrapper
class BingSearchRun(BaseTool): class BingSearchRun(BaseTool):
"""Tool that adds the capability to query the Bing search API.""" """Tool that adds the capability to query the Bing search API."""
name = "Bing Search" name = "bing_search"
description = ( description = (
"A wrapper around Bing Search. " "A wrapper around Bing Search. "
"Useful for when you need to answer questions about current events. " "Useful for when you need to answer questions about current events. "

View File

@ -11,7 +11,7 @@ from langchain.utilities.brave_search import BraveSearchWrapper
class BraveSearch(BaseTool): class BraveSearch(BaseTool):
name = "brave-search" name = "brave_search"
description = ( description = (
"a search engine. " "a search engine. "
"useful for when you need to answer questions about current events." "useful for when you need to answer questions about current events."

View File

@ -16,7 +16,7 @@ from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
class DuckDuckGoSearchRun(BaseTool): class DuckDuckGoSearchRun(BaseTool):
"""Tool that adds the capability to query the DuckDuckGo search API.""" """Tool that adds the capability to query the DuckDuckGo search API."""
name = "DuckDuckGo Search" name = "duckduckgo_search"
description = ( description = (
"A wrapper around DuckDuckGo Search. " "A wrapper around DuckDuckGo Search. "
"Useful for when you need to answer questions about current events. " "Useful for when you need to answer questions about current events. "

View File

@ -19,7 +19,7 @@ class GooglePlacesSchema(BaseModel):
class GooglePlacesTool(BaseTool): class GooglePlacesTool(BaseTool):
"""Tool that adds the capability to query the Google places API.""" """Tool that adds the capability to query the Google places API."""
name = "Google Places" name = "google_places"
description = ( description = (
"A wrapper around Google Places. " "A wrapper around Google Places. "
"Useful for when you need to validate or " "Useful for when you need to validate or "

View File

@ -13,7 +13,7 @@ from langchain.utilities.google_search import GoogleSearchAPIWrapper
class GoogleSearchRun(BaseTool): class GoogleSearchRun(BaseTool):
"""Tool that adds the capability to query the Google search API.""" """Tool that adds the capability to query the Google search API."""
name = "Google Search" name = "google_search"
description = ( description = (
"A wrapper around Google Search. " "A wrapper around Google Search. "
"Useful for when you need to answer questions about current events. " "Useful for when you need to answer questions about current events. "

View File

@ -15,7 +15,7 @@ from langchain.utilities.google_serper import GoogleSerperAPIWrapper
class GoogleSerperRun(BaseTool): class GoogleSerperRun(BaseTool):
"""Tool that adds the capability to query the Serper.dev Google search API.""" """Tool that adds the capability to query the Serper.dev Google search API."""
name = "Google Serper" name = "google_serper"
description = ( description = (
"A low-cost Google Search API." "A low-cost Google Search API."
"Useful for when you need to answer questions about current events." "Useful for when you need to answer questions about current events."

View File

@ -19,7 +19,7 @@ def _print_func(text: str) -> None:
class HumanInputRun(BaseTool): class HumanInputRun(BaseTool):
"""Tool that adds the capability to ask user for input.""" """Tool that adds the capability to ask user for input."""
name = "Human" name = "human"
description = ( description = (
"You can ask a human for guidance when you think you " "You can ask a human for guidance when you think you "
"got stuck or you are not sure what to do next. " "got stuck or you are not sure what to do next. "

View File

@ -13,7 +13,7 @@ from langchain.utilities.metaphor_search import MetaphorSearchAPIWrapper
class MetaphorSearchResults(BaseTool): class MetaphorSearchResults(BaseTool):
"""Tool that has capability to query the Metaphor Search API and get back json.""" """Tool that has capability to query the Metaphor Search API and get back json."""
name = "Metaphor Search Results JSON" name = "metaphor_search_results_json"
description = ( description = (
"A wrapper around Metaphor Search. " "A wrapper around Metaphor Search. "
"Input should be a Metaphor-optimized query. " "Input should be a Metaphor-optimized query. "

View File

@ -34,7 +34,7 @@ def sanitize_input(query: str) -> str:
class PythonREPLTool(BaseTool): class PythonREPLTool(BaseTool):
"""A tool for running python code in a REPL.""" """A tool for running python code in a REPL."""
name = "Python REPL" name = "Python_REPL"
description = ( description = (
"A Python shell. Use this to execute python commands. " "A Python shell. Use this to execute python commands. "
"Input should be a valid python command. " "Input should be a valid python command. "

View File

@ -20,7 +20,7 @@ class SceneXplainInput(BaseModel):
class SceneXplainTool(BaseTool): class SceneXplainTool(BaseTool):
"""Tool that adds the capability to explain images.""" """Tool that adds the capability to explain images."""
name = "Image Explainer" name = "image_explainer"
description = ( description = (
"An Image Captioning Tool: Use this tool to generate a detailed caption " "An Image Captioning Tool: Use this tool to generate a detailed caption "
"for an image. The input can be an image file of any format, and " "for an image. The input can be an image file of any format, and "

View File

@ -14,7 +14,7 @@ from langchain.utilities.searx_search import SearxSearchWrapper
class SearxSearchRun(BaseTool): class SearxSearchRun(BaseTool):
"""Tool that adds the capability to query a Searx instance.""" """Tool that adds the capability to query a Searx instance."""
name = "Searx Search" name = "searx_search"
description = ( description = (
"A meta search engine." "A meta search engine."
"Useful for when you need to answer questions about current events." "Useful for when you need to answer questions about current events."

View File

@ -13,7 +13,7 @@ from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
class WolframAlphaQueryRun(BaseTool): class WolframAlphaQueryRun(BaseTool):
"""Tool that adds the capability to query using the Wolfram Alpha SDK.""" """Tool that adds the capability to query using the Wolfram Alpha SDK."""
name = "Wolfram Alpha" name = "wolfram_alpha"
description = ( description = (
"A wrapper around Wolfram Alpha. " "A wrapper around Wolfram Alpha. "
"Useful for when you need to answer questions about Math, " "Useful for when you need to answer questions about Math, "

View File

@ -19,7 +19,7 @@ from langchain.tools import BaseTool
class YouTubeSearchTool(BaseTool): class YouTubeSearchTool(BaseTool):
name = "YouTubeSearch" name = "youtube_search"
description = ( description = (
"search for youtube videos associated with a person. " "search for youtube videos associated with a person. "
"the input to this tool should be a comma separated list, " "the input to this tool should be a comma separated list, "

View File

@ -164,7 +164,7 @@ class ZapierNLAListActions(BaseTool):
""" """
name = "Zapier NLA: List Actions" name = "ZapierNLA_list_actions"
description = BASE_ZAPIER_TOOL_PROMPT + ( description = BASE_ZAPIER_TOOL_PROMPT + (
"This tool returns a list of the user's exposed actions." "This tool returns a list of the user's exposed actions."
) )