Improving Resilience of MRKL Agent (#5014)

This is a highly optimized update to the pull request
https://github.com/hwchase17/langchain/pull/3269

Summary:
1) Added ability to MRKL agent to self solve the ValueError(f"Could not
parse LLM output: `{llm_output}`") error, whenever llm (especially
gpt-3.5-turbo) does not follow the format of MRKL Agent, while returning
"Action:" & "Action Input:".
2) The way I am solving this error is by responding back to the llm with
the messages "Invalid Format: Missing 'Action:' after 'Thought:'" &
"Invalid Format: Missing 'Action Input:' after 'Action:'" whenever
Action: and Action Input: are not present in the llm output
respectively.

For a detailed explanation, look at the previous pull request.

New Updates:
1) Since @hwchase17 , requested in the previous PR to communicate the
self correction (error) message, using the OutputParserException, I have
added new ability to the OutputParserException class to store the
observation & previous llm_output in order to communicate it to the next
Agent's prompt. This is done, without breaking/modifying any of the
functionality OutputParserException previously performs (i.e.
OutputParserException can be used in the same way as before, without
passing any observation & previous llm_output too).

---------

Co-authored-by: Deepak S V <svdeepak99@users.noreply.github.com>
This commit is contained in:
Deepak S V 2023-05-22 14:08:08 -04:00 committed by GitHub
parent 6eacd88ae7
commit 5cd12102be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 6 deletions

View File

@ -773,6 +773,10 @@ class AgentExecutor(Chain):
raise e
text = str(e)
if isinstance(self.handle_parsing_errors, bool):
if e.send_to_llm:
observation = str(e.observation)
text = str(e.llm_output)
else:
observation = "Invalid or incomplete response"
elif isinstance(self.handle_parsing_errors, str):
observation = self.handle_parsing_errors

View File

@ -23,6 +23,24 @@ class MRKLOutputParser(AgentOutputParser):
)
match = re.search(regex, text, re.DOTALL)
if not match:
if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL):
raise OutputParserException(
f"Could not parse LLM output: `{text}`",
observation="Invalid Format: Missing 'Action:' after 'Thought:'",
llm_output=text,
send_to_llm=True,
)
elif not re.search(
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
):
raise OutputParserException(
f"Could not parse LLM output: `{text}`",
observation="Invalid Format:"
" Missing 'Action Input:' after 'Action:'",
llm_output=text,
send_to_llm=True,
)
else:
raise OutputParserException(f"Could not parse LLM output: `{text}`")
action = match.group(1).strip()
action_input = match.group(2)

View File

@ -369,7 +369,23 @@ class OutputParserException(ValueError):
errors will be raised.
"""
pass
def __init__(
self,
error: Any,
observation: str | None = None,
llm_output: str | None = None,
send_to_llm: bool = False,
):
super(OutputParserException, self).__init__(error)
if send_to_llm:
if observation is None or llm_output is None:
raise ValueError(
"Arguments 'observation' & 'llm_output'"
" are required if 'send_to_llm' is True"
)
self.observation = observation
self.llm_output = llm_output
self.send_to_llm = send_to_llm
class BaseDocumentTransformer(ABC):

View File

@ -119,17 +119,19 @@ def test_get_final_answer_multiline() -> None:
def test_bad_action_input_line() -> None:
"""Test handling when no action input found."""
llm_output = "Thought: I need to search for NBA\n" "Action: Search\n" "Thought: NBA"
with pytest.raises(OutputParserException):
with pytest.raises(OutputParserException) as e_info:
get_action_and_input(llm_output)
assert e_info.value.observation is not None
def test_bad_action_line() -> None:
"""Test handling when no action input found."""
"""Test handling when no action found."""
llm_output = (
"Thought: I need to search for NBA\n" "Thought: Search\n" "Action Input: NBA"
)
with pytest.raises(OutputParserException):
with pytest.raises(OutputParserException) as e_info:
get_action_and_input(llm_output)
assert e_info.value.observation is not None
def test_from_chains() -> None: