This commit is contained in:
Harrison Chase 2022-11-05 14:41:53 -07:00 committed by GitHub
parent c636488fe5
commit 2456a547de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 476 additions and 2 deletions

171
examples/mrkl.ipynb Normal file
View File

@ -0,0 +1,171 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "ac561cc4",
"metadata": {},
"outputs": [],
"source": [
"from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain, SQLDatabase, SQLDatabaseChain\n",
"from langchain.chains.mrkl.base import ChainConfig"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "07e96d99",
"metadata": {},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)\n",
"search = SerpAPIChain()\n",
"llm_math_chain = LLMMathChain(llm=llm)\n",
"db = SQLDatabase.from_uri(\"sqlite:///../notebooks/Chinook.db\")\n",
"db_chain = SQLDatabaseChain(llm=llm, database=db)\n",
"chains = [\n",
" ChainConfig(\n",
" action_name = \"Search\",\n",
" action=search.search,\n",
" action_description=\"useful for when you need to answer questions about current events\"\n",
" ),\n",
" ChainConfig(\n",
" action_name=\"Calculator\",\n",
" action=llm_math_chain.run,\n",
" action_description=\"useful for when you need to answer questions about math\"\n",
" ),\n",
" \n",
" ChainConfig(\n",
" action_name=\"FooBar DB\",\n",
" action=db_chain.query,\n",
" action_description=\"useful for when you need to answer questions about FooBar. Input should be in the form of a question\"\n",
" )\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a069c4b6",
"metadata": {},
"outputs": [],
"source": [
"mrkl = MRKLChain.from_chains(llm, chains, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e603cd7d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\n",
"Thought:\u001b[102m I need to find the age of Olivia Wilde's boyfriend\n",
"Action: Search\n",
"Action Input: \"Olivia Wilde's boyfriend\"\u001b[0m\n",
"Observation: \u001b[104mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n",
"Thought:\u001b[102m I need to find the age of Harry Styles\n",
"Action: Search\n",
"Action Input: \"Harry Styles age\"\u001b[0m\n",
"Observation: \u001b[104m28 years\u001b[0m\n",
"Thought:\u001b[102m I need to calculate 28 to the 0.23 power\n",
"Action: Calculator\n",
"Action Input: 28^0.23\u001b[0m\n",
"Observation: \u001b[103mAnswer: 2.1520202182226886\n",
"\u001b[0m\n",
"Thought:\u001b[102m I now know the final answer\n",
"Final Answer: 2.1520202182226886\u001b[0m"
]
},
{
"data": {
"text/plain": [
"'2.1520202182226886'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mrkl.run(\"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a5c07010",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\n",
"Thought:\u001b[102m I need to find an album called 'The Storm Before the Calm'\n",
"Action: Search\n",
"Action Input: \"The Storm Before the Calm album\"\u001b[0m\n",
"Observation: \u001b[104mThe Storm Before the Calm (stylized in all lowercase) is the tenth (and eighth international) studio album by Canadian-American singer-songwriter Alanis ...\u001b[0m\n",
"Thought:\u001b[102m I need to check if Alanis is in the FooBar database\n",
"Action: FooBar DB\n",
"Action Input: \"Does Alanis Morissette exist in the FooBar database?\"\u001b[0m\n",
"Observation: \u001b[101m Yes\u001b[0m\n",
"Thought:\u001b[102m I need to find out what albums of Alanis's are in the FooBar database\n",
"Action: FooBar DB\n",
"Action Input: \"What albums by Alanis Morissette are in the FooBar database?\"\u001b[0m\n",
"Observation: \u001b[101m The album \"Jagged Little Pill\" by Alanis Morissette is in the FooBar database.\u001b[0m\n",
"Thought:\u001b[102m I now know the final answer\n",
"Final Answer: The album \"Jagged Little Pill\" by Alanis Morissette is the only album by Alanis Morissette in the FooBar database.\u001b[0m"
]
},
{
"data": {
"text/plain": [
"'The album \"Jagged Little Pill\" by Alanis Morissette is the only album by Alanis Morissette in the FooBar database.'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mrkl.run(\"Who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7c2e6ac",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -8,6 +8,7 @@ with open(Path(__file__).absolute().parents[0] / "VERSION") as _f:
from langchain.chains import (
LLMChain,
LLMMathChain,
MRKLChain,
PythonChain,
ReActChain,
SelfAskWithSearchChain,
@ -37,4 +38,5 @@ __all__ = [
"SQLDatabase",
"SQLDatabaseChain",
"FAISS",
"MRKLChain",
]

View File

@ -1,6 +1,7 @@
"""Chains are easily reusable components which can be linked together."""
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.mrkl.base import MRKLChain
from langchain.chains.python import PythonChain
from langchain.chains.react.base import ReActChain
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
@ -15,4 +16,5 @@ __all__ = [
"SerpAPIChain",
"ReActChain",
"SQLDatabaseChain",
"MRKLChain",
]

View File

@ -0,0 +1 @@
"""Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf."""

View File

@ -0,0 +1,176 @@
"""Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf."""
from typing import Any, Callable, Dict, List, NamedTuple, Tuple
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
from langchain.input import ChainedInput, get_color_mapping
from langchain.llms.base import LLM
from langchain.prompt import BasePrompt, Prompt
FINAL_ANSWER_ACTION = "Final Answer: "
class ChainConfig(NamedTuple):
"""Configuration for chain to use in MRKL system.
Args:
action_name: Name of the action.
action: Action function to call.
action_description: Description of the action.
"""
action_name: str
action: Callable
action_description: str
def get_action_and_input(llm_output: str) -> Tuple[str, str]:
"""Parse out the action and input from the LLM output."""
ps = [p for p in llm_output.split("\n") if p]
if ps[-1].startswith(FINAL_ANSWER_ACTION):
directive = ps[-1][len(FINAL_ANSWER_ACTION) :]
return FINAL_ANSWER_ACTION, directive
if not ps[-1].startswith("Action Input: "):
raise ValueError(
"The last line does not have an action input, "
"something has gone terribly wrong."
)
if not ps[-2].startswith("Action: "):
raise ValueError(
"The second to last line does not have an action, "
"something has gone terribly wrong."
)
action = ps[-2][len("Action: ") :]
action_input = ps[-1][len("Action Input: ") :]
return action, action_input.strip(" ").strip('"')
class MRKLChain(Chain, BaseModel):
"""Chain that implements the MRKL system.
Example:
.. code-block:: python
from langchain import OpenAI, Prompt, MRKLChain
from langchain.chains.mrkl.base import ChainConfig
llm = OpenAI(temperature=0)
prompt = Prompt(...)
action_to_chain_map = {...}
mrkl = MRKLChain(
llm=llm,
prompt=prompt,
action_to_chain_map=action_to_chain_map
)
"""
llm: LLM
"""LLM wrapper to use as router."""
prompt: BasePrompt
"""Prompt to use as router."""
action_to_chain_map: Dict[str, Callable]
"""Mapping from action name to chain to execute."""
verbose: bool = False
"""Whether to print out the code that was executed."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
@classmethod
def from_chains(
cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any
) -> "MRKLChain":
"""User friendly way to initialize the MRKL chain.
This is intended to be an easy way to get up and running with the
MRKL chain.
Args:
llm: The LLM to use as the router LLM.
chains: The chains the MRKL system has access to.
**kwargs: parameters to be passed to initialization.
Returns:
An initialized MRKL chain.
Example:
.. code-block:: python
from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain
from langchain.chains.mrkl.base import ChainConfig
llm = OpenAI(temperature=0)
search = SerpAPIChain()
llm_math_chain = LLMMathChain(llm=llm)
chains = [
ChainConfig(
action_name = "Search",
action=search.search,
action_description="useful for searching"
),
ChainConfig(
action_name="Calculator",
action=llm_math_chain.run,
action_description="useful for doing math"
)
]
mrkl = MRKLChain.from_chains(llm, chains)
"""
tools = "\n".join(
[f"{chain.action_name}: {chain.action_description}" for chain in chains]
)
tool_names = ", ".join([chain.action_name for chain in chains])
template = BASE_TEMPLATE.format(tools=tools, tool_names=tool_names)
prompt = Prompt(template=template, input_variables=["input"])
action_to_chain_map = {chain.action_name: chain.action for chain in chains}
return cls(
llm=llm, prompt=prompt, action_to_chain_map=action_to_chain_map, **kwargs
)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
chained_input = ChainedInput(
f"{inputs[self.input_key]}\nThought:", verbose=self.verbose
)
color_mapping = get_color_mapping(
list(self.action_to_chain_map.keys()), excluded_colors=["green"]
)
while True:
thought = llm_chain.predict(
input=chained_input.input, stop=["\nObservation"]
)
chained_input.add(thought, color="green")
action, action_input = get_action_and_input(thought)
if action == FINAL_ANSWER_ACTION:
return {self.output_key: action_input}
chain = self.action_to_chain_map[action]
ca = chain(action_input)
chained_input.add("\nObservation: ")
chained_input.add(ca, color=color_mapping[action])
chained_input.add("\nThought:")
def run(self, _input: str) -> str:
"""Run input through the MRKL system."""
return self({self.input_key: _input})[self.output_key]

View File

@ -0,0 +1,19 @@
# flake8: noqa
BASE_TEMPLATE = """Answer the following questions as best you can. You have access to the following tools:
{tools}
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!
Question: {{input}}"""

View File

@ -1,9 +1,20 @@
"""Handle chained inputs."""
from typing import Optional
from typing import Dict, List, Optional
_COLOR_MAPPING = {"blue": 104, "yellow": 103, "red": 101, "green": 102}
def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
"""Get mapping for items to a support color."""
colors = list(_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
return color_mapping
def print_text(text: str, color: Optional[str] = None) -> None:
"""Print text with highlighting and no end characters."""
if color is None:

View File

@ -0,0 +1,70 @@
"""Test MRKL functionality."""
import pytest
from langchain.chains.mrkl.base import ChainConfig, MRKLChain, get_action_and_input
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
from langchain.prompt import Prompt
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_get_action_and_input() -> None:
"""Test getting an action from text."""
llm_output = (
"Thought: I need to search for NBA\n" "Action: Search\n" "Action Input: NBA"
)
action, action_input = get_action_and_input(llm_output)
assert action == "Search"
assert action_input == "NBA"
def test_get_final_answer() -> None:
"""Test getting final answer."""
llm_output = (
"Thought: I need to search for NBA\n"
"Action: Search\n"
"Action Input: NBA\n"
"Observation: founded in 1994\n"
"Thought: I can now answer the question\n"
"Final Answer: 1994"
)
action, action_input = get_action_and_input(llm_output)
assert action == "Final Answer: "
assert action_input == "1994"
def test_bad_action_input_line() -> None:
"""Test handling when no action input found."""
llm_output = "Thought: I need to search for NBA\n" "Action: Search\n" "Thought: NBA"
with pytest.raises(ValueError):
get_action_and_input(llm_output)
def test_bad_action_line() -> None:
"""Test handling when no action input found."""
llm_output = (
"Thought: I need to search for NBA\n" "Thought: Search\n" "Action Input: NBA"
)
with pytest.raises(ValueError):
get_action_and_input(llm_output)
def test_from_chains() -> None:
"""Test initializing from chains."""
chain_configs = [
ChainConfig(
action_name="foo", action=lambda x: "foo", action_description="foobar1"
),
ChainConfig(
action_name="bar", action=lambda x: "bar", action_description="foobar2"
),
]
mrkl_chain = MRKLChain.from_chains(FakeLLM(), chain_configs)
expected_tools_prompt = "foo: foobar1\nbar: foobar2"
expected_tool_names = "foo, bar"
expected_template = BASE_TEMPLATE.format(
tools=expected_tools_prompt, tool_names=expected_tool_names
)
prompt = mrkl_chain.prompt
assert isinstance(prompt, Prompt)
assert prompt.template == expected_template

View File

@ -3,7 +3,7 @@
import sys
from io import StringIO
from langchain.input import ChainedInput
from langchain.input import ChainedInput, get_color_mapping
def test_chained_input_not_verbose() -> None:
@ -50,3 +50,25 @@ def test_chained_input_verbose() -> None:
output = mystdout.getvalue()
assert output == "\x1b[104mbaz\x1b[0m"
assert chained_input.input == "foobarbaz"
def test_get_color_mapping() -> None:
"""Test getting of color mapping."""
# Test on few inputs.
items = ["foo", "bar"]
output = get_color_mapping(items)
expected_output = {"foo": "blue", "bar": "yellow"}
assert output == expected_output
# Test on a lot of inputs.
items = [f"foo-{i}" for i in range(20)]
output = get_color_mapping(items)
assert len(output) == 20
def test_get_color_mapping_excluded_colors() -> None:
"""Test getting of color mapping with excluded colors."""
items = ["foo", "bar"]
output = get_color_mapping(items, excluded_colors=["blue"])
expected_output = {"foo": "yellow", "bar": "red"}
assert output == expected_output