mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
71 lines
2.3 KiB
Python
71 lines
2.3 KiB
Python
|
"""Test MRKL functionality."""
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
from langchain.chains.mrkl.base import ChainConfig, MRKLChain, get_action_and_input
|
||
|
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
|
||
|
from langchain.prompt import Prompt
|
||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||
|
|
||
|
|
||
|
def test_get_action_and_input() -> None:
|
||
|
"""Test getting an action from text."""
|
||
|
llm_output = (
|
||
|
"Thought: I need to search for NBA\n" "Action: Search\n" "Action Input: NBA"
|
||
|
)
|
||
|
action, action_input = get_action_and_input(llm_output)
|
||
|
assert action == "Search"
|
||
|
assert action_input == "NBA"
|
||
|
|
||
|
|
||
|
def test_get_final_answer() -> None:
|
||
|
"""Test getting final answer."""
|
||
|
llm_output = (
|
||
|
"Thought: I need to search for NBA\n"
|
||
|
"Action: Search\n"
|
||
|
"Action Input: NBA\n"
|
||
|
"Observation: founded in 1994\n"
|
||
|
"Thought: I can now answer the question\n"
|
||
|
"Final Answer: 1994"
|
||
|
)
|
||
|
action, action_input = get_action_and_input(llm_output)
|
||
|
assert action == "Final Answer: "
|
||
|
assert action_input == "1994"
|
||
|
|
||
|
|
||
|
def test_bad_action_input_line() -> None:
|
||
|
"""Test handling when no action input found."""
|
||
|
llm_output = "Thought: I need to search for NBA\n" "Action: Search\n" "Thought: NBA"
|
||
|
with pytest.raises(ValueError):
|
||
|
get_action_and_input(llm_output)
|
||
|
|
||
|
|
||
|
def test_bad_action_line() -> None:
|
||
|
"""Test handling when no action input found."""
|
||
|
llm_output = (
|
||
|
"Thought: I need to search for NBA\n" "Thought: Search\n" "Action Input: NBA"
|
||
|
)
|
||
|
with pytest.raises(ValueError):
|
||
|
get_action_and_input(llm_output)
|
||
|
|
||
|
|
||
|
def test_from_chains() -> None:
|
||
|
"""Test initializing from chains."""
|
||
|
chain_configs = [
|
||
|
ChainConfig(
|
||
|
action_name="foo", action=lambda x: "foo", action_description="foobar1"
|
||
|
),
|
||
|
ChainConfig(
|
||
|
action_name="bar", action=lambda x: "bar", action_description="foobar2"
|
||
|
),
|
||
|
]
|
||
|
mrkl_chain = MRKLChain.from_chains(FakeLLM(), chain_configs)
|
||
|
expected_tools_prompt = "foo: foobar1\nbar: foobar2"
|
||
|
expected_tool_names = "foo, bar"
|
||
|
expected_template = BASE_TEMPLATE.format(
|
||
|
tools=expected_tools_prompt, tool_names=expected_tool_names
|
||
|
)
|
||
|
prompt = mrkl_chain.prompt
|
||
|
assert isinstance(prompt, Prompt)
|
||
|
assert prompt.template == expected_template
|