langchain/tests/unit_tests/agents/test_mrkl.py
Matt Wells 1d861dc37a
MRKL output parser no longer breaks well formed queries (#5432)
# Handles the edge scenario in which the action input is a well formed
SQL query which ends with a quoted column

There may be a cleaner option here (or indeed other edge scenarios) but
this seems to robustly determine if the action input is likely to be a
well formed SQL query in which we don't want to arbitrarily trim off `"`
characters

Fixes #5423

## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:

For a quicker response, figure out the right person to tag with @

  @hwchase17 - project lead

  Agents / Tools / Toolkits
  - @vowelparrot
2023-05-30 15:58:47 -04:00

174 lines
5.8 KiB
Python

"""Test MRKL functionality."""
from typing import Tuple
import pytest
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.output_parser import MRKLOutputParser
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool
from langchain.prompts import PromptTemplate
from langchain.schema import AgentAction, OutputParserException
from tests.unit_tests.llms.fake_llm import FakeLLM
def get_action_and_input(text: str) -> Tuple[str, str]:
output = MRKLOutputParser().parse(text)
if isinstance(output, AgentAction):
return output.tool, str(output.tool_input)
else:
return "Final Answer", output.return_values["output"]
def test_get_action_and_input() -> None:
"""Test getting an action from text."""
llm_output = (
"Thought: I need to search for NBA\n" "Action: Search\n" "Action Input: NBA"
)
action, action_input = get_action_and_input(llm_output)
assert action == "Search"
assert action_input == "NBA"
def test_get_action_and_input_whitespace() -> None:
"""Test getting an action from text."""
llm_output = "Thought: I need to search for NBA\nAction: Search \nAction Input: NBA"
action, action_input = get_action_and_input(llm_output)
assert action == "Search"
assert action_input == "NBA"
def test_get_action_and_input_newline() -> None:
"""Test getting an action from text where Action Input is a code snippet."""
llm_output = (
"Now I need to write a unittest for the function.\n\n"
"Action: Python\nAction Input:\n```\nimport unittest\n\nunittest.main()\n```"
)
action, action_input = get_action_and_input(llm_output)
assert action == "Python"
assert action_input == "```\nimport unittest\n\nunittest.main()\n```"
def test_get_action_and_input_newline_after_keyword() -> None:
"""Test getting an action and action input from the text
when there is a new line before the action
(after the keywords "Action:" and "Action Input:")
"""
llm_output = """
I can use the `ls` command to list the contents of the directory \
and `grep` to search for the specific file.
Action:
Terminal
Action Input:
ls -l ~/.bashrc.d/
"""
action, action_input = get_action_and_input(llm_output)
assert action == "Terminal"
assert action_input == "ls -l ~/.bashrc.d/\n"
def test_get_action_and_input_sql_query() -> None:
"""Test getting the action and action input from the text
when the LLM output is a well formed SQL query
"""
llm_output = """
I should query for the largest single shift payment for every unique user.
Action: query_sql_db
Action Input: \
SELECT "UserName", MAX(totalpayment) FROM user_shifts GROUP BY "UserName" """
action, action_input = get_action_and_input(llm_output)
assert action == "query_sql_db"
assert (
action_input
== 'SELECT "UserName", MAX(totalpayment) FROM user_shifts GROUP BY "UserName"'
)
def test_get_final_answer() -> None:
"""Test getting final answer."""
llm_output = (
"Thought: I need to search for NBA\n"
"Action: Search\n"
"Action Input: NBA\n"
"Observation: founded in 1994\n"
"Thought: I can now answer the question\n"
"Final Answer: 1994"
)
action, action_input = get_action_and_input(llm_output)
assert action == "Final Answer"
assert action_input == "1994"
def test_get_final_answer_new_line() -> None:
"""Test getting final answer."""
llm_output = (
"Thought: I need to search for NBA\n"
"Action: Search\n"
"Action Input: NBA\n"
"Observation: founded in 1994\n"
"Thought: I can now answer the question\n"
"Final Answer:\n1994"
)
action, action_input = get_action_and_input(llm_output)
assert action == "Final Answer"
assert action_input == "1994"
def test_get_final_answer_multiline() -> None:
"""Test getting final answer that is multiline."""
llm_output = (
"Thought: I need to search for NBA\n"
"Action: Search\n"
"Action Input: NBA\n"
"Observation: founded in 1994 and 1993\n"
"Thought: I can now answer the question\n"
"Final Answer: 1994\n1993"
)
action, action_input = get_action_and_input(llm_output)
assert action == "Final Answer"
assert action_input == "1994\n1993"
def test_bad_action_input_line() -> None:
"""Test handling when no action input found."""
llm_output = "Thought: I need to search for NBA\n" "Action: Search\n" "Thought: NBA"
with pytest.raises(OutputParserException) as e_info:
get_action_and_input(llm_output)
assert e_info.value.observation is not None
def test_bad_action_line() -> None:
"""Test handling when no action found."""
llm_output = (
"Thought: I need to search for NBA\n" "Thought: Search\n" "Action Input: NBA"
)
with pytest.raises(OutputParserException) as e_info:
get_action_and_input(llm_output)
assert e_info.value.observation is not None
def test_from_chains() -> None:
"""Test initializing from chains."""
chain_configs = [
Tool(name="foo", func=lambda x: "foo", description="foobar1"),
Tool(name="bar", func=lambda x: "bar", description="foobar2"),
]
agent = ZeroShotAgent.from_llm_and_tools(FakeLLM(), chain_configs)
expected_tools_prompt = "foo: foobar1\nbar: foobar2"
expected_tool_names = "foo, bar"
expected_template = "\n\n".join(
[
PREFIX,
expected_tools_prompt,
FORMAT_INSTRUCTIONS.format(tool_names=expected_tool_names),
SUFFIX,
]
)
prompt = agent.llm_chain.prompt
assert isinstance(prompt, PromptTemplate)
assert prompt.template == expected_template