diff --git a/docs/examples/agents/mrkl.ipynb b/docs/examples/agents/mrkl.ipynb index c0dcb817..3a069870 100644 --- a/docs/examples/agents/mrkl.ipynb +++ b/docs/examples/agents/mrkl.ipynb @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 8, "id": "ac561cc4", "metadata": {}, "outputs": [], @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 9, "id": "07e96d99", "metadata": {}, "outputs": [], @@ -43,16 +43,16 @@ "db = SQLDatabase.from_uri(\"sqlite:///../../../notebooks/Chinook.db\")\n", "db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n", "tools = [\n", - " Tool(\n", - " name = \"Search\",\n", - " func=search.run,\n", - " description=\"useful for when you need to answer questions about current events\"\n", - " ),\n", - " Tool(\n", - " name=\"Calculator\",\n", - " func=llm_math_chain.run,\n", - " description=\"useful for when you need to answer questions about math\"\n", - " ),\n", + "# Tool(\n", + "# name = \"Search\",\n", + "# func=search.run,\n", + "# description=\"useful for when you need to answer questions about current events\"\n", + "# ),\n", + "# Tool(\n", + "# name=\"Calculator\",\n", + "# func=llm_math_chain.run,\n", + "# description=\"useful for when you need to answer questions about math\"\n", + "# ),\n", " Tool(\n", " name=\"FooBar DB\",\n", " func=db_chain.run,\n", @@ -63,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 10, "id": "a069c4b6", "metadata": {}, "outputs": [], @@ -71,6 +71,153 @@ "mrkl = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)" ] }, + { + "cell_type": "code", + "execution_count": 11, + "id": "356a1396", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "What was our total number of sales and total revenue in 2013?\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out how many sales and how much revenue we had in 2013\n", + "Action: FooBar DB\n", + "Action Input: sales 2013\u001b[0m\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "sales 2013\n", + "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT SUM(Total) FROM Invoice WHERE InvoiceDate BETWEEN '2013-01-01' AND '2013-12-31'\u001b[0m\n", + "SQLResult: \u001b[33;1m\u001b[1;3m[(450.58000000000027,)]\u001b[0m\n", + "Answer:\u001b[32;1m\u001b[1;3m 450.58\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "Observation: \u001b[36;1m\u001b[1;3m 450.58\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out how many sales and how much revenue we had in 2013\n", + "Action: FooBar DB\n", + "Action Input: revenue 2013\u001b[0m\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "revenue 2013\n", + "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT SUM(Total) FROM Invoice WHERE strftime('%Y', InvoiceDate) = '2013'\u001b[0m\n", + "SQLResult: \u001b[33;1m\u001b[1;3m[(450.58000000000027,)]\u001b[0m\n", + "Answer:\u001b[32;1m\u001b[1;3m 450.58\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "Observation: \u001b[36;1m\u001b[1;3m 450.58\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: We had 450.58 sales and 450.58 revenue in 2013\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'We had 450.58 sales and 450.58 revenue in 2013'" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mrkl.run(\"What was our total number of sales and total revenue in 2013?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "4e9c9b23", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.agents import ZeroShotAgent\n", + "from langchain import PromptTemplate\n", + "from langchain import LLMChain\n", + "TEMPLATE = \"\"\"You are a data engineer answering questions using a SQL database.\n", + "\n", + "Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. If a previous query produced an error, do NOT try it again.\n", + "\n", + "Only use the following tables:\n", + "\n", + "{table_info}\n", + "\n", + "Use the following format:\n", + "\n", + "Question: the input question you must answer\n", + "Thought: you should always think about what to do\n", + "Action: SQLDB\n", + "Action Input: the query to run against the SQL database\n", + "Observation: the result of the action\n", + "... (this Thought/Action/Action Input/Observation can repeat N times)\n", + "Thought: I now know the final answer\n", + "Final Answer: the final answer to the original input question\n", + "\n", + "Begin!\n", + "\n", + "Question: {{input}}\"\"\".format(**{\n", + " \"dialect\": db.dialect,\n", + " \"table_info\": db.table_info})\n", + "prompt = PromptTemplate(template=TEMPLATE, input_variables=[\"input\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "a6fb1b4f", + "metadata": {}, + "outputs": [], + "source": [ + "llm_chain = LLMChain(llm=llm, prompt=prompt)\n", + "tools = [\n", + " Tool(\"SQLDB\", db.run, \"foo\")\n", + "]\n", + "agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "64578802", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "What was our total number of sales and total revenue in 2013?\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find the total number of sales and total revenue\n", + "Action: SQLDB\n", + "Action Input: SELECT SUM(Total) AS \"Total Sales\", SUM(Total) AS \"Total Revenue\" FROM Invoice WHERE InvoiceDate LIKE '2013%'\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3m[(450.58000000000027, 450.58000000000027)]\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: 450.58\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'450.58'" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"What was our total number of sales and total revenue in 2013?\")" + ] + }, { "cell_type": "code", "execution_count": 4, diff --git a/langchain/agents/sql_agent.py b/langchain/agents/sql_agent.py new file mode 100644 index 00000000..da9d5e57 --- /dev/null +++ b/langchain/agents/sql_agent.py @@ -0,0 +1,22 @@ +TEMPLATE = """You are a data engineer answering questions using a SQL database. + +Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. + +Only use the following tables: + +{table_info} + +Use the following format: + +Question: the input question you must answer +Thought: you should always think about what to do +Action: SQLDB +Action Input: the query to run against the SQL database +Observation: the result of the action +... (this Thought/Action/Action Input/Observation can repeat N times) +Thought: I now know the final answer +Final Answer: the final answer to the original input question + +Begin! + +Question: {input}""" \ No newline at end of file diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index 32bb1c9c..0e231006 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -64,7 +64,10 @@ class SQLDatabaseChain(Chain, BaseModel): } sql_cmd = llm_chain.predict(**llm_inputs) chained_input.add(sql_cmd, color="green") - result = self.database.run(sql_cmd) + try: + result = self.database.run(sql_cmd) + except Exception as e: + result = str(e) chained_input.add("\nSQLResult: ") chained_input.add(result, color="yellow") chained_input.add("\nAnswer:") diff --git a/langchain/sql_database.py b/langchain/sql_database.py index a04ab15a..6463b36b 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -67,5 +67,8 @@ class SQLDatabase: def run(self, command: str) -> str: """Execute a SQL command and return a string of the results.""" - result = self._engine.execute(command).fetchall() + try: + result = self._engine.execute(command).fetchall() + except Exception as e: + result = e return str(result)