From e82687ddf42e63f1e793c22e5f35123c83e5d52c Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 15 Jun 2023 08:18:50 -0700 Subject: [PATCH] Harrison/use functions agent (#6185) Co-authored-by: Francisco Ingham <24279597+fpingham@users.noreply.github.com> --- .../agents/toolkits/examples/csv.ipynb | 173 ++++++++++------ .../agents/toolkits/examples/pandas.ipynb | 59 ++++-- .../agents/toolkits/examples/python.ipynb | 189 +++++++++++------- .../toolkits/examples/sql_database.ipynb | 119 ++++++++--- .../agents/agent_toolkits/pandas/base.py | 166 +++++++++++++-- .../agents/agent_toolkits/pandas/prompt.py | 14 ++ .../agents/agent_toolkits/python/base.py | 37 +++- langchain/agents/agent_toolkits/sql/base.py | 71 +++++-- langchain/agents/agent_toolkits/sql/prompt.py | 2 + .../agents/openai_functions_agent/base.py | 53 +++-- langchain/tools/arxiv/tool.py | 2 +- .../form_recognizer.py | 2 +- .../image_analysis.py | 2 +- .../azure_cognitive_services/speech2text.py | 2 +- .../azure_cognitive_services/text2speech.py | 2 +- langchain/tools/bing_search/tool.py | 2 +- langchain/tools/brave_search/tool.py | 2 +- langchain/tools/ddg_search/tool.py | 2 +- langchain/tools/google_places/tool.py | 2 +- langchain/tools/google_search/tool.py | 2 +- langchain/tools/google_serper/tool.py | 2 +- langchain/tools/human/tool.py | 2 +- langchain/tools/metaphor_search/tool.py | 2 +- langchain/tools/python/tool.py | 2 +- langchain/tools/scenexplain/tool.py | 2 +- langchain/tools/searx_search/tool.py | 2 +- langchain/tools/wolfram_alpha/tool.py | 2 +- langchain/tools/youtube/search.py | 2 +- langchain/tools/zapier/tool.py | 2 +- 29 files changed, 674 insertions(+), 247 deletions(-) diff --git a/docs/modules/agents/toolkits/examples/csv.ipynb b/docs/modules/agents/toolkits/examples/csv.ipynb index 030f5e58..4d7c7641 100644 --- a/docs/modules/agents/toolkits/examples/csv.ipynb +++ b/docs/modules/agents/toolkits/examples/csv.ipynb @@ -30,38 +30,83 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.llms import OpenAI" + "from langchain.llms import OpenAI\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.agents.agent_types import AgentType" + ] + }, + { + "cell_type": "markdown", + "id": "bd806175", + "metadata": {}, + "source": [ + "## Using ZERO_SHOT_REACT_DESCRIPTION\n", + "\n", + "This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, + "id": "a1717204", + "metadata": {}, + "outputs": [], + "source": [ + "agent = create_csv_agent(\n", + " OpenAI(temperature=0), \n", + " 'titanic.csv', \n", + " verbose=True, \n", + " agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c31bb8a6", + "metadata": {}, + "source": [ + "## Using OpenAI Functions\n", + "\n", + "This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above." + ] + }, + { + "cell_type": "code", + "execution_count": 3, "id": "16c4dc59", "metadata": {}, "outputs": [], "source": [ - "agent = create_csv_agent(OpenAI(temperature=0), 'titanic.csv', verbose=True)" + "agent = create_csv_agent(\n", + " ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"), \n", + " 'titanic.csv', \n", + " verbose=True, \n", + " agent_type=AgentType.OPENAI_FUNCTIONS\n", + ")" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "46b9489d", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Error in on_chain_start callback: 'name'\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `python_repl_ast` with `df.shape[0]`\n", "\n", "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3mThought: I need to count the number of rows\n", - "Action: python_repl_ast\n", - "Action Input: df.shape[0]\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: There are 891 rows.\u001b[0m\n", + "\u001b[0m\u001b[36;1m\u001b[1;3m891\u001b[0m\u001b[32;1m\u001b[1;3mThere are 891 rows in the dataframe.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -69,10 +114,10 @@ { "data": { "text/plain": [ - "'There are 891 rows.'" + "'There are 891 rows in the dataframe.'" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -83,23 +128,28 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "a96309be", - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Error in on_chain_start callback: 'name'\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `python_repl_ast` with `df[df['SibSp'] > 3]['PassengerId'].count()`\n", "\n", "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3mThought: I need to count the number of people with more than 3 siblings\n", - "Action: python_repl_ast\n", - "Action Input: df[df['SibSp'] > 3].shape[0]\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3m30\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: 30 people have more than 3 siblings.\u001b[0m\n", + "\u001b[0m\u001b[36;1m\u001b[1;3m30\u001b[0m\u001b[32;1m\u001b[1;3mThere are 30 people in the dataframe who have more than 3 siblings.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -107,10 +157,10 @@ { "data": { "text/plain": [ - "'30 people have more than 3 siblings.'" + "'There are 30 people in the dataframe who have more than 3 siblings.'" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -121,35 +171,39 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "964a09f7", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Error in on_chain_start callback: 'name'\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `python_repl_ast` with `import pandas as pd\n", + "import math\n", + "\n", + "# Create a dataframe\n", + "data = {'Age': [22, 38, 26, 35, 35]}\n", + "df = pd.DataFrame(data)\n", + "\n", + "# Calculate the average age\n", + "average_age = df['Age'].mean()\n", + "\n", + "# Calculate the square root of the average age\n", + "square_root = math.sqrt(average_age)\n", + "\n", + "square_root`\n", "\n", "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3mThought: I need to calculate the average age first\n", - "Action: python_repl_ast\n", - "Action Input: df['Age'].mean()\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3m29.69911764705882\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now need to calculate the square root of the average age\n", - "Action: python_repl_ast\n", - "Action Input: math.sqrt(df['Age'].mean())\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mNameError(\"name 'math' is not defined\")\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to import the math library\n", - "Action: python_repl_ast\n", - "Action Input: import math\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now need to calculate the square root of the average age\n", - "Action: python_repl_ast\n", - "Action Input: math.sqrt(df['Age'].mean())\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3m5.449689683556195\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: 5.449689683556195\u001b[0m\n", + "\u001b[0m\u001b[36;1m\u001b[1;3m5.585696017507576\u001b[0m\u001b[32;1m\u001b[1;3mThe square root of the average age is approximately 5.59.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -157,10 +211,10 @@ { "data": { "text/plain": [ - "'5.449689683556195'" + "'The square root of the average age is approximately 5.59.'" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -181,23 +235,26 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 8, "id": "15f11fbd", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Error in on_chain_start callback: 'name'\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `python_repl_ast` with `df1['Age'].nunique() - df2['Age'].nunique()`\n", "\n", "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3mThought: I need to compare the age columns in both dataframes\n", - "Action: python_repl_ast\n", - "Action Input: len(df1[df1['Age'] != df2['Age']])\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3m177\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: 177 rows in the age column are different.\u001b[0m\n", + "\u001b[0m\u001b[36;1m\u001b[1;3m-1\u001b[0m\u001b[32;1m\u001b[1;3mThere is 1 row in the age column that is different between the two dataframes.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -205,17 +262,17 @@ { "data": { "text/plain": [ - "'177 rows in the age column are different.'" + "'There is 1 row in the age column that is different between the two dataframes.'" ] }, - "execution_count": 3, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "agent = create_csv_agent(OpenAI(temperature=0), ['titanic.csv', 'titanic_age_fillna.csv'], verbose=True)\n", - "agent.run(\"how many rows in the age column are different?\")" + "agent = create_csv_agent(ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"), ['titanic.csv', 'titanic_age_fillna.csv'], verbose=True, agent_type=AgentType.OPENAI_FUNCTIONS)\n", + "agent.run(\"how many rows in the age column are different between the two dfs?\")" ] }, { diff --git a/docs/modules/agents/toolkits/examples/pandas.ipynb b/docs/modules/agents/toolkits/examples/pandas.ipynb index f2acf6ff..c28aef2d 100644 --- a/docs/modules/agents/toolkits/examples/pandas.ipynb +++ b/docs/modules/agents/toolkits/examples/pandas.ipynb @@ -19,7 +19,9 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.agents import create_pandas_dataframe_agent" + "from langchain.agents import create_pandas_dataframe_agent\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.agents.agent_types import AgentType" ] }, { @@ -35,6 +37,16 @@ "df = pd.read_csv('titanic.csv')" ] }, + { + "cell_type": "markdown", + "id": "a62858e2", + "metadata": {}, + "source": [ + "## Using ZERO_SHOT_REACT_DESCRIPTION\n", + "\n", + "This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above." + ] + }, { "cell_type": "code", "execution_count": 3, @@ -45,9 +57,34 @@ "agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df, verbose=True)" ] }, + { + "cell_type": "markdown", + "id": "7233ab56", + "metadata": {}, + "source": [ + "## Using OpenAI Functions\n", + "\n", + "This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above." + ] + }, { "cell_type": "code", "execution_count": 4, + "id": "a8ea710e", + "metadata": {}, + "outputs": [], + "source": [ + "agent = create_pandas_dataframe_agent(\n", + " ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"), \n", + " df, \n", + " verbose=True, \n", + " agent_type=AgentType.OPENAI_FUNCTIONS\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, "id": "a9207a2e", "metadata": {}, "outputs": [ @@ -57,13 +94,12 @@ "text": [ "\n", "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3mThought: I need to count the number of rows\n", - "Action: python_repl_ast\n", - "Action Input: df.shape[0]\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: There are 891 rows.\u001b[0m\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `python_repl_ast` with `df.shape[0]`\n", + "\n", + "\n", + "\u001b[0m\u001b[36;1m\u001b[1;3m891\u001b[0m\u001b[32;1m\u001b[1;3mThere are 891 rows in the dataframe.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -71,10 +107,10 @@ { "data": { "text/plain": [ - "'There are 891 rows.'" + "'There are 891 rows in the dataframe.'" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -172,7 +208,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "c4bc0584", "metadata": {}, @@ -257,7 +292,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/docs/modules/agents/toolkits/examples/python.ipynb b/docs/modules/agents/toolkits/examples/python.ipynb index 1c05a1f9..06870d75 100644 --- a/docs/modules/agents/toolkits/examples/python.ipynb +++ b/docs/modules/agents/toolkits/examples/python.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "id": "f98e9c90-5c37-4fb9-af3e-d09693af8543", "metadata": { "tags": [] @@ -22,12 +22,24 @@ "from langchain.agents.agent_toolkits import create_python_agent\n", "from langchain.tools.python.tool import PythonREPLTool\n", "from langchain.python import PythonREPL\n", - "from langchain.llms.openai import OpenAI" + "from langchain.llms.openai import OpenAI\n", + "from langchain.agents.agent_types import AgentType\n", + "from langchain.chat_models import ChatOpenAI" + ] + }, + { + "cell_type": "markdown", + "id": "ca30d64c", + "metadata": {}, + "source": [ + "## Using ZERO_SHOT_REACT_DESCRIPTION\n", + "\n", + "This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "id": "cc422f53-c51c-4694-a834-72ecd1e68363", "metadata": { "tags": [] @@ -37,7 +49,34 @@ "agent_executor = create_python_agent(\n", " llm=OpenAI(temperature=0, max_tokens=1000),\n", " tool=PythonREPLTool(),\n", - " verbose=True\n", + " verbose=True,\n", + " agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "bb487e8e", + "metadata": {}, + "source": [ + "## Using OpenAI Functions\n", + "\n", + "This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6e651822", + "metadata": {}, + "outputs": [], + "source": [ + "agent_executor = create_python_agent(\n", + " llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"),\n", + " tool=PythonREPLTool(),\n", + " verbose=True,\n", + " agent_type=AgentType.OPENAI_FUNCTIONS,\n", + " agent_executor_kwargs={\"handle_parsing_errors\": True},\n", ")" ] }, @@ -52,7 +91,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "25cd4f92-ea9b-4fe6-9838-a4f85f81eebe", "metadata": { "tags": [] @@ -64,23 +103,20 @@ "text": [ "\n", "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I need to calculate the 10th fibonacci number\n", - "Action: Python REPL\n", - "Action Input: def fibonacci(n):\n", - " if n == 0:\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `Python_REPL` with `def fibonacci(n):\n", + " if n <= 0:\n", " return 0\n", " elif n == 1:\n", " return 1\n", " else:\n", - " return fibonacci(n-1) + fibonacci(n-2)\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to call the function with 10 as the argument\n", - "Action: Python REPL\n", - "Action Input: fibonacci(10)\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3m\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: 55\u001b[0m\n", + " return fibonacci(n-1) + fibonacci(n-2)\n", + "\n", + "fibonacci(10)`\n", + "\n", + "\n", + "\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3mThe 10th Fibonacci number is 55.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -88,10 +124,10 @@ { "data": { "text/plain": [ - "'55'" + "'The 10th Fibonacci number is 55.'" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -111,9 +147,11 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "id": "4b9f60e7-eb6a-4f14-8604-498d863d4482", - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [ { "name": "stdout", @@ -121,59 +159,70 @@ "text": [ "\n", "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I need to write a neural network in PyTorch and train it on the given data.\n", - "Action: Python REPL\n", - "Action Input: \n", - "import torch\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mCould not parse tool input: {'name': 'python', 'arguments': 'import torch\\nimport torch.nn as nn\\nimport torch.optim as optim\\n\\n# Define the neural network\\nclass SingleNeuron(nn.Module):\\n def __init__(self):\\n super(SingleNeuron, self).__init__()\\n self.linear = nn.Linear(1, 1)\\n \\n def forward(self, x):\\n return self.linear(x)\\n\\n# Create the synthetic data\\nx_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)\\ny_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float32)\\n\\n# Create the neural network\\nmodel = SingleNeuron()\\n\\n# Define the loss function and optimizer\\ncriterion = nn.MSELoss()\\noptimizer = optim.SGD(model.parameters(), lr=0.01)\\n\\n# Train the neural network\\nfor epoch in range(1, 1001):\\n # Forward pass\\n y_pred = model(x_train)\\n \\n # Compute loss\\n loss = criterion(y_pred, y_train)\\n \\n # Backward pass and optimization\\n optimizer.zero_grad()\\n loss.backward()\\n optimizer.step()\\n \\n # Print the loss every 100 epochs\\n if epoch % 100 == 0:\\n print(f\"Epoch {epoch}: Loss = {loss.item()}\")\\n\\n# Make a prediction for x = 5\\nx_test = torch.tensor([[5.0]], dtype=torch.float32)\\ny_pred = model(x_test)\\ny_pred.item()'} because the `arguments` is not valid JSON.\u001b[0mInvalid or incomplete response\u001b[32;1m\u001b[1;3m\n", + "Invoking: `Python_REPL` with `import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", "\n", - "# Define the model\n", - "model = torch.nn.Sequential(\n", - " torch.nn.Linear(1, 1)\n", - ")\n", + "# Define the neural network\n", + "class SingleNeuron(nn.Module):\n", + " def __init__(self):\n", + " super(SingleNeuron, self).__init__()\n", + " self.linear = nn.Linear(1, 1)\n", + " \n", + " def forward(self, x):\n", + " return self.linear(x)\n", "\n", - "# Define the loss\n", - "loss_fn = torch.nn.MSELoss()\n", + "# Create the synthetic data\n", + "x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)\n", + "y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float32)\n", "\n", - "# Define the optimizer\n", - "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n", + "# Create the neural network\n", + "model = SingleNeuron()\n", "\n", - "# Define the data\n", - "x_data = torch.tensor([[1.0], [2.0], [3.0], [4.0]])\n", - "y_data = torch.tensor([[2.0], [4.0], [6.0], [8.0]])\n", + "# Define the loss function and optimizer\n", + "criterion = nn.MSELoss()\n", + "optimizer = optim.SGD(model.parameters(), lr=0.01)\n", "\n", - "# Train the model\n", - "for epoch in range(1000):\n", + "# Train the neural network\n", + "for epoch in range(1, 1001):\n", " # Forward pass\n", - " y_pred = model(x_data)\n", + " y_pred = model(x_train)\n", + " \n", + " # Compute loss\n", + " loss = criterion(y_pred, y_train)\n", + " \n", + " # Backward pass and optimization\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " # Print the loss every 100 epochs\n", + " if epoch % 100 == 0:\n", + " print(f\"Epoch {epoch}: Loss = {loss.item()}\")\n", "\n", - " # Compute and print loss\n", - " loss = loss_fn(y_pred, y_data)\n", - " if (epoch+1) % 100 == 0:\n", - " print(f'Epoch {epoch+1}: loss = {loss.item():.4f}')\n", + "# Make a prediction for x = 5\n", + "x_test = torch.tensor([[5.0]], dtype=torch.float32)\n", + "y_pred = model(x_test)\n", + "y_pred.item()`\n", "\n", - " # Zero the gradients\n", - " optimizer.zero_grad()\n", "\n", - " # Backward pass\n", - " loss.backward()\n", + "\u001b[0m\u001b[36;1m\u001b[1;3mEpoch 100: Loss = 0.03825576975941658\n", + "Epoch 200: Loss = 0.02100197970867157\n", + "Epoch 300: Loss = 0.01152981910854578\n", + "Epoch 400: Loss = 0.006329738534986973\n", + "Epoch 500: Loss = 0.0034749575424939394\n", + "Epoch 600: Loss = 0.0019077073084190488\n", + "Epoch 700: Loss = 0.001047312980517745\n", + "Epoch 800: Loss = 0.0005749554838985205\n", + "Epoch 900: Loss = 0.0003156439634039998\n", + "Epoch 1000: Loss = 0.00017328384274151176\n", + "\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "Invoking: `Python_REPL` with `x_test.item()`\n", "\n", - " # Update the weights\n", - " optimizer.step()\n", - "\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mEpoch 100: loss = 0.0013\n", - "Epoch 200: loss = 0.0007\n", - "Epoch 300: loss = 0.0004\n", - "Epoch 400: loss = 0.0002\n", - "Epoch 500: loss = 0.0001\n", - "Epoch 600: loss = 0.0001\n", - "Epoch 700: loss = 0.0000\n", - "Epoch 800: loss = 0.0000\n", - "Epoch 900: loss = 0.0000\n", - "Epoch 1000: loss = 0.0000\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: The prediction for x = 5 is 10.0.\u001b[0m\n", + "\n", + "\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3mThe prediction for x = 5 is 10.000173568725586.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -181,10 +230,10 @@ { "data": { "text/plain": [ - "'The prediction for x = 5 is 10.0.'" + "'The prediction for x = 5 is 10.000173568725586.'" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -206,9 +255,9 @@ ], "metadata": { "kernelspec": { - "display_name": "LangChain", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "langchain" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -220,7 +269,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/docs/modules/agents/toolkits/examples/sql_database.ipynb b/docs/modules/agents/toolkits/examples/sql_database.ipynb index f9cb0895..0360bc73 100644 --- a/docs/modules/agents/toolkits/examples/sql_database.ipynb +++ b/docs/modules/agents/toolkits/examples/sql_database.ipynb @@ -37,7 +37,30 @@ "from langchain.agents.agent_toolkits import SQLDatabaseToolkit\n", "from langchain.sql_database import SQLDatabase\n", "from langchain.llms.openai import OpenAI\n", - "from langchain.agents import AgentExecutor" + "from langchain.agents import AgentExecutor\n", + "from langchain.agents.agent_types import AgentType\n", + "from langchain.chat_models import ChatOpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "65ec5bb3", + "metadata": {}, + "outputs": [], + "source": [ + "db = SQLDatabase.from_uri(\"sqlite:///../../../../../notebooks/Chinook.db\")\n", + "toolkit = SQLDatabaseToolkit(db=db, llm=OpenAI(temperature=0))\n" + ] + }, + { + "cell_type": "markdown", + "id": "f74d1792", + "metadata": {}, + "source": [ + "## Using ZERO_SHOT_REACT_DESCRIPTION\n", + "\n", + "This shows how to initialize the agent using the ZERO_SHOT_REACT_DESCRIPTION agent type. Note that this is an alternative to the above." ] }, { @@ -49,16 +72,39 @@ }, "outputs": [], "source": [ - "db = SQLDatabase.from_uri(\"sqlite:///../../../../notebooks/Chinook.db\")\n", - "toolkit = SQLDatabaseToolkit(db=db)\n", - "\n", "agent_executor = create_sql_agent(\n", " llm=OpenAI(temperature=0),\n", " toolkit=toolkit,\n", - " verbose=True\n", + " verbose=True,\n", + " agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n", ")" ] }, + { + "cell_type": "markdown", + "id": "971cc455", + "metadata": {}, + "source": [ + "## Using OpenAI Functions\n", + "\n", + "This shows how to initialize the agent using the OPENAI_FUNCTIONS agent type. Note that this is an alternative to the above." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6426a27d", + "metadata": {}, + "outputs": [], + "source": [ + "# agent_executor = create_sql_agent(\n", + "# llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"),\n", + "# toolkit=toolkit,\n", + "# verbose=True,\n", + "# agent_type=AgentType.OPENAI_FUNCTIONS\n", + "# )" + ] + }, { "cell_type": "markdown", "id": "36ae48c7-cb08-4fef-977e-c7d4b96a464b", @@ -69,7 +115,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "ff70e83d-5ad0-4fc7-bb96-27d82ac166d7", "metadata": { "tags": [] @@ -81,14 +127,16 @@ "text": [ "\n", "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3mAction: list_tables_sql_db\n", - "Action Input: \"\"\u001b[0m\n", - "Observation: \u001b[38;5;200m\u001b[1;3mArtist, Invoice, Playlist, Genre, Album, PlaylistTrack, Track, InvoiceLine, MediaType, Employee, Customer\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I should look at the schema of the playlisttrack table\n", - "Action: schema_sql_db\n", - "Action Input: \"PlaylistTrack\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3m\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `list_tables_sql_db` with `{}`\n", + "\n", + "\n", + "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Track, PlaylistTrack, InvoiceLine, sales_table, Playlist, Genre, Employee, Customer, Invoice, MediaType\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "Invoking: `schema_sql_db` with `PlaylistTrack`\n", + "\n", + "\n", + "\u001b[0m\u001b[33;1m\u001b[1;3m\n", "CREATE TABLE \"PlaylistTrack\" (\n", "\t\"PlaylistId\" INTEGER NOT NULL, \n", "\t\"TrackId\" INTEGER NOT NULL, \n", @@ -97,13 +145,36 @@ "\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n", ")\n", "\n", - "SELECT * FROM 'PlaylistTrack' LIMIT 3;\n", - "PlaylistId TrackId\n", - "1 3402\n", - "1 3389\n", - "1 3390\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: The PlaylistTrack table has two columns, PlaylistId and TrackId, and is linked to the Playlist and Track tables.\u001b[0m\n", + "/*\n", + "3 rows from PlaylistTrack table:\n", + "PlaylistId\tTrackId\n", + "1\t3402\n", + "1\t3389\n", + "1\t3390\n", + "*/\u001b[0m\u001b[32;1m\u001b[1;3mThe `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the relationship between playlists and tracks. \n", + "\n", + "Here is the schema of the `PlaylistTrack` table:\n", + "\n", + "```\n", + "CREATE TABLE \"PlaylistTrack\" (\n", + "\t\"PlaylistId\" INTEGER NOT NULL, \n", + "\t\"TrackId\" INTEGER NOT NULL, \n", + "\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n", + "\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n", + "\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n", + ")\n", + "```\n", + "\n", + "Here are three sample rows from the `PlaylistTrack` table:\n", + "\n", + "```\n", + "PlaylistId TrackId\n", + "1 3402\n", + "1 3389\n", + "1 3390\n", + "```\n", + "\n", + "Please let me know if there is anything else I can help you with.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -111,10 +182,10 @@ { "data": { "text/plain": [ - "'The PlaylistTrack table has two columns, PlaylistId and TrackId, and is linked to the Playlist and Track tables.'" + "'The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the relationship between playlists and tracks. \\n\\nHere is the schema of the `PlaylistTrack` table:\\n\\n```\\nCREATE TABLE \"PlaylistTrack\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\n)\\n```\\n\\nHere are three sample rows from the `PlaylistTrack` table:\\n\\n```\\nPlaylistId TrackId\\n1 3402\\n1 3389\\n1 3390\\n```\\n\\nPlease let me know if there is anything else I can help you with.'" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -519,7 +590,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/langchain/agents/agent_toolkits/pandas/base.py b/langchain/agents/agent_toolkits/pandas/base.py index d8599a90..03eddbc8 100644 --- a/langchain/agents/agent_toolkits/pandas/base.py +++ b/langchain/agents/agent_toolkits/pandas/base.py @@ -1,19 +1,26 @@ """Agent for working with pandas objects.""" from typing import Any, Dict, List, Optional, Tuple -from langchain.agents.agent import AgentExecutor +from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent from langchain.agents.agent_toolkits.pandas.prompt import ( + FUNCTIONS_WITH_DF, + FUNCTIONS_WITH_MULTI_DF, MULTI_DF_PREFIX, + MULTI_DF_PREFIX_FUNCTIONS, PREFIX, + PREFIX_FUNCTIONS, SUFFIX_NO_DF, SUFFIX_WITH_DF, SUFFIX_WITH_MULTI_DF, ) from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent +from langchain.agents.types import AgentType from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain from langchain.prompts.base import BasePromptTemplate +from langchain.schema import SystemMessage from langchain.tools.python.tool import PythonAstREPLTool @@ -137,9 +144,108 @@ def _get_prompt_and_tools( ) +def _get_functions_single_prompt( + df: Any, + prefix: Optional[str] = None, + suffix: Optional[str] = None, + include_df_in_prompt: Optional[bool] = True, +) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]: + if suffix is not None: + suffix_to_use = suffix + if include_df_in_prompt: + suffix_to_use = suffix_to_use.format(df_head=str(df.head().to_markdown())) + elif include_df_in_prompt: + suffix_to_use = FUNCTIONS_WITH_DF.format(df_head=str(df.head().to_markdown())) + else: + suffix_to_use = "" + + if prefix is None: + prefix = PREFIX_FUNCTIONS + + tools = [PythonAstREPLTool(locals={"df": df})] + system_message = SystemMessage(content=prefix + suffix_to_use) + prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message) + return prompt, tools + + +def _get_functions_multi_prompt( + dfs: Any, + prefix: Optional[str] = None, + suffix: Optional[str] = None, + include_df_in_prompt: Optional[bool] = True, +) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]: + if suffix is not None: + suffix_to_use = suffix + if include_df_in_prompt: + dfs_head = "\n\n".join([d.head().to_markdown() for d in dfs]) + suffix_to_use = suffix_to_use.format( + dfs_head=dfs_head, + ) + elif include_df_in_prompt: + dfs_head = "\n\n".join([d.head().to_markdown() for d in dfs]) + suffix_to_use = FUNCTIONS_WITH_MULTI_DF.format( + dfs_head=dfs_head, + ) + else: + suffix_to_use = "" + + if prefix is None: + prefix = MULTI_DF_PREFIX_FUNCTIONS + prefix = prefix.format(num_dfs=str(len(dfs))) + + df_locals = {} + for i, dataframe in enumerate(dfs): + df_locals[f"df{i + 1}"] = dataframe + tools = [PythonAstREPLTool(locals=df_locals)] + system_message = SystemMessage(content=prefix + suffix_to_use) + prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message) + return prompt, tools + + +def _get_functions_prompt_and_tools( + df: Any, + prefix: Optional[str] = None, + suffix: Optional[str] = None, + input_variables: Optional[List[str]] = None, + include_df_in_prompt: Optional[bool] = True, +) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]: + try: + import pandas as pd + except ImportError: + raise ValueError( + "pandas package not found, please install with `pip install pandas`" + ) + if input_variables is not None: + raise ValueError("`input_variables` is not supported at the moment.") + + if include_df_in_prompt is not None and suffix is not None: + raise ValueError("If suffix is specified, include_df_in_prompt should not be.") + + if isinstance(df, list): + for item in df: + if not isinstance(item, pd.DataFrame): + raise ValueError(f"Expected pandas object, got {type(df)}") + return _get_functions_multi_prompt( + df, + prefix=prefix, + suffix=suffix, + include_df_in_prompt=include_df_in_prompt, + ) + else: + if not isinstance(df, pd.DataFrame): + raise ValueError(f"Expected pandas object, got {type(df)}") + return _get_functions_single_prompt( + df, + prefix=prefix, + suffix=suffix, + include_df_in_prompt=include_df_in_prompt, + ) + + def create_pandas_dataframe_agent( llm: BaseLanguageModel, df: Any, + agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION, callback_manager: Optional[BaseCallbackManager] = None, prefix: Optional[str] = None, suffix: Optional[str] = None, @@ -154,26 +260,44 @@ def create_pandas_dataframe_agent( **kwargs: Dict[str, Any], ) -> AgentExecutor: """Construct a pandas agent from an LLM and dataframe.""" - - prompt, tools = _get_prompt_and_tools( - df, - prefix=prefix, - suffix=suffix, - input_variables=input_variables, - include_df_in_prompt=include_df_in_prompt, - ) - llm_chain = LLMChain( - llm=llm, - prompt=prompt, - callback_manager=callback_manager, - ) - tool_names = [tool.name for tool in tools] - agent = ZeroShotAgent( - llm_chain=llm_chain, - allowed_tools=tool_names, - callback_manager=callback_manager, - **kwargs, - ) + agent: BaseSingleActionAgent + if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION: + prompt, tools = _get_prompt_and_tools( + df, + prefix=prefix, + suffix=suffix, + input_variables=input_variables, + include_df_in_prompt=include_df_in_prompt, + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent( + llm_chain=llm_chain, + allowed_tools=tool_names, + callback_manager=callback_manager, + **kwargs, + ) + elif agent_type == AgentType.OPENAI_FUNCTIONS: + _prompt, tools = _get_functions_prompt_and_tools( + df, + prefix=prefix, + suffix=suffix, + input_variables=input_variables, + include_df_in_prompt=include_df_in_prompt, + ) + agent = OpenAIFunctionsAgent( + llm=llm, + prompt=_prompt, + tools=tools, + callback_manager=callback_manager, + **kwargs, + ) + else: + raise ValueError(f"Agent type {agent_type} not supported at the moment.") return AgentExecutor.from_agent_and_tools( agent=agent, tools=tools, diff --git a/langchain/agents/agent_toolkits/pandas/prompt.py b/langchain/agents/agent_toolkits/pandas/prompt.py index af9f38c3..72b2bc8b 100644 --- a/langchain/agents/agent_toolkits/pandas/prompt.py +++ b/langchain/agents/agent_toolkits/pandas/prompt.py @@ -28,3 +28,17 @@ This is the result of `print(df.head())` for each dataframe: Begin! Question: {input} {agent_scratchpad}""" + +PREFIX_FUNCTIONS = """ +You are working with a pandas dataframe in Python. The name of the dataframe is `df`.""" + +MULTI_DF_PREFIX_FUNCTIONS = """ +You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc.""" + +FUNCTIONS_WITH_DF = """ +This is the result of `print(df.head())`: +{df_head}""" + +FUNCTIONS_WITH_MULTI_DF = """ +This is the result of `print(df.head())` for each dataframe: +{dfs_head}""" diff --git a/langchain/agents/agent_toolkits/python/base.py b/langchain/agents/agent_toolkits/python/base.py index 2db17642..e86c6ca6 100644 --- a/langchain/agents/agent_toolkits/python/base.py +++ b/langchain/agents/agent_toolkits/python/base.py @@ -2,18 +2,22 @@ from typing import Any, Dict, Optional -from langchain.agents.agent import AgentExecutor +from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent from langchain.agents.agent_toolkits.python.prompt import PREFIX from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent +from langchain.agents.types import AgentType from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain +from langchain.schema import SystemMessage from langchain.tools.python.tool import PythonREPLTool def create_python_agent( llm: BaseLanguageModel, tool: PythonREPLTool, + agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION, callback_manager: Optional[BaseCallbackManager] = None, verbose: bool = False, prefix: str = PREFIX, @@ -22,14 +26,29 @@ def create_python_agent( ) -> AgentExecutor: """Construct a python agent from an LLM and tool.""" tools = [tool] - prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix) - llm_chain = LLMChain( - llm=llm, - prompt=prompt, - callback_manager=callback_manager, - ) - tool_names = [tool.name for tool in tools] - agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + agent: BaseSingleActionAgent + + if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION: + prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + elif agent_type == AgentType.OPENAI_FUNCTIONS: + system_message = SystemMessage(content=prefix) + _prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message) + agent = OpenAIFunctionsAgent( + llm=llm, + prompt=_prompt, + tools=tools, + callback_manager=callback_manager, + **kwargs, + ) + else: + raise ValueError(f"Agent type {agent_type} not supported at the moment.") return AgentExecutor.from_agent_and_tools( agent=agent, tools=tools, diff --git a/langchain/agents/agent_toolkits/sql/base.py b/langchain/agents/agent_toolkits/sql/base.py index 784d3155..82fe07cc 100644 --- a/langchain/agents/agent_toolkits/sql/base.py +++ b/langchain/agents/agent_toolkits/sql/base.py @@ -1,22 +1,35 @@ """SQL agent.""" from typing import Any, Dict, List, Optional -from langchain.agents.agent import AgentExecutor -from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX +from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent +from langchain.agents.agent_toolkits.sql.prompt import ( + SQL_FUNCTIONS_SUFFIX, + SQL_PREFIX, + SQL_SUFFIX, +) from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit +from langchain.agents.agent_types import AgentType from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS +from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain +from langchain.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, +) +from langchain.schema import AIMessage, SystemMessage def create_sql_agent( llm: BaseLanguageModel, toolkit: SQLDatabaseToolkit, + agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION, callback_manager: Optional[BaseCallbackManager] = None, prefix: str = SQL_PREFIX, - suffix: str = SQL_SUFFIX, + suffix: Optional[str] = None, format_instructions: str = FORMAT_INSTRUCTIONS, input_variables: Optional[List[str]] = None, top_k: int = 10, @@ -30,20 +43,44 @@ def create_sql_agent( """Construct a sql agent from an LLM and tools.""" tools = toolkit.get_tools() prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k) - prompt = ZeroShotAgent.create_prompt( - tools, - prefix=prefix, - suffix=suffix, - format_instructions=format_instructions, - input_variables=input_variables, - ) - llm_chain = LLMChain( - llm=llm, - prompt=prompt, - callback_manager=callback_manager, - ) - tool_names = [tool.name for tool in tools] - agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + agent: BaseSingleActionAgent + + if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION: + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=prefix, + suffix=suffix or SQL_SUFFIX, + format_instructions=format_instructions, + input_variables=input_variables, + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + + elif agent_type == AgentType.OPENAI_FUNCTIONS: + messages = [ + SystemMessage(content=prefix), + HumanMessagePromptTemplate.from_template("{input}"), + AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + input_variables = ["input", "agent_scratchpad"] + _prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages) + + agent = OpenAIFunctionsAgent( + llm=llm, + prompt=_prompt, + tools=tools, + callback_manager=callback_manager, + **kwargs, + ) + else: + raise ValueError(f"Agent type {agent_type} not supported at the moment.") + return AgentExecutor.from_agent_and_tools( agent=agent, tools=tools, diff --git a/langchain/agents/agent_toolkits/sql/prompt.py b/langchain/agents/agent_toolkits/sql/prompt.py index 2a8ced66..92464da4 100644 --- a/langchain/agents/agent_toolkits/sql/prompt.py +++ b/langchain/agents/agent_toolkits/sql/prompt.py @@ -19,3 +19,5 @@ SQL_SUFFIX = """Begin! Question: {input} Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables. {agent_scratchpad}""" + +SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.""" diff --git a/langchain/agents/openai_functions_agent/base.py b/langchain/agents/openai_functions_agent/base.py index 2276f1a7..4b91bbb2 100644 --- a/langchain/agents/openai_functions_agent/base.py +++ b/langchain/agents/openai_functions_agent/base.py @@ -7,12 +7,11 @@ from typing import Any, List, Optional, Sequence, Tuple, Union from langchain.agents import BaseSingleActionAgent from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import ( - Callbacks, -) +from langchain.callbacks.manager import Callbacks from langchain.chat_models.openai import ChatOpenAI from langchain.prompts.base import BasePromptTemplate from langchain.prompts.chat import ( + BaseMessagePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder, @@ -23,6 +22,7 @@ from langchain.schema import ( AIMessage, BaseMessage, FunctionMessage, + OutputParserException, SystemMessage, ) from langchain.tools import BaseTool @@ -34,7 +34,9 @@ class _FunctionsAgentAction(AgentAction): message_log: List[BaseMessage] -def _convert_agent_action_to_messages(agent_action: AgentAction) -> List[BaseMessage]: +def _convert_agent_action_to_messages( + agent_action: AgentAction, observation: str +) -> List[BaseMessage]: """Convert an agent action to a message. This code is used to reconstruct the original AI message from the agent action. @@ -45,9 +47,12 @@ def _convert_agent_action_to_messages(agent_action: AgentAction) -> List[BaseMes Returns: AIMessage that corresponds to the original tool invocation. """ - if not isinstance(agent_action, _FunctionsAgentAction): - raise ValueError("This agent type only works with _FunctionsAgentAction") - return agent_action.message_log + if isinstance(agent_action, _FunctionsAgentAction): + return agent_action.message_log + [ + _create_function_message(agent_action, observation) + ] + else: + return [AIMessage(content=agent_action.log)] def _create_function_message( @@ -61,7 +66,10 @@ def _create_function_message( FunctionMessage that corresponds to the original tool invocation """ if not isinstance(observation, str): - content = json.dumps(observation) + try: + content = json.dumps(observation) + except Exception: + content = str(observation) else: content = observation return FunctionMessage( @@ -83,8 +91,7 @@ def _format_intermediate_steps( for intermediate_step in intermediate_steps: agent_action, observation = intermediate_step - messages.extend(_convert_agent_action_to_messages(agent_action)) - messages.append(_create_function_message(agent_action, observation)) + messages.extend(_convert_agent_action_to_messages(agent_action, observation)) return messages @@ -102,7 +109,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]: try: _tool_input = json.loads(function_call["arguments"]) except JSONDecodeError: - raise ValueError( + raise OutputParserException( f"Could not parse tool input: {function_call} because " f"the `arguments` is not valid JSON." ) @@ -202,12 +209,24 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): return agent_decision @classmethod - def create_prompt(cls) -> BasePromptTemplate: - messages = [ - SystemMessage(content="You are a helpful AI assistant."), - HumanMessagePromptTemplate.from_template("{input}"), - MessagesPlaceholder(variable_name="agent_scratchpad"), - ] + def create_prompt( + cls, + system_message: Optional[SystemMessage] = SystemMessage( + content="You are a helpful AI assistant." + ), + ) -> BasePromptTemplate: + messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] + if system_message: + messages = [system_message] + else: + messages = [] + + messages.extend( + [ + HumanMessagePromptTemplate.from_template("{input}"), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + ) input_variables = ["input", "agent_scratchpad"] return ChatPromptTemplate(input_variables=input_variables, messages=messages) diff --git a/langchain/tools/arxiv/tool.py b/langchain/tools/arxiv/tool.py index a9144dd8..f032cb0b 100644 --- a/langchain/tools/arxiv/tool.py +++ b/langchain/tools/arxiv/tool.py @@ -15,7 +15,7 @@ from langchain.utilities.arxiv import ArxivAPIWrapper class ArxivQueryRun(BaseTool): """Tool that adds the capability to search using the Arxiv API.""" - name = "Arxiv" + name = "arxiv" description = ( "A wrapper around Arxiv.org " "Useful for when you need to answer questions about Physics, Mathematics, " diff --git a/langchain/tools/azure_cognitive_services/form_recognizer.py b/langchain/tools/azure_cognitive_services/form_recognizer.py index 2b7bdd98..d28d69d5 100644 --- a/langchain/tools/azure_cognitive_services/form_recognizer.py +++ b/langchain/tools/azure_cognitive_services/form_recognizer.py @@ -27,7 +27,7 @@ class AzureCogsFormRecognizerTool(BaseTool): azure_cogs_endpoint: str = "" #: :meta private: doc_analysis_client: Any #: :meta private: - name = "Azure Cognitive Services Form Recognizer" + name = "azure_cognitive_services_form_recognizer" description = ( "A wrapper around Azure Cognitive Services Form Recognizer. " "Useful for when you need to " diff --git a/langchain/tools/azure_cognitive_services/image_analysis.py b/langchain/tools/azure_cognitive_services/image_analysis.py index c327bf58..4a8aaa94 100644 --- a/langchain/tools/azure_cognitive_services/image_analysis.py +++ b/langchain/tools/azure_cognitive_services/image_analysis.py @@ -28,7 +28,7 @@ class AzureCogsImageAnalysisTool(BaseTool): vision_service: Any #: :meta private: analysis_options: Any #: :meta private: - name = "Azure Cognitive Services Image Analysis" + name = "azure_cognitive_services_image_analysis" description = ( "A wrapper around Azure Cognitive Services Image Analysis. " "Useful for when you need to analyze images. " diff --git a/langchain/tools/azure_cognitive_services/speech2text.py b/langchain/tools/azure_cognitive_services/speech2text.py index fd9df593..dd358d5c 100644 --- a/langchain/tools/azure_cognitive_services/speech2text.py +++ b/langchain/tools/azure_cognitive_services/speech2text.py @@ -32,7 +32,7 @@ class AzureCogsSpeech2TextTool(BaseTool): speech_language: str = "en-US" #: :meta private: speech_config: Any #: :meta private: - name = "Azure Cognitive Services Speech2Text" + name = "azure_cognitive_services_speech2text" description = ( "A wrapper around Azure Cognitive Services Speech2Text. " "Useful for when you need to transcribe audio to text. " diff --git a/langchain/tools/azure_cognitive_services/text2speech.py b/langchain/tools/azure_cognitive_services/text2speech.py index c6d5a026..2725b516 100644 --- a/langchain/tools/azure_cognitive_services/text2speech.py +++ b/langchain/tools/azure_cognitive_services/text2speech.py @@ -28,7 +28,7 @@ class AzureCogsText2SpeechTool(BaseTool): speech_language: str = "en-US" #: :meta private: speech_config: Any #: :meta private: - name = "Azure Cognitive Services Text2Speech" + name = "azure_cognitive_services_text2speech" description = ( "A wrapper around Azure Cognitive Services Text2Speech. " "Useful for when you need to convert text to speech. " diff --git a/langchain/tools/bing_search/tool.py b/langchain/tools/bing_search/tool.py index 3340a55a..5e180ea2 100644 --- a/langchain/tools/bing_search/tool.py +++ b/langchain/tools/bing_search/tool.py @@ -13,7 +13,7 @@ from langchain.utilities.bing_search import BingSearchAPIWrapper class BingSearchRun(BaseTool): """Tool that adds the capability to query the Bing search API.""" - name = "Bing Search" + name = "bing_search" description = ( "A wrapper around Bing Search. " "Useful for when you need to answer questions about current events. " diff --git a/langchain/tools/brave_search/tool.py b/langchain/tools/brave_search/tool.py index b8d41061..19a96a48 100644 --- a/langchain/tools/brave_search/tool.py +++ b/langchain/tools/brave_search/tool.py @@ -11,7 +11,7 @@ from langchain.utilities.brave_search import BraveSearchWrapper class BraveSearch(BaseTool): - name = "brave-search" + name = "brave_search" description = ( "a search engine. " "useful for when you need to answer questions about current events." diff --git a/langchain/tools/ddg_search/tool.py b/langchain/tools/ddg_search/tool.py index 109431e2..c0f9b450 100644 --- a/langchain/tools/ddg_search/tool.py +++ b/langchain/tools/ddg_search/tool.py @@ -16,7 +16,7 @@ from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper class DuckDuckGoSearchRun(BaseTool): """Tool that adds the capability to query the DuckDuckGo search API.""" - name = "DuckDuckGo Search" + name = "duckduckgo_search" description = ( "A wrapper around DuckDuckGo Search. " "Useful for when you need to answer questions about current events. " diff --git a/langchain/tools/google_places/tool.py b/langchain/tools/google_places/tool.py index 3f1f368a..f453dcd6 100644 --- a/langchain/tools/google_places/tool.py +++ b/langchain/tools/google_places/tool.py @@ -19,7 +19,7 @@ class GooglePlacesSchema(BaseModel): class GooglePlacesTool(BaseTool): """Tool that adds the capability to query the Google places API.""" - name = "Google Places" + name = "google_places" description = ( "A wrapper around Google Places. " "Useful for when you need to validate or " diff --git a/langchain/tools/google_search/tool.py b/langchain/tools/google_search/tool.py index 71288e19..11ea5cab 100644 --- a/langchain/tools/google_search/tool.py +++ b/langchain/tools/google_search/tool.py @@ -13,7 +13,7 @@ from langchain.utilities.google_search import GoogleSearchAPIWrapper class GoogleSearchRun(BaseTool): """Tool that adds the capability to query the Google search API.""" - name = "Google Search" + name = "google_search" description = ( "A wrapper around Google Search. " "Useful for when you need to answer questions about current events. " diff --git a/langchain/tools/google_serper/tool.py b/langchain/tools/google_serper/tool.py index 27d35f32..18d58408 100644 --- a/langchain/tools/google_serper/tool.py +++ b/langchain/tools/google_serper/tool.py @@ -15,7 +15,7 @@ from langchain.utilities.google_serper import GoogleSerperAPIWrapper class GoogleSerperRun(BaseTool): """Tool that adds the capability to query the Serper.dev Google search API.""" - name = "Google Serper" + name = "google_serper" description = ( "A low-cost Google Search API." "Useful for when you need to answer questions about current events." diff --git a/langchain/tools/human/tool.py b/langchain/tools/human/tool.py index a207c6b1..4f8b2b3d 100644 --- a/langchain/tools/human/tool.py +++ b/langchain/tools/human/tool.py @@ -19,7 +19,7 @@ def _print_func(text: str) -> None: class HumanInputRun(BaseTool): """Tool that adds the capability to ask user for input.""" - name = "Human" + name = "human" description = ( "You can ask a human for guidance when you think you " "got stuck or you are not sure what to do next. " diff --git a/langchain/tools/metaphor_search/tool.py b/langchain/tools/metaphor_search/tool.py index 2e690111..cf4ea427 100644 --- a/langchain/tools/metaphor_search/tool.py +++ b/langchain/tools/metaphor_search/tool.py @@ -13,7 +13,7 @@ from langchain.utilities.metaphor_search import MetaphorSearchAPIWrapper class MetaphorSearchResults(BaseTool): """Tool that has capability to query the Metaphor Search API and get back json.""" - name = "Metaphor Search Results JSON" + name = "metaphor_search_results_json" description = ( "A wrapper around Metaphor Search. " "Input should be a Metaphor-optimized query. " diff --git a/langchain/tools/python/tool.py b/langchain/tools/python/tool.py index c53904e5..ddc98681 100644 --- a/langchain/tools/python/tool.py +++ b/langchain/tools/python/tool.py @@ -34,7 +34,7 @@ def sanitize_input(query: str) -> str: class PythonREPLTool(BaseTool): """A tool for running python code in a REPL.""" - name = "Python REPL" + name = "Python_REPL" description = ( "A Python shell. Use this to execute python commands. " "Input should be a valid python command. " diff --git a/langchain/tools/scenexplain/tool.py b/langchain/tools/scenexplain/tool.py index 7ac3a72e..0be2b12c 100644 --- a/langchain/tools/scenexplain/tool.py +++ b/langchain/tools/scenexplain/tool.py @@ -20,7 +20,7 @@ class SceneXplainInput(BaseModel): class SceneXplainTool(BaseTool): """Tool that adds the capability to explain images.""" - name = "Image Explainer" + name = "image_explainer" description = ( "An Image Captioning Tool: Use this tool to generate a detailed caption " "for an image. The input can be an image file of any format, and " diff --git a/langchain/tools/searx_search/tool.py b/langchain/tools/searx_search/tool.py index e3ea04b5..9a584843 100644 --- a/langchain/tools/searx_search/tool.py +++ b/langchain/tools/searx_search/tool.py @@ -14,7 +14,7 @@ from langchain.utilities.searx_search import SearxSearchWrapper class SearxSearchRun(BaseTool): """Tool that adds the capability to query a Searx instance.""" - name = "Searx Search" + name = "searx_search" description = ( "A meta search engine." "Useful for when you need to answer questions about current events." diff --git a/langchain/tools/wolfram_alpha/tool.py b/langchain/tools/wolfram_alpha/tool.py index a243d22f..30ba2570 100644 --- a/langchain/tools/wolfram_alpha/tool.py +++ b/langchain/tools/wolfram_alpha/tool.py @@ -13,7 +13,7 @@ from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper class WolframAlphaQueryRun(BaseTool): """Tool that adds the capability to query using the Wolfram Alpha SDK.""" - name = "Wolfram Alpha" + name = "wolfram_alpha" description = ( "A wrapper around Wolfram Alpha. " "Useful for when you need to answer questions about Math, " diff --git a/langchain/tools/youtube/search.py b/langchain/tools/youtube/search.py index 0255434a..5ce5fd58 100644 --- a/langchain/tools/youtube/search.py +++ b/langchain/tools/youtube/search.py @@ -19,7 +19,7 @@ from langchain.tools import BaseTool class YouTubeSearchTool(BaseTool): - name = "YouTubeSearch" + name = "youtube_search" description = ( "search for youtube videos associated with a person. " "the input to this tool should be a comma separated list, " diff --git a/langchain/tools/zapier/tool.py b/langchain/tools/zapier/tool.py index 9709f5bb..96d3b762 100644 --- a/langchain/tools/zapier/tool.py +++ b/langchain/tools/zapier/tool.py @@ -164,7 +164,7 @@ class ZapierNLAListActions(BaseTool): """ - name = "Zapier NLA: List Actions" + name = "ZapierNLA_list_actions" description = BASE_ZAPIER_TOOL_PROMPT + ( "This tool returns a list of the user's exposed actions." )