From 2456a547de12f5777397b49c88d4418fb3f62a59 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 5 Nov 2022 14:41:53 -0700 Subject: [PATCH] mrkl (#42) --- examples/mrkl.ipynb | 171 ++++++++++++++++++++++++++ langchain/__init__.py | 2 + langchain/chains/__init__.py | 2 + langchain/chains/mrkl/__init__.py | 1 + langchain/chains/mrkl/base.py | 176 +++++++++++++++++++++++++++ langchain/chains/mrkl/prompt.py | 19 +++ langchain/input.py | 13 +- tests/unit_tests/chains/test_mrkl.py | 70 +++++++++++ tests/unit_tests/test_input.py | 24 +++- 9 files changed, 476 insertions(+), 2 deletions(-) create mode 100644 examples/mrkl.ipynb create mode 100644 langchain/chains/mrkl/__init__.py create mode 100644 langchain/chains/mrkl/base.py create mode 100644 langchain/chains/mrkl/prompt.py create mode 100644 tests/unit_tests/chains/test_mrkl.py diff --git a/examples/mrkl.ipynb b/examples/mrkl.ipynb new file mode 100644 index 0000000000..3e4b40e2ab --- /dev/null +++ b/examples/mrkl.ipynb @@ -0,0 +1,171 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ac561cc4", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain, SQLDatabase, SQLDatabaseChain\n", + "from langchain.chains.mrkl.base import ChainConfig" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "07e96d99", + "metadata": {}, + "outputs": [], + "source": [ + "llm = OpenAI(temperature=0)\n", + "search = SerpAPIChain()\n", + "llm_math_chain = LLMMathChain(llm=llm)\n", + "db = SQLDatabase.from_uri(\"sqlite:///../notebooks/Chinook.db\")\n", + "db_chain = SQLDatabaseChain(llm=llm, database=db)\n", + "chains = [\n", + " ChainConfig(\n", + " action_name = \"Search\",\n", + " action=search.search,\n", + " action_description=\"useful for when you need to answer questions about current events\"\n", + " ),\n", + " ChainConfig(\n", + " action_name=\"Calculator\",\n", + " action=llm_math_chain.run,\n", + " action_description=\"useful for when you need to answer questions about math\"\n", + " ),\n", + " \n", + " ChainConfig(\n", + " action_name=\"FooBar DB\",\n", + " action=db_chain.query,\n", + " action_description=\"useful for when you need to answer questions about FooBar. Input should be in the form of a question\"\n", + " )\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a069c4b6", + "metadata": {}, + "outputs": [], + "source": [ + "mrkl = MRKLChain.from_chains(llm, chains, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e603cd7d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\n", + "Thought:\u001b[102m I need to find the age of Olivia Wilde's boyfriend\n", + "Action: Search\n", + "Action Input: \"Olivia Wilde's boyfriend\"\u001b[0m\n", + "Observation: \u001b[104mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n", + "Thought:\u001b[102m I need to find the age of Harry Styles\n", + "Action: Search\n", + "Action Input: \"Harry Styles age\"\u001b[0m\n", + "Observation: \u001b[104m28 years\u001b[0m\n", + "Thought:\u001b[102m I need to calculate 28 to the 0.23 power\n", + "Action: Calculator\n", + "Action Input: 28^0.23\u001b[0m\n", + "Observation: \u001b[103mAnswer: 2.1520202182226886\n", + "\u001b[0m\n", + "Thought:\u001b[102m I now know the final answer\n", + "Final Answer: 2.1520202182226886\u001b[0m" + ] + }, + { + "data": { + "text/plain": [ + "'2.1520202182226886'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mrkl.run(\"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a5c07010", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\n", + "Thought:\u001b[102m I need to find an album called 'The Storm Before the Calm'\n", + "Action: Search\n", + "Action Input: \"The Storm Before the Calm album\"\u001b[0m\n", + "Observation: \u001b[104mThe Storm Before the Calm (stylized in all lowercase) is the tenth (and eighth international) studio album by Canadian-American singer-songwriter Alanis ...\u001b[0m\n", + "Thought:\u001b[102m I need to check if Alanis is in the FooBar database\n", + "Action: FooBar DB\n", + "Action Input: \"Does Alanis Morissette exist in the FooBar database?\"\u001b[0m\n", + "Observation: \u001b[101m Yes\u001b[0m\n", + "Thought:\u001b[102m I need to find out what albums of Alanis's are in the FooBar database\n", + "Action: FooBar DB\n", + "Action Input: \"What albums by Alanis Morissette are in the FooBar database?\"\u001b[0m\n", + "Observation: \u001b[101m The album \"Jagged Little Pill\" by Alanis Morissette is in the FooBar database.\u001b[0m\n", + "Thought:\u001b[102m I now know the final answer\n", + "Final Answer: The album \"Jagged Little Pill\" by Alanis Morissette is the only album by Alanis Morissette in the FooBar database.\u001b[0m" + ] + }, + { + "data": { + "text/plain": [ + "'The album \"Jagged Little Pill\" by Alanis Morissette is the only album by Alanis Morissette in the FooBar database.'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mrkl.run(\"Who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7c2e6ac", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/__init__.py b/langchain/__init__.py index 439cba0804..66e4fe03c7 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -8,6 +8,7 @@ with open(Path(__file__).absolute().parents[0] / "VERSION") as _f: from langchain.chains import ( LLMChain, LLMMathChain, + MRKLChain, PythonChain, ReActChain, SelfAskWithSearchChain, @@ -37,4 +38,5 @@ __all__ = [ "SQLDatabase", "SQLDatabaseChain", "FAISS", + "MRKLChain", ] diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index adde1ed8cc..45fb20cb95 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -1,6 +1,7 @@ """Chains are easily reusable components which can be linked together.""" from langchain.chains.llm import LLMChain from langchain.chains.llm_math.base import LLMMathChain +from langchain.chains.mrkl.base import MRKLChain from langchain.chains.python import PythonChain from langchain.chains.react.base import ReActChain from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain @@ -15,4 +16,5 @@ __all__ = [ "SerpAPIChain", "ReActChain", "SQLDatabaseChain", + "MRKLChain", ] diff --git a/langchain/chains/mrkl/__init__.py b/langchain/chains/mrkl/__init__.py new file mode 100644 index 0000000000..a86a5b510d --- /dev/null +++ b/langchain/chains/mrkl/__init__.py @@ -0,0 +1 @@ +"""Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf.""" diff --git a/langchain/chains/mrkl/base.py b/langchain/chains/mrkl/base.py new file mode 100644 index 0000000000..80b4e5c0e5 --- /dev/null +++ b/langchain/chains/mrkl/base.py @@ -0,0 +1,176 @@ +"""Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf.""" +from typing import Any, Callable, Dict, List, NamedTuple, Tuple + +from pydantic import BaseModel, Extra + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.chains.mrkl.prompt import BASE_TEMPLATE +from langchain.input import ChainedInput, get_color_mapping +from langchain.llms.base import LLM +from langchain.prompt import BasePrompt, Prompt + +FINAL_ANSWER_ACTION = "Final Answer: " + + +class ChainConfig(NamedTuple): + """Configuration for chain to use in MRKL system. + + Args: + action_name: Name of the action. + action: Action function to call. + action_description: Description of the action. + """ + + action_name: str + action: Callable + action_description: str + + +def get_action_and_input(llm_output: str) -> Tuple[str, str]: + """Parse out the action and input from the LLM output.""" + ps = [p for p in llm_output.split("\n") if p] + if ps[-1].startswith(FINAL_ANSWER_ACTION): + directive = ps[-1][len(FINAL_ANSWER_ACTION) :] + return FINAL_ANSWER_ACTION, directive + if not ps[-1].startswith("Action Input: "): + raise ValueError( + "The last line does not have an action input, " + "something has gone terribly wrong." + ) + if not ps[-2].startswith("Action: "): + raise ValueError( + "The second to last line does not have an action, " + "something has gone terribly wrong." + ) + action = ps[-2][len("Action: ") :] + action_input = ps[-1][len("Action Input: ") :] + return action, action_input.strip(" ").strip('"') + + +class MRKLChain(Chain, BaseModel): + """Chain that implements the MRKL system. + + Example: + .. code-block:: python + + from langchain import OpenAI, Prompt, MRKLChain + from langchain.chains.mrkl.base import ChainConfig + llm = OpenAI(temperature=0) + prompt = Prompt(...) + action_to_chain_map = {...} + mrkl = MRKLChain( + llm=llm, + prompt=prompt, + action_to_chain_map=action_to_chain_map + ) + """ + + llm: LLM + """LLM wrapper to use as router.""" + prompt: BasePrompt + """Prompt to use as router.""" + action_to_chain_map: Dict[str, Callable] + """Mapping from action name to chain to execute.""" + verbose: bool = False + """Whether to print out the code that was executed.""" + input_key: str = "question" #: :meta private: + output_key: str = "answer" #: :meta private: + + @classmethod + def from_chains( + cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any + ) -> "MRKLChain": + """User friendly way to initialize the MRKL chain. + + This is intended to be an easy way to get up and running with the + MRKL chain. + + Args: + llm: The LLM to use as the router LLM. + chains: The chains the MRKL system has access to. + **kwargs: parameters to be passed to initialization. + + Returns: + An initialized MRKL chain. + + Example: + .. code-block:: python + + from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain + from langchain.chains.mrkl.base import ChainConfig + llm = OpenAI(temperature=0) + search = SerpAPIChain() + llm_math_chain = LLMMathChain(llm=llm) + chains = [ + ChainConfig( + action_name = "Search", + action=search.search, + action_description="useful for searching" + ), + ChainConfig( + action_name="Calculator", + action=llm_math_chain.run, + action_description="useful for doing math" + ) + ] + mrkl = MRKLChain.from_chains(llm, chains) + """ + tools = "\n".join( + [f"{chain.action_name}: {chain.action_description}" for chain in chains] + ) + tool_names = ", ".join([chain.action_name for chain in chains]) + template = BASE_TEMPLATE.format(tools=tools, tool_names=tool_names) + prompt = Prompt(template=template, input_variables=["input"]) + action_to_chain_map = {chain.action_name: chain.action for chain in chains} + return cls( + llm=llm, prompt=prompt, action_to_chain_map=action_to_chain_map, **kwargs + ) + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Expect output key. + + :meta private: + """ + return [self.output_key] + + def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: + llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) + chained_input = ChainedInput( + f"{inputs[self.input_key]}\nThought:", verbose=self.verbose + ) + color_mapping = get_color_mapping( + list(self.action_to_chain_map.keys()), excluded_colors=["green"] + ) + while True: + thought = llm_chain.predict( + input=chained_input.input, stop=["\nObservation"] + ) + chained_input.add(thought, color="green") + action, action_input = get_action_and_input(thought) + if action == FINAL_ANSWER_ACTION: + return {self.output_key: action_input} + chain = self.action_to_chain_map[action] + ca = chain(action_input) + chained_input.add("\nObservation: ") + chained_input.add(ca, color=color_mapping[action]) + chained_input.add("\nThought:") + + def run(self, _input: str) -> str: + """Run input through the MRKL system.""" + return self({self.input_key: _input})[self.output_key] diff --git a/langchain/chains/mrkl/prompt.py b/langchain/chains/mrkl/prompt.py new file mode 100644 index 0000000000..128bdb8552 --- /dev/null +++ b/langchain/chains/mrkl/prompt.py @@ -0,0 +1,19 @@ +# flake8: noqa +BASE_TEMPLATE = """Answer the following questions as best you can. You have access to the following tools: + +{tools} + +Use the following format: + +Question: the input question you must answer +Thought: you should always think about what to do +Action: the action to take, should be one of [{tool_names}] +Action Input: the input to the action +Observation: the result of the action +... (this Thought/Action/Action Input/Observation can repeat N times) +Thought: I now know the final answer +Final Answer: the final answer to the original input question + +Begin! + +Question: {{input}}""" diff --git a/langchain/input.py b/langchain/input.py index d4ac80e72c..0d1b3b6daf 100644 --- a/langchain/input.py +++ b/langchain/input.py @@ -1,9 +1,20 @@ """Handle chained inputs.""" -from typing import Optional +from typing import Dict, List, Optional _COLOR_MAPPING = {"blue": 104, "yellow": 103, "red": 101, "green": 102} +def get_color_mapping( + items: List[str], excluded_colors: Optional[List] = None +) -> Dict[str, str]: + """Get mapping for items to a support color.""" + colors = list(_COLOR_MAPPING.keys()) + if excluded_colors is not None: + colors = [c for c in colors if c not in excluded_colors] + color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)} + return color_mapping + + def print_text(text: str, color: Optional[str] = None) -> None: """Print text with highlighting and no end characters.""" if color is None: diff --git a/tests/unit_tests/chains/test_mrkl.py b/tests/unit_tests/chains/test_mrkl.py new file mode 100644 index 0000000000..500181bf85 --- /dev/null +++ b/tests/unit_tests/chains/test_mrkl.py @@ -0,0 +1,70 @@ +"""Test MRKL functionality.""" + +import pytest + +from langchain.chains.mrkl.base import ChainConfig, MRKLChain, get_action_and_input +from langchain.chains.mrkl.prompt import BASE_TEMPLATE +from langchain.prompt import Prompt +from tests.unit_tests.llms.fake_llm import FakeLLM + + +def test_get_action_and_input() -> None: + """Test getting an action from text.""" + llm_output = ( + "Thought: I need to search for NBA\n" "Action: Search\n" "Action Input: NBA" + ) + action, action_input = get_action_and_input(llm_output) + assert action == "Search" + assert action_input == "NBA" + + +def test_get_final_answer() -> None: + """Test getting final answer.""" + llm_output = ( + "Thought: I need to search for NBA\n" + "Action: Search\n" + "Action Input: NBA\n" + "Observation: founded in 1994\n" + "Thought: I can now answer the question\n" + "Final Answer: 1994" + ) + action, action_input = get_action_and_input(llm_output) + assert action == "Final Answer: " + assert action_input == "1994" + + +def test_bad_action_input_line() -> None: + """Test handling when no action input found.""" + llm_output = "Thought: I need to search for NBA\n" "Action: Search\n" "Thought: NBA" + with pytest.raises(ValueError): + get_action_and_input(llm_output) + + +def test_bad_action_line() -> None: + """Test handling when no action input found.""" + llm_output = ( + "Thought: I need to search for NBA\n" "Thought: Search\n" "Action Input: NBA" + ) + with pytest.raises(ValueError): + get_action_and_input(llm_output) + + +def test_from_chains() -> None: + """Test initializing from chains.""" + chain_configs = [ + ChainConfig( + action_name="foo", action=lambda x: "foo", action_description="foobar1" + ), + ChainConfig( + action_name="bar", action=lambda x: "bar", action_description="foobar2" + ), + ] + mrkl_chain = MRKLChain.from_chains(FakeLLM(), chain_configs) + expected_tools_prompt = "foo: foobar1\nbar: foobar2" + expected_tool_names = "foo, bar" + expected_template = BASE_TEMPLATE.format( + tools=expected_tools_prompt, tool_names=expected_tool_names + ) + prompt = mrkl_chain.prompt + assert isinstance(prompt, Prompt) + assert prompt.template == expected_template diff --git a/tests/unit_tests/test_input.py b/tests/unit_tests/test_input.py index b2d1a52713..dd17bfc5de 100644 --- a/tests/unit_tests/test_input.py +++ b/tests/unit_tests/test_input.py @@ -3,7 +3,7 @@ import sys from io import StringIO -from langchain.input import ChainedInput +from langchain.input import ChainedInput, get_color_mapping def test_chained_input_not_verbose() -> None: @@ -50,3 +50,25 @@ def test_chained_input_verbose() -> None: output = mystdout.getvalue() assert output == "\x1b[104mbaz\x1b[0m" assert chained_input.input == "foobarbaz" + + +def test_get_color_mapping() -> None: + """Test getting of color mapping.""" + # Test on few inputs. + items = ["foo", "bar"] + output = get_color_mapping(items) + expected_output = {"foo": "blue", "bar": "yellow"} + assert output == expected_output + + # Test on a lot of inputs. + items = [f"foo-{i}" for i in range(20)] + output = get_color_mapping(items) + assert len(output) == 20 + + +def test_get_color_mapping_excluded_colors() -> None: + """Test getting of color mapping with excluded colors.""" + items = ["foo", "bar"] + output = get_color_mapping(items, excluded_colors=["blue"]) + expected_output = {"foo": "yellow", "bar": "red"} + assert output == expected_output