MRKL output parser no longer breaks well formed queries (#5432)

# Handles the edge scenario in which the action input is a well formed
SQL query which ends with a quoted column

There may be a cleaner option here (or indeed other edge scenarios) but
this seems to robustly determine if the action input is likely to be a
well formed SQL query in which we don't want to arbitrarily trim off `"`
characters

Fixes #5423

## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:

For a quicker response, figure out the right person to tag with @

  @hwchase17 - project lead

  Agents / Tools / Toolkits
  - @vowelparrot
This commit is contained in:
Matt Wells 2023-05-30 20:58:47 +01:00 committed by GitHub
parent c1807d8408
commit 1d861dc37a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 1 deletions

View File

@ -44,7 +44,13 @@ class MRKLOutputParser(AgentOutputParser):
raise OutputParserException(f"Could not parse LLM output: `{text}`")
action = match.group(1).strip()
action_input = match.group(2)
return AgentAction(action, action_input.strip(" ").strip('"'), text)
tool_input = action_input.strip(" ")
# ensure if its a well formed SQL query we don't remove any trailing " chars
if tool_input.startswith("SELECT ") is False:
tool_input = tool_input.strip('"')
return AgentAction(action, tool_input, text)
@property
def _type(self) -> str:

View File

@ -71,6 +71,23 @@ def test_get_action_and_input_newline_after_keyword() -> None:
assert action_input == "ls -l ~/.bashrc.d/\n"
def test_get_action_and_input_sql_query() -> None:
"""Test getting the action and action input from the text
when the LLM output is a well formed SQL query
"""
llm_output = """
I should query for the largest single shift payment for every unique user.
Action: query_sql_db
Action Input: \
SELECT "UserName", MAX(totalpayment) FROM user_shifts GROUP BY "UserName" """
action, action_input = get_action_and_input(llm_output)
assert action == "query_sql_db"
assert (
action_input
== 'SELECT "UserName", MAX(totalpayment) FROM user_shifts GROUP BY "UserName"'
)
def test_get_final_answer() -> None:
"""Test getting final answer."""
llm_output = (