diff --git a/langchain/agents/mrkl/output_parser.py b/langchain/agents/mrkl/output_parser.py index e81baef3..3ef3ed0e 100644 --- a/langchain/agents/mrkl/output_parser.py +++ b/langchain/agents/mrkl/output_parser.py @@ -44,7 +44,13 @@ class MRKLOutputParser(AgentOutputParser): raise OutputParserException(f"Could not parse LLM output: `{text}`") action = match.group(1).strip() action_input = match.group(2) - return AgentAction(action, action_input.strip(" ").strip('"'), text) + + tool_input = action_input.strip(" ") + # ensure if its a well formed SQL query we don't remove any trailing " chars + if tool_input.startswith("SELECT ") is False: + tool_input = tool_input.strip('"') + + return AgentAction(action, tool_input, text) @property def _type(self) -> str: diff --git a/tests/unit_tests/agents/test_mrkl.py b/tests/unit_tests/agents/test_mrkl.py index ffaf55f5..1b8802db 100644 --- a/tests/unit_tests/agents/test_mrkl.py +++ b/tests/unit_tests/agents/test_mrkl.py @@ -71,6 +71,23 @@ def test_get_action_and_input_newline_after_keyword() -> None: assert action_input == "ls -l ~/.bashrc.d/\n" +def test_get_action_and_input_sql_query() -> None: + """Test getting the action and action input from the text + when the LLM output is a well formed SQL query + """ + llm_output = """ + I should query for the largest single shift payment for every unique user. + Action: query_sql_db + Action Input: \ + SELECT "UserName", MAX(totalpayment) FROM user_shifts GROUP BY "UserName" """ + action, action_input = get_action_and_input(llm_output) + assert action == "query_sql_db" + assert ( + action_input + == 'SELECT "UserName", MAX(totalpayment) FROM user_shifts GROUP BY "UserName"' + ) + + def test_get_final_answer() -> None: """Test getting final answer.""" llm_output = (