Harrison/use functions agent (#6185)

Co-authored-by: Francisco Ingham <24279597+fpingham@users.noreply.github.com>
searx_updates
Harrison Chase 11 months ago committed by GitHub
parent 7d2b946d0b
commit e82687ddf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,38 +30,83 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import OpenAI"
"from langchain.llms import OpenAI\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.agents.agent_types import AgentType"
]
},
{
"cell_type": "markdown",
"id": "bd806175",
"metadata": {},
"source": [
"## Using ZERO_SHOT_REACT_DESCRIPTION\n",
"\n",
"This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above."
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 9,
"id": "a1717204",
"metadata": {},
"outputs": [],
"source": [
"agent = create_csv_agent(\n",
" OpenAI(temperature=0), \n",
" 'titanic.csv', \n",
" verbose=True, \n",
" agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c31bb8a6",
"metadata": {},
"source": [
"## Using OpenAI Functions\n",
"\n",
"This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "16c4dc59",
"metadata": {},
"outputs": [],
"source": [
"agent = create_csv_agent(OpenAI(temperature=0), 'titanic.csv', verbose=True)"
"agent = create_csv_agent(\n",
" ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"), \n",
" 'titanic.csv', \n",
" verbose=True, \n",
" agent_type=AgentType.OPENAI_FUNCTIONS\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"id": "46b9489d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Error in on_chain_start callback: 'name'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `python_repl_ast` with `df.shape[0]`\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to count the number of rows\n",
"Action: python_repl_ast\n",
"Action Input: df.shape[0]\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: There are 891 rows.\u001b[0m\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m891\u001b[0m\u001b[32;1m\u001b[1;3mThere are 891 rows in the dataframe.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -69,10 +114,10 @@
{
"data": {
"text/plain": [
"'There are 891 rows.'"
"'There are 891 rows in the dataframe.'"
]
},
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -83,23 +128,28 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"id": "a96309be",
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Error in on_chain_start callback: 'name'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `python_repl_ast` with `df[df['SibSp'] > 3]['PassengerId'].count()`\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to count the number of people with more than 3 siblings\n",
"Action: python_repl_ast\n",
"Action Input: df[df['SibSp'] > 3].shape[0]\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m30\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 30 people have more than 3 siblings.\u001b[0m\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m30\u001b[0m\u001b[32;1m\u001b[1;3mThere are 30 people in the dataframe who have more than 3 siblings.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -107,10 +157,10 @@
{
"data": {
"text/plain": [
"'30 people have more than 3 siblings.'"
"'There are 30 people in the dataframe who have more than 3 siblings.'"
]
},
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@ -121,35 +171,39 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"id": "964a09f7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Error in on_chain_start callback: 'name'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `python_repl_ast` with `import pandas as pd\n",
"import math\n",
"\n",
"# Create a dataframe\n",
"data = {'Age': [22, 38, 26, 35, 35]}\n",
"df = pd.DataFrame(data)\n",
"\n",
"# Calculate the average age\n",
"average_age = df['Age'].mean()\n",
"\n",
"# Calculate the square root of the average age\n",
"square_root = math.sqrt(average_age)\n",
"\n",
"square_root`\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to calculate the average age first\n",
"Action: python_repl_ast\n",
"Action Input: df['Age'].mean()\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m29.69911764705882\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now need to calculate the square root of the average age\n",
"Action: python_repl_ast\n",
"Action Input: math.sqrt(df['Age'].mean())\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mNameError(\"name 'math' is not defined\")\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to import the math library\n",
"Action: python_repl_ast\n",
"Action Input: import math\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now need to calculate the square root of the average age\n",
"Action: python_repl_ast\n",
"Action Input: math.sqrt(df['Age'].mean())\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m5.449689683556195\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 5.449689683556195\u001b[0m\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m5.585696017507576\u001b[0m\u001b[32;1m\u001b[1;3mThe square root of the average age is approximately 5.59.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -157,10 +211,10 @@
{
"data": {
"text/plain": [
"'5.449689683556195'"
"'The square root of the average age is approximately 5.59.'"
]
},
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -181,23 +235,26 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 8,
"id": "15f11fbd",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Error in on_chain_start callback: 'name'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `python_repl_ast` with `df1['Age'].nunique() - df2['Age'].nunique()`\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to compare the age columns in both dataframes\n",
"Action: python_repl_ast\n",
"Action Input: len(df1[df1['Age'] != df2['Age']])\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m177\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 177 rows in the age column are different.\u001b[0m\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m-1\u001b[0m\u001b[32;1m\u001b[1;3mThere is 1 row in the age column that is different between the two dataframes.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -205,17 +262,17 @@
{
"data": {
"text/plain": [
"'177 rows in the age column are different.'"
"'There is 1 row in the age column that is different between the two dataframes.'"
]
},
"execution_count": 3,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent = create_csv_agent(OpenAI(temperature=0), ['titanic.csv', 'titanic_age_fillna.csv'], verbose=True)\n",
"agent.run(\"how many rows in the age column are different?\")"
"agent = create_csv_agent(ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"), ['titanic.csv', 'titanic_age_fillna.csv'], verbose=True, agent_type=AgentType.OPENAI_FUNCTIONS)\n",
"agent.run(\"how many rows in the age column are different between the two dfs?\")"
]
},
{

@ -19,7 +19,9 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import create_pandas_dataframe_agent"
"from langchain.agents import create_pandas_dataframe_agent\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.agents.agent_types import AgentType"
]
},
{
@ -35,6 +37,16 @@
"df = pd.read_csv('titanic.csv')"
]
},
{
"cell_type": "markdown",
"id": "a62858e2",
"metadata": {},
"source": [
"## Using ZERO_SHOT_REACT_DESCRIPTION\n",
"\n",
"This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above."
]
},
{
"cell_type": "code",
"execution_count": 3,
@ -45,9 +57,34 @@
"agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df, verbose=True)"
]
},
{
"cell_type": "markdown",
"id": "7233ab56",
"metadata": {},
"source": [
"## Using OpenAI Functions\n",
"\n",
"This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a8ea710e",
"metadata": {},
"outputs": [],
"source": [
"agent = create_pandas_dataframe_agent(\n",
" ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"), \n",
" df, \n",
" verbose=True, \n",
" agent_type=AgentType.OPENAI_FUNCTIONS\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a9207a2e",
"metadata": {},
"outputs": [
@ -57,13 +94,12 @@
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to count the number of rows\n",
"Action: python_repl_ast\n",
"Action Input: df.shape[0]\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: There are 891 rows.\u001b[0m\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `python_repl_ast` with `df.shape[0]`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m891\u001b[0m\u001b[32;1m\u001b[1;3mThere are 891 rows in the dataframe.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -71,10 +107,10 @@
{
"data": {
"text/plain": [
"'There are 891 rows.'"
"'There are 891 rows in the dataframe.'"
]
},
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@ -172,7 +208,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c4bc0584",
"metadata": {},
@ -257,7 +292,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.1"
}
},
"nbformat": 4,

@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"id": "f98e9c90-5c37-4fb9-af3e-d09693af8543",
"metadata": {
"tags": []
@ -22,12 +22,24 @@
"from langchain.agents.agent_toolkits import create_python_agent\n",
"from langchain.tools.python.tool import PythonREPLTool\n",
"from langchain.python import PythonREPL\n",
"from langchain.llms.openai import OpenAI"
"from langchain.llms.openai import OpenAI\n",
"from langchain.agents.agent_types import AgentType\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "markdown",
"id": "ca30d64c",
"metadata": {},
"source": [
"## Using ZERO_SHOT_REACT_DESCRIPTION\n",
"\n",
"This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above."
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 2,
"id": "cc422f53-c51c-4694-a834-72ecd1e68363",
"metadata": {
"tags": []
@ -37,7 +49,34 @@
"agent_executor = create_python_agent(\n",
" llm=OpenAI(temperature=0, max_tokens=1000),\n",
" tool=PythonREPLTool(),\n",
" verbose=True\n",
" verbose=True,\n",
" agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n",
")"
]
},
{
"cell_type": "markdown",
"id": "bb487e8e",
"metadata": {},
"source": [
"## Using OpenAI Functions\n",
"\n",
"This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6e651822",
"metadata": {},
"outputs": [],
"source": [
"agent_executor = create_python_agent(\n",
" llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"),\n",
" tool=PythonREPLTool(),\n",
" verbose=True,\n",
" agent_type=AgentType.OPENAI_FUNCTIONS,\n",
" agent_executor_kwargs={\"handle_parsing_errors\": True},\n",
")"
]
},
@ -52,7 +91,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "25cd4f92-ea9b-4fe6-9838-a4f85f81eebe",
"metadata": {
"tags": []
@ -64,23 +103,20 @@
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to calculate the 10th fibonacci number\n",
"Action: Python REPL\n",
"Action Input: def fibonacci(n):\n",
" if n == 0:\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `Python_REPL` with `def fibonacci(n):\n",
" if n <= 0:\n",
" return 0\n",
" elif n == 1:\n",
" return 1\n",
" else:\n",
" return fibonacci(n-1) + fibonacci(n-2)\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to call the function with 10 as the argument\n",
"Action: Python REPL\n",
"Action Input: fibonacci(10)\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 55\u001b[0m\n",
" return fibonacci(n-1) + fibonacci(n-2)\n",
"\n",
"fibonacci(10)`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3mThe 10th Fibonacci number is 55.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -88,10 +124,10 @@
{
"data": {
"text/plain": [
"'55'"
"'The 10th Fibonacci number is 55.'"
]
},
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -111,9 +147,11 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "4b9f60e7-eb6a-4f14-8604-498d863d4482",
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
@ -121,59 +159,70 @@
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to write a neural network in PyTorch and train it on the given data.\n",
"Action: Python REPL\n",
"Action Input: \n",
"import torch\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mCould not parse tool input: {'name': 'python', 'arguments': 'import torch\\nimport torch.nn as nn\\nimport torch.optim as optim\\n\\n# Define the neural network\\nclass SingleNeuron(nn.Module):\\n def __init__(self):\\n super(SingleNeuron, self).__init__()\\n self.linear = nn.Linear(1, 1)\\n \\n def forward(self, x):\\n return self.linear(x)\\n\\n# Create the synthetic data\\nx_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)\\ny_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float32)\\n\\n# Create the neural network\\nmodel = SingleNeuron()\\n\\n# Define the loss function and optimizer\\ncriterion = nn.MSELoss()\\noptimizer = optim.SGD(model.parameters(), lr=0.01)\\n\\n# Train the neural network\\nfor epoch in range(1, 1001):\\n # Forward pass\\n y_pred = model(x_train)\\n \\n # Compute loss\\n loss = criterion(y_pred, y_train)\\n \\n # Backward pass and optimization\\n optimizer.zero_grad()\\n loss.backward()\\n optimizer.step()\\n \\n # Print the loss every 100 epochs\\n if epoch % 100 == 0:\\n print(f\"Epoch {epoch}: Loss = {loss.item()}\")\\n\\n# Make a prediction for x = 5\\nx_test = torch.tensor([[5.0]], dtype=torch.float32)\\ny_pred = model(x_test)\\ny_pred.item()'} because the `arguments` is not valid JSON.\u001b[0mInvalid or incomplete response\u001b[32;1m\u001b[1;3m\n",
"Invoking: `Python_REPL` with `import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"# Define the model\n",
"model = torch.nn.Sequential(\n",
" torch.nn.Linear(1, 1)\n",
")\n",
"# Define the neural network\n",
"class SingleNeuron(nn.Module):\n",
" def __init__(self):\n",
" super(SingleNeuron, self).__init__()\n",
" self.linear = nn.Linear(1, 1)\n",
" \n",
" def forward(self, x):\n",
" return self.linear(x)\n",
"\n",
"# Define the loss\n",
"loss_fn = torch.nn.MSELoss()\n",
"# Create the synthetic data\n",
"x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)\n",
"y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float32)\n",
"\n",
"# Define the optimizer\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n",
"# Create the neural network\n",
"model = SingleNeuron()\n",
"\n",
"# Define the data\n",
"x_data = torch.tensor([[1.0], [2.0], [3.0], [4.0]])\n",
"y_data = torch.tensor([[2.0], [4.0], [6.0], [8.0]])\n",
"# Define the loss function and optimizer\n",
"criterion = nn.MSELoss()\n",
"optimizer = optim.SGD(model.parameters(), lr=0.01)\n",
"\n",
"# Train the model\n",
"for epoch in range(1000):\n",
"# Train the neural network\n",
"for epoch in range(1, 1001):\n",
" # Forward pass\n",
" y_pred = model(x_data)\n",
" y_pred = model(x_train)\n",
" \n",
" # Compute loss\n",
" loss = criterion(y_pred, y_train)\n",
" \n",
" # Backward pass and optimization\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" # Print the loss every 100 epochs\n",
" if epoch % 100 == 0:\n",
" print(f\"Epoch {epoch}: Loss = {loss.item()}\")\n",
"\n",
" # Compute and print loss\n",
" loss = loss_fn(y_pred, y_data)\n",
" if (epoch+1) % 100 == 0:\n",
" print(f'Epoch {epoch+1}: loss = {loss.item():.4f}')\n",
"# Make a prediction for x = 5\n",
"x_test = torch.tensor([[5.0]], dtype=torch.float32)\n",
"y_pred = model(x_test)\n",
"y_pred.item()`\n",
"\n",
" # Zero the gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # Backward pass\n",
" loss.backward()\n",
"\u001b[0m\u001b[36;1m\u001b[1;3mEpoch 100: Loss = 0.03825576975941658\n",
"Epoch 200: Loss = 0.02100197970867157\n",
"Epoch 300: Loss = 0.01152981910854578\n",
"Epoch 400: Loss = 0.006329738534986973\n",
"Epoch 500: Loss = 0.0034749575424939394\n",
"Epoch 600: Loss = 0.0019077073084190488\n",
"Epoch 700: Loss = 0.001047312980517745\n",
"Epoch 800: Loss = 0.0005749554838985205\n",
"Epoch 900: Loss = 0.0003156439634039998\n",
"Epoch 1000: Loss = 0.00017328384274151176\n",
"\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `Python_REPL` with `x_test.item()`\n",
"\n",
" # Update the weights\n",
" optimizer.step()\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mEpoch 100: loss = 0.0013\n",
"Epoch 200: loss = 0.0007\n",
"Epoch 300: loss = 0.0004\n",
"Epoch 400: loss = 0.0002\n",
"Epoch 500: loss = 0.0001\n",
"Epoch 600: loss = 0.0001\n",
"Epoch 700: loss = 0.0000\n",
"Epoch 800: loss = 0.0000\n",
"Epoch 900: loss = 0.0000\n",
"Epoch 1000: loss = 0.0000\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: The prediction for x = 5 is 10.0.\u001b[0m\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3mThe prediction for x = 5 is 10.000173568725586.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -181,10 +230,10 @@
{
"data": {
"text/plain": [
"'The prediction for x = 5 is 10.0.'"
"'The prediction for x = 5 is 10.000173568725586.'"
]
},
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@ -206,9 +255,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "LangChain",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "langchain"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@ -220,7 +269,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.1"
}
},
"nbformat": 4,

@ -37,7 +37,30 @@
"from langchain.agents.agent_toolkits import SQLDatabaseToolkit\n",
"from langchain.sql_database import SQLDatabase\n",
"from langchain.llms.openai import OpenAI\n",
"from langchain.agents import AgentExecutor"
"from langchain.agents import AgentExecutor\n",
"from langchain.agents.agent_types import AgentType\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "65ec5bb3",
"metadata": {},
"outputs": [],
"source": [
"db = SQLDatabase.from_uri(\"sqlite:///../../../../../notebooks/Chinook.db\")\n",
"toolkit = SQLDatabaseToolkit(db=db, llm=OpenAI(temperature=0))\n"
]
},
{
"cell_type": "markdown",
"id": "f74d1792",
"metadata": {},
"source": [
"## Using ZERO_SHOT_REACT_DESCRIPTION\n",
"\n",
"This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above."
]
},
{
@ -49,16 +72,39 @@
},
"outputs": [],
"source": [
"db = SQLDatabase.from_uri(\"sqlite:///../../../../notebooks/Chinook.db\")\n",
"toolkit = SQLDatabaseToolkit(db=db)\n",
"\n",
"agent_executor = create_sql_agent(\n",
" llm=OpenAI(temperature=0),\n",
" toolkit=toolkit,\n",
" verbose=True\n",
" verbose=True,\n",
" agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n",
")"
]
},
{
"cell_type": "markdown",
"id": "971cc455",
"metadata": {},
"source": [
"## Using OpenAI Functions\n",
"\n",
"This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6426a27d",
"metadata": {},
"outputs": [],
"source": [
"# agent_executor = create_sql_agent(\n",
"# llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"),\n",
"# toolkit=toolkit,\n",
"# verbose=True,\n",
"# agent_type=AgentType.OPENAI_FUNCTIONS\n",
"# )"
]
},
{
"cell_type": "markdown",
"id": "36ae48c7-cb08-4fef-977e-c7d4b96a464b",
@ -69,7 +115,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "ff70e83d-5ad0-4fc7-bb96-27d82ac166d7",
"metadata": {
"tags": []
@ -81,14 +127,16 @@
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mAction: list_tables_sql_db\n",
"Action Input: \"\"\u001b[0m\n",
"Observation: \u001b[38;5;200m\u001b[1;3mArtist, Invoice, Playlist, Genre, Album, PlaylistTrack, Track, InvoiceLine, MediaType, Employee, Customer\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I should look at the schema of the playlisttrack table\n",
"Action: schema_sql_db\n",
"Action Input: \"PlaylistTrack\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `list_tables_sql_db` with `{}`\n",
"\n",
"\n",
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Track, PlaylistTrack, InvoiceLine, sales_table, Playlist, Genre, Employee, Customer, Invoice, MediaType\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `schema_sql_db` with `PlaylistTrack`\n",
"\n",
"\n",
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"PlaylistTrack\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
@ -97,13 +145,36 @@
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
")\n",
"\n",
"SELECT * FROM 'PlaylistTrack' LIMIT 3;\n",
"PlaylistId TrackId\n",
"1 3402\n",
"1 3389\n",
"1 3390\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: The PlaylistTrack table has two columns, PlaylistId and TrackId, and is linked to the Playlist and Track tables.\u001b[0m\n",
"/*\n",
"3 rows from PlaylistTrack table:\n",
"PlaylistId\tTrackId\n",
"1\t3402\n",
"1\t3389\n",
"1\t3390\n",
"*/\u001b[0m\u001b[32;1m\u001b[1;3mThe `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the relationship between playlists and tracks. \n",
"\n",
"Here is the schema of the `PlaylistTrack` table:\n",
"\n",
"```\n",
"CREATE TABLE \"PlaylistTrack\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n",
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
")\n",
"```\n",
"\n",
"Here are three sample rows from the `PlaylistTrack` table:\n",
"\n",
"```\n",
"PlaylistId TrackId\n",
"1 3402\n",
"1 3389\n",
"1 3390\n",
"```\n",
"\n",
"Please let me know if there is anything else I can help you with.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -111,10 +182,10 @@
{
"data": {
"text/plain": [
"'The PlaylistTrack table has two columns, PlaylistId and TrackId, and is linked to the Playlist and Track tables.'"
"'The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the relationship between playlists and tracks. \\n\\nHere is the schema of the `PlaylistTrack` table:\\n\\n```\\nCREATE TABLE \"PlaylistTrack\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\n)\\n```\\n\\nHere are three sample rows from the `PlaylistTrack` table:\\n\\n```\\nPlaylistId TrackId\\n1 3402\\n1 3389\\n1 3390\\n```\\n\\nPlease let me know if there is anything else I can help you with.'"
]
},
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -519,7 +590,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.1"
}
},
"nbformat": 4,

@ -1,19 +1,26 @@
"""Agent for working with pandas objects."""
from typing import Any, Dict, List, Optional, Tuple
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.pandas.prompt import (
FUNCTIONS_WITH_DF,
FUNCTIONS_WITH_MULTI_DF,
MULTI_DF_PREFIX,
MULTI_DF_PREFIX_FUNCTIONS,
PREFIX,
PREFIX_FUNCTIONS,
SUFFIX_NO_DF,
SUFFIX_WITH_DF,
SUFFIX_WITH_MULTI_DF,
)
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.agents.types import AgentType
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import SystemMessage
from langchain.tools.python.tool import PythonAstREPLTool
@ -137,9 +144,108 @@ def _get_prompt_and_tools(
)
def _get_functions_single_prompt(
df: Any,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
include_df_in_prompt: Optional[bool] = True,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
if suffix is not None:
suffix_to_use = suffix
if include_df_in_prompt:
suffix_to_use = suffix_to_use.format(df_head=str(df.head().to_markdown()))
elif include_df_in_prompt:
suffix_to_use = FUNCTIONS_WITH_DF.format(df_head=str(df.head().to_markdown()))
else:
suffix_to_use = ""
if prefix is None:
prefix = PREFIX_FUNCTIONS
tools = [PythonAstREPLTool(locals={"df": df})]
system_message = SystemMessage(content=prefix + suffix_to_use)
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
return prompt, tools
def _get_functions_multi_prompt(
dfs: Any,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
include_df_in_prompt: Optional[bool] = True,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
if suffix is not None:
suffix_to_use = suffix
if include_df_in_prompt:
dfs_head = "\n\n".join([d.head().to_markdown() for d in dfs])
suffix_to_use = suffix_to_use.format(
dfs_head=dfs_head,
)
elif include_df_in_prompt:
dfs_head = "\n\n".join([d.head().to_markdown() for d in dfs])
suffix_to_use = FUNCTIONS_WITH_MULTI_DF.format(
dfs_head=dfs_head,
)
else:
suffix_to_use = ""
if prefix is None:
prefix = MULTI_DF_PREFIX_FUNCTIONS
prefix = prefix.format(num_dfs=str(len(dfs)))
df_locals = {}
for i, dataframe in enumerate(dfs):
df_locals[f"df{i + 1}"] = dataframe
tools = [PythonAstREPLTool(locals=df_locals)]
system_message = SystemMessage(content=prefix + suffix_to_use)
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
return prompt, tools
def _get_functions_prompt_and_tools(
df: Any,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
try:
import pandas as pd
except ImportError:
raise ValueError(
"pandas package not found, please install with `pip install pandas`"
)
if input_variables is not None:
raise ValueError("`input_variables` is not supported at the moment.")
if include_df_in_prompt is not None and suffix is not None:
raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
if isinstance(df, list):
for item in df:
if not isinstance(item, pd.DataFrame):
raise ValueError(f"Expected pandas object, got {type(df)}")
return _get_functions_multi_prompt(
df,
prefix=prefix,
suffix=suffix,
include_df_in_prompt=include_df_in_prompt,
)
else:
if not isinstance(df, pd.DataFrame):
raise ValueError(f"Expected pandas object, got {type(df)}")
return _get_functions_single_prompt(
df,
prefix=prefix,
suffix=suffix,
include_df_in_prompt=include_df_in_prompt,
)
def create_pandas_dataframe_agent(
llm: BaseLanguageModel,
df: Any,
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
@ -154,26 +260,44 @@ def create_pandas_dataframe_agent(
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a pandas agent from an LLM and dataframe."""
prompt, tools = _get_prompt_and_tools(
df,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(
llm_chain=llm_chain,
allowed_tools=tool_names,
callback_manager=callback_manager,
**kwargs,
)
agent: BaseSingleActionAgent
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt, tools = _get_prompt_and_tools(
df,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(
llm_chain=llm_chain,
allowed_tools=tool_names,
callback_manager=callback_manager,
**kwargs,
)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
_prompt, tools = _get_functions_prompt_and_tools(
df,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,

@ -28,3 +28,17 @@ This is the result of `print(df.head())` for each dataframe:
Begin!
Question: {input}
{agent_scratchpad}"""
PREFIX_FUNCTIONS = """
You are working with a pandas dataframe in Python. The name of the dataframe is `df`."""
MULTI_DF_PREFIX_FUNCTIONS = """
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc."""
FUNCTIONS_WITH_DF = """
This is the result of `print(df.head())`:
{df_head}"""
FUNCTIONS_WITH_MULTI_DF = """
This is the result of `print(df.head())` for each dataframe:
{dfs_head}"""

@ -2,18 +2,22 @@
from typing import Any, Dict, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.python.prompt import PREFIX
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.agents.types import AgentType
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.schema import SystemMessage
from langchain.tools.python.tool import PythonREPLTool
def create_python_agent(
llm: BaseLanguageModel,
tool: PythonREPLTool,
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None,
verbose: bool = False,
prefix: str = PREFIX,
@ -22,14 +26,29 @@ def create_python_agent(
) -> AgentExecutor:
"""Construct a python agent from an LLM and tool."""
tools = [tool]
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
agent: BaseSingleActionAgent
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
system_message = SystemMessage(content=prefix)
_prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,

@ -1,22 +1,35 @@
"""SQL agent."""
from typing import Any, Dict, List, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.sql.prompt import (
SQL_FUNCTIONS_SUFFIX,
SQL_PREFIX,
SQL_SUFFIX,
)
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.schema import AIMessage, SystemMessage
def create_sql_agent(
llm: BaseLanguageModel,
toolkit: SQLDatabaseToolkit,
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = SQL_PREFIX,
suffix: str = SQL_SUFFIX,
suffix: Optional[str] = None,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
@ -30,20 +43,44 @@ def create_sql_agent(
"""Construct a sql agent from an LLM and tools."""
tools = toolkit.get_tools()
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
agent: BaseSingleActionAgent
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix or SQL_SUFFIX,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
messages = [
SystemMessage(content=prefix),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
input_variables = ["input", "agent_scratchpad"]
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,

@ -19,3 +19,5 @@ SQL_SUFFIX = """Begin!
Question: {input}
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
{agent_scratchpad}"""
SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."""

@ -7,12 +7,11 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
from langchain.agents import BaseSingleActionAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
Callbacks,
)
from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
BaseMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
@ -23,6 +22,7 @@ from langchain.schema import (
AIMessage,
BaseMessage,
FunctionMessage,
OutputParserException,
SystemMessage,
)
from langchain.tools import BaseTool
@ -34,7 +34,9 @@ class _FunctionsAgentAction(AgentAction):
message_log: List[BaseMessage]
def _convert_agent_action_to_messages(agent_action: AgentAction) -> List[BaseMessage]:
def _convert_agent_action_to_messages(
agent_action: AgentAction, observation: str
) -> List[BaseMessage]:
"""Convert an agent action to a message.
This code is used to reconstruct the original AI message from the agent action.
@ -45,9 +47,12 @@ def _convert_agent_action_to_messages(agent_action: AgentAction) -> List[BaseMes
Returns:
AIMessage that corresponds to the original tool invocation.
"""
if not isinstance(agent_action, _FunctionsAgentAction):
raise ValueError("This agent type only works with _FunctionsAgentAction")
return agent_action.message_log
if isinstance(agent_action, _FunctionsAgentAction):
return agent_action.message_log + [
_create_function_message(agent_action, observation)
]
else:
return [AIMessage(content=agent_action.log)]
def _create_function_message(
@ -61,7 +66,10 @@ def _create_function_message(
FunctionMessage that corresponds to the original tool invocation
"""
if not isinstance(observation, str):
content = json.dumps(observation)
try:
content = json.dumps(observation)
except Exception:
content = str(observation)
else:
content = observation
return FunctionMessage(
@ -83,8 +91,7 @@ def _format_intermediate_steps(
for intermediate_step in intermediate_steps:
agent_action, observation = intermediate_step
messages.extend(_convert_agent_action_to_messages(agent_action))
messages.append(_create_function_message(agent_action, observation))
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
return messages
@ -102,7 +109,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
try:
_tool_input = json.loads(function_call["arguments"])
except JSONDecodeError:
raise ValueError(
raise OutputParserException(
f"Could not parse tool input: {function_call} because "
f"the `arguments` is not valid JSON."
)
@ -202,12 +209,24 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
return agent_decision
@classmethod
def create_prompt(cls) -> BasePromptTemplate:
messages = [
SystemMessage(content="You are a helpful AI assistant."),
HumanMessagePromptTemplate.from_template("{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
def create_prompt(
cls,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
) -> BasePromptTemplate:
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message:
messages = [system_message]
else:
messages = []
messages.extend(
[
HumanMessagePromptTemplate.from_template("{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
input_variables = ["input", "agent_scratchpad"]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)

@ -15,7 +15,7 @@ from langchain.utilities.arxiv import ArxivAPIWrapper
class ArxivQueryRun(BaseTool):
"""Tool that adds the capability to search using the Arxiv API."""
name = "Arxiv"
name = "arxiv"
description = (
"A wrapper around Arxiv.org "
"Useful for when you need to answer questions about Physics, Mathematics, "

@ -27,7 +27,7 @@ class AzureCogsFormRecognizerTool(BaseTool):
azure_cogs_endpoint: str = "" #: :meta private:
doc_analysis_client: Any #: :meta private:
name = "Azure Cognitive Services Form Recognizer"
name = "azure_cognitive_services_form_recognizer"
description = (
"A wrapper around Azure Cognitive Services Form Recognizer. "
"Useful for when you need to "

@ -28,7 +28,7 @@ class AzureCogsImageAnalysisTool(BaseTool):
vision_service: Any #: :meta private:
analysis_options: Any #: :meta private:
name = "Azure Cognitive Services Image Analysis"
name = "azure_cognitive_services_image_analysis"
description = (
"A wrapper around Azure Cognitive Services Image Analysis. "
"Useful for when you need to analyze images. "

@ -32,7 +32,7 @@ class AzureCogsSpeech2TextTool(BaseTool):
speech_language: str = "en-US" #: :meta private:
speech_config: Any #: :meta private:
name = "Azure Cognitive Services Speech2Text"
name = "azure_cognitive_services_speech2text"
description = (
"A wrapper around Azure Cognitive Services Speech2Text. "
"Useful for when you need to transcribe audio to text. "

@ -28,7 +28,7 @@ class AzureCogsText2SpeechTool(BaseTool):
speech_language: str = "en-US" #: :meta private:
speech_config: Any #: :meta private:
name = "Azure Cognitive Services Text2Speech"
name = "azure_cognitive_services_text2speech"
description = (
"A wrapper around Azure Cognitive Services Text2Speech. "
"Useful for when you need to convert text to speech. "

@ -13,7 +13,7 @@ from langchain.utilities.bing_search import BingSearchAPIWrapper
class BingSearchRun(BaseTool):
"""Tool that adds the capability to query the Bing search API."""
name = "Bing Search"
name = "bing_search"
description = (
"A wrapper around Bing Search. "
"Useful for when you need to answer questions about current events. "

@ -11,7 +11,7 @@ from langchain.utilities.brave_search import BraveSearchWrapper
class BraveSearch(BaseTool):
name = "brave-search"
name = "brave_search"
description = (
"a search engine. "
"useful for when you need to answer questions about current events."

@ -16,7 +16,7 @@ from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
class DuckDuckGoSearchRun(BaseTool):
"""Tool that adds the capability to query the DuckDuckGo search API."""
name = "DuckDuckGo Search"
name = "duckduckgo_search"
description = (
"A wrapper around DuckDuckGo Search. "
"Useful for when you need to answer questions about current events. "

@ -19,7 +19,7 @@ class GooglePlacesSchema(BaseModel):
class GooglePlacesTool(BaseTool):
"""Tool that adds the capability to query the Google places API."""
name = "Google Places"
name = "google_places"
description = (
"A wrapper around Google Places. "
"Useful for when you need to validate or "

@ -13,7 +13,7 @@ from langchain.utilities.google_search import GoogleSearchAPIWrapper
class GoogleSearchRun(BaseTool):
"""Tool that adds the capability to query the Google search API."""
name = "Google Search"
name = "google_search"
description = (
"A wrapper around Google Search. "
"Useful for when you need to answer questions about current events. "

@ -15,7 +15,7 @@ from langchain.utilities.google_serper import GoogleSerperAPIWrapper
class GoogleSerperRun(BaseTool):
"""Tool that adds the capability to query the Serper.dev Google search API."""
name = "Google Serper"
name = "google_serper"
description = (
"A low-cost Google Search API."
"Useful for when you need to answer questions about current events."

@ -19,7 +19,7 @@ def _print_func(text: str) -> None:
class HumanInputRun(BaseTool):
"""Tool that adds the capability to ask user for input."""
name = "Human"
name = "human"
description = (
"You can ask a human for guidance when you think you "
"got stuck or you are not sure what to do next. "

@ -13,7 +13,7 @@ from langchain.utilities.metaphor_search import MetaphorSearchAPIWrapper
class MetaphorSearchResults(BaseTool):
"""Tool that has capability to query the Metaphor Search API and get back json."""
name = "Metaphor Search Results JSON"
name = "metaphor_search_results_json"
description = (
"A wrapper around Metaphor Search. "
"Input should be a Metaphor-optimized query. "

@ -34,7 +34,7 @@ def sanitize_input(query: str) -> str:
class PythonREPLTool(BaseTool):
"""A tool for running python code in a REPL."""
name = "Python REPL"
name = "Python_REPL"
description = (
"A Python shell. Use this to execute python commands. "
"Input should be a valid python command. "

@ -20,7 +20,7 @@ class SceneXplainInput(BaseModel):
class SceneXplainTool(BaseTool):
"""Tool that adds the capability to explain images."""
name = "Image Explainer"
name = "image_explainer"
description = (
"An Image Captioning Tool: Use this tool to generate a detailed caption "
"for an image. The input can be an image file of any format, and "

@ -14,7 +14,7 @@ from langchain.utilities.searx_search import SearxSearchWrapper
class SearxSearchRun(BaseTool):
"""Tool that adds the capability to query a Searx instance."""
name = "Searx Search"
name = "searx_search"
description = (
"A meta search engine."
"Useful for when you need to answer questions about current events."

@ -13,7 +13,7 @@ from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
class WolframAlphaQueryRun(BaseTool):
"""Tool that adds the capability to query using the Wolfram Alpha SDK."""
name = "Wolfram Alpha"
name = "wolfram_alpha"
description = (
"A wrapper around Wolfram Alpha. "
"Useful for when you need to answer questions about Math, "

@ -19,7 +19,7 @@ from langchain.tools import BaseTool
class YouTubeSearchTool(BaseTool):
name = "YouTubeSearch"
name = "youtube_search"
description = (
"search for youtube videos associated with a person. "
"the input to this tool should be a comma separated list, "

@ -164,7 +164,7 @@ class ZapierNLAListActions(BaseTool):
"""
name = "Zapier NLA: List Actions"
name = "ZapierNLA_list_actions"
description = BASE_ZAPIER_TOOL_PROMPT + (
"This tool returns a list of the user's exposed actions."
)

Loading…
Cancel
Save