mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
mrkl (#42)
This commit is contained in:
parent
c636488fe5
commit
2456a547de
171
examples/mrkl.ipynb
Normal file
171
examples/mrkl.ipynb
Normal file
@ -0,0 +1,171 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "ac561cc4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain, SQLDatabase, SQLDatabaseChain\n",
|
||||
"from langchain.chains.mrkl.base import ChainConfig"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "07e96d99",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"search = SerpAPIChain()\n",
|
||||
"llm_math_chain = LLMMathChain(llm=llm)\n",
|
||||
"db = SQLDatabase.from_uri(\"sqlite:///../notebooks/Chinook.db\")\n",
|
||||
"db_chain = SQLDatabaseChain(llm=llm, database=db)\n",
|
||||
"chains = [\n",
|
||||
" ChainConfig(\n",
|
||||
" action_name = \"Search\",\n",
|
||||
" action=search.search,\n",
|
||||
" action_description=\"useful for when you need to answer questions about current events\"\n",
|
||||
" ),\n",
|
||||
" ChainConfig(\n",
|
||||
" action_name=\"Calculator\",\n",
|
||||
" action=llm_math_chain.run,\n",
|
||||
" action_description=\"useful for when you need to answer questions about math\"\n",
|
||||
" ),\n",
|
||||
" \n",
|
||||
" ChainConfig(\n",
|
||||
" action_name=\"FooBar DB\",\n",
|
||||
" action=db_chain.query,\n",
|
||||
" action_description=\"useful for when you need to answer questions about FooBar. Input should be in the form of a question\"\n",
|
||||
" )\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "a069c4b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mrkl = MRKLChain.from_chains(llm, chains, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "e603cd7d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\n",
|
||||
"Thought:\u001b[102m I need to find the age of Olivia Wilde's boyfriend\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Olivia Wilde's boyfriend\"\u001b[0m\n",
|
||||
"Observation: \u001b[104mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n",
|
||||
"Thought:\u001b[102m I need to find the age of Harry Styles\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Harry Styles age\"\u001b[0m\n",
|
||||
"Observation: \u001b[104m28 years\u001b[0m\n",
|
||||
"Thought:\u001b[102m I need to calculate 28 to the 0.23 power\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 28^0.23\u001b[0m\n",
|
||||
"Observation: \u001b[103mAnswer: 2.1520202182226886\n",
|
||||
"\u001b[0m\n",
|
||||
"Thought:\u001b[102m I now know the final answer\n",
|
||||
"Final Answer: 2.1520202182226886\u001b[0m"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'2.1520202182226886'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"mrkl.run(\"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "a5c07010",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\n",
|
||||
"Thought:\u001b[102m I need to find an album called 'The Storm Before the Calm'\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"The Storm Before the Calm album\"\u001b[0m\n",
|
||||
"Observation: \u001b[104mThe Storm Before the Calm (stylized in all lowercase) is the tenth (and eighth international) studio album by Canadian-American singer-songwriter Alanis ...\u001b[0m\n",
|
||||
"Thought:\u001b[102m I need to check if Alanis is in the FooBar database\n",
|
||||
"Action: FooBar DB\n",
|
||||
"Action Input: \"Does Alanis Morissette exist in the FooBar database?\"\u001b[0m\n",
|
||||
"Observation: \u001b[101m Yes\u001b[0m\n",
|
||||
"Thought:\u001b[102m I need to find out what albums of Alanis's are in the FooBar database\n",
|
||||
"Action: FooBar DB\n",
|
||||
"Action Input: \"What albums by Alanis Morissette are in the FooBar database?\"\u001b[0m\n",
|
||||
"Observation: \u001b[101m The album \"Jagged Little Pill\" by Alanis Morissette is in the FooBar database.\u001b[0m\n",
|
||||
"Thought:\u001b[102m I now know the final answer\n",
|
||||
"Final Answer: The album \"Jagged Little Pill\" by Alanis Morissette is the only album by Alanis Morissette in the FooBar database.\u001b[0m"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'The album \"Jagged Little Pill\" by Alanis Morissette is the only album by Alanis Morissette in the FooBar database.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"mrkl.run(\"Who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d7c2e6ac",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -8,6 +8,7 @@ with open(Path(__file__).absolute().parents[0] / "VERSION") as _f:
|
||||
from langchain.chains import (
|
||||
LLMChain,
|
||||
LLMMathChain,
|
||||
MRKLChain,
|
||||
PythonChain,
|
||||
ReActChain,
|
||||
SelfAskWithSearchChain,
|
||||
@ -37,4 +38,5 @@ __all__ = [
|
||||
"SQLDatabase",
|
||||
"SQLDatabaseChain",
|
||||
"FAISS",
|
||||
"MRKLChain",
|
||||
]
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Chains are easily reusable components which can be linked together."""
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.base import LLMMathChain
|
||||
from langchain.chains.mrkl.base import MRKLChain
|
||||
from langchain.chains.python import PythonChain
|
||||
from langchain.chains.react.base import ReActChain
|
||||
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
|
||||
@ -15,4 +16,5 @@ __all__ = [
|
||||
"SerpAPIChain",
|
||||
"ReActChain",
|
||||
"SQLDatabaseChain",
|
||||
"MRKLChain",
|
||||
]
|
||||
|
1
langchain/chains/mrkl/__init__.py
Normal file
1
langchain/chains/mrkl/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf."""
|
176
langchain/chains/mrkl/base.py
Normal file
176
langchain/chains/mrkl/base.py
Normal file
@ -0,0 +1,176 @@
|
||||
"""Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf."""
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Tuple
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
|
||||
from langchain.input import ChainedInput, get_color_mapping
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompt import BasePrompt, Prompt
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer: "
|
||||
|
||||
|
||||
class ChainConfig(NamedTuple):
|
||||
"""Configuration for chain to use in MRKL system.
|
||||
|
||||
Args:
|
||||
action_name: Name of the action.
|
||||
action: Action function to call.
|
||||
action_description: Description of the action.
|
||||
"""
|
||||
|
||||
action_name: str
|
||||
action: Callable
|
||||
action_description: str
|
||||
|
||||
|
||||
def get_action_and_input(llm_output: str) -> Tuple[str, str]:
|
||||
"""Parse out the action and input from the LLM output."""
|
||||
ps = [p for p in llm_output.split("\n") if p]
|
||||
if ps[-1].startswith(FINAL_ANSWER_ACTION):
|
||||
directive = ps[-1][len(FINAL_ANSWER_ACTION) :]
|
||||
return FINAL_ANSWER_ACTION, directive
|
||||
if not ps[-1].startswith("Action Input: "):
|
||||
raise ValueError(
|
||||
"The last line does not have an action input, "
|
||||
"something has gone terribly wrong."
|
||||
)
|
||||
if not ps[-2].startswith("Action: "):
|
||||
raise ValueError(
|
||||
"The second to last line does not have an action, "
|
||||
"something has gone terribly wrong."
|
||||
)
|
||||
action = ps[-2][len("Action: ") :]
|
||||
action_input = ps[-1][len("Action Input: ") :]
|
||||
return action, action_input.strip(" ").strip('"')
|
||||
|
||||
|
||||
class MRKLChain(Chain, BaseModel):
|
||||
"""Chain that implements the MRKL system.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import OpenAI, Prompt, MRKLChain
|
||||
from langchain.chains.mrkl.base import ChainConfig
|
||||
llm = OpenAI(temperature=0)
|
||||
prompt = Prompt(...)
|
||||
action_to_chain_map = {...}
|
||||
mrkl = MRKLChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
action_to_chain_map=action_to_chain_map
|
||||
)
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
"""LLM wrapper to use as router."""
|
||||
prompt: BasePrompt
|
||||
"""Prompt to use as router."""
|
||||
action_to_chain_map: Dict[str, Callable]
|
||||
"""Mapping from action name to chain to execute."""
|
||||
verbose: bool = False
|
||||
"""Whether to print out the code that was executed."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
@classmethod
|
||||
def from_chains(
|
||||
cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any
|
||||
) -> "MRKLChain":
|
||||
"""User friendly way to initialize the MRKL chain.
|
||||
|
||||
This is intended to be an easy way to get up and running with the
|
||||
MRKL chain.
|
||||
|
||||
Args:
|
||||
llm: The LLM to use as the router LLM.
|
||||
chains: The chains the MRKL system has access to.
|
||||
**kwargs: parameters to be passed to initialization.
|
||||
|
||||
Returns:
|
||||
An initialized MRKL chain.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain
|
||||
from langchain.chains.mrkl.base import ChainConfig
|
||||
llm = OpenAI(temperature=0)
|
||||
search = SerpAPIChain()
|
||||
llm_math_chain = LLMMathChain(llm=llm)
|
||||
chains = [
|
||||
ChainConfig(
|
||||
action_name = "Search",
|
||||
action=search.search,
|
||||
action_description="useful for searching"
|
||||
),
|
||||
ChainConfig(
|
||||
action_name="Calculator",
|
||||
action=llm_math_chain.run,
|
||||
action_description="useful for doing math"
|
||||
)
|
||||
]
|
||||
mrkl = MRKLChain.from_chains(llm, chains)
|
||||
"""
|
||||
tools = "\n".join(
|
||||
[f"{chain.action_name}: {chain.action_description}" for chain in chains]
|
||||
)
|
||||
tool_names = ", ".join([chain.action_name for chain in chains])
|
||||
template = BASE_TEMPLATE.format(tools=tools, tool_names=tool_names)
|
||||
prompt = Prompt(template=template, input_variables=["input"])
|
||||
action_to_chain_map = {chain.action_name: chain.action for chain in chains}
|
||||
return cls(
|
||||
llm=llm, prompt=prompt, action_to_chain_map=action_to_chain_map, **kwargs
|
||||
)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||
chained_input = ChainedInput(
|
||||
f"{inputs[self.input_key]}\nThought:", verbose=self.verbose
|
||||
)
|
||||
color_mapping = get_color_mapping(
|
||||
list(self.action_to_chain_map.keys()), excluded_colors=["green"]
|
||||
)
|
||||
while True:
|
||||
thought = llm_chain.predict(
|
||||
input=chained_input.input, stop=["\nObservation"]
|
||||
)
|
||||
chained_input.add(thought, color="green")
|
||||
action, action_input = get_action_and_input(thought)
|
||||
if action == FINAL_ANSWER_ACTION:
|
||||
return {self.output_key: action_input}
|
||||
chain = self.action_to_chain_map[action]
|
||||
ca = chain(action_input)
|
||||
chained_input.add("\nObservation: ")
|
||||
chained_input.add(ca, color=color_mapping[action])
|
||||
chained_input.add("\nThought:")
|
||||
|
||||
def run(self, _input: str) -> str:
|
||||
"""Run input through the MRKL system."""
|
||||
return self({self.input_key: _input})[self.output_key]
|
19
langchain/chains/mrkl/prompt.py
Normal file
19
langchain/chains/mrkl/prompt.py
Normal file
@ -0,0 +1,19 @@
|
||||
# flake8: noqa
|
||||
BASE_TEMPLATE = """Answer the following questions as best you can. You have access to the following tools:
|
||||
|
||||
{tools}
|
||||
|
||||
Use the following format:
|
||||
|
||||
Question: the input question you must answer
|
||||
Thought: you should always think about what to do
|
||||
Action: the action to take, should be one of [{tool_names}]
|
||||
Action Input: the input to the action
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Action Input/Observation can repeat N times)
|
||||
Thought: I now know the final answer
|
||||
Final Answer: the final answer to the original input question
|
||||
|
||||
Begin!
|
||||
|
||||
Question: {{input}}"""
|
@ -1,9 +1,20 @@
|
||||
"""Handle chained inputs."""
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
_COLOR_MAPPING = {"blue": 104, "yellow": 103, "red": 101, "green": 102}
|
||||
|
||||
|
||||
def get_color_mapping(
|
||||
items: List[str], excluded_colors: Optional[List] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Get mapping for items to a support color."""
|
||||
colors = list(_COLOR_MAPPING.keys())
|
||||
if excluded_colors is not None:
|
||||
colors = [c for c in colors if c not in excluded_colors]
|
||||
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||
return color_mapping
|
||||
|
||||
|
||||
def print_text(text: str, color: Optional[str] = None) -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
if color is None:
|
||||
|
70
tests/unit_tests/chains/test_mrkl.py
Normal file
70
tests/unit_tests/chains/test_mrkl.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""Test MRKL functionality."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.mrkl.base import ChainConfig, MRKLChain, get_action_and_input
|
||||
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
|
||||
from langchain.prompt import Prompt
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def test_get_action_and_input() -> None:
|
||||
"""Test getting an action from text."""
|
||||
llm_output = (
|
||||
"Thought: I need to search for NBA\n" "Action: Search\n" "Action Input: NBA"
|
||||
)
|
||||
action, action_input = get_action_and_input(llm_output)
|
||||
assert action == "Search"
|
||||
assert action_input == "NBA"
|
||||
|
||||
|
||||
def test_get_final_answer() -> None:
|
||||
"""Test getting final answer."""
|
||||
llm_output = (
|
||||
"Thought: I need to search for NBA\n"
|
||||
"Action: Search\n"
|
||||
"Action Input: NBA\n"
|
||||
"Observation: founded in 1994\n"
|
||||
"Thought: I can now answer the question\n"
|
||||
"Final Answer: 1994"
|
||||
)
|
||||
action, action_input = get_action_and_input(llm_output)
|
||||
assert action == "Final Answer: "
|
||||
assert action_input == "1994"
|
||||
|
||||
|
||||
def test_bad_action_input_line() -> None:
|
||||
"""Test handling when no action input found."""
|
||||
llm_output = "Thought: I need to search for NBA\n" "Action: Search\n" "Thought: NBA"
|
||||
with pytest.raises(ValueError):
|
||||
get_action_and_input(llm_output)
|
||||
|
||||
|
||||
def test_bad_action_line() -> None:
|
||||
"""Test handling when no action input found."""
|
||||
llm_output = (
|
||||
"Thought: I need to search for NBA\n" "Thought: Search\n" "Action Input: NBA"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
get_action_and_input(llm_output)
|
||||
|
||||
|
||||
def test_from_chains() -> None:
|
||||
"""Test initializing from chains."""
|
||||
chain_configs = [
|
||||
ChainConfig(
|
||||
action_name="foo", action=lambda x: "foo", action_description="foobar1"
|
||||
),
|
||||
ChainConfig(
|
||||
action_name="bar", action=lambda x: "bar", action_description="foobar2"
|
||||
),
|
||||
]
|
||||
mrkl_chain = MRKLChain.from_chains(FakeLLM(), chain_configs)
|
||||
expected_tools_prompt = "foo: foobar1\nbar: foobar2"
|
||||
expected_tool_names = "foo, bar"
|
||||
expected_template = BASE_TEMPLATE.format(
|
||||
tools=expected_tools_prompt, tool_names=expected_tool_names
|
||||
)
|
||||
prompt = mrkl_chain.prompt
|
||||
assert isinstance(prompt, Prompt)
|
||||
assert prompt.template == expected_template
|
@ -3,7 +3,7 @@
|
||||
import sys
|
||||
from io import StringIO
|
||||
|
||||
from langchain.input import ChainedInput
|
||||
from langchain.input import ChainedInput, get_color_mapping
|
||||
|
||||
|
||||
def test_chained_input_not_verbose() -> None:
|
||||
@ -50,3 +50,25 @@ def test_chained_input_verbose() -> None:
|
||||
output = mystdout.getvalue()
|
||||
assert output == "\x1b[104mbaz\x1b[0m"
|
||||
assert chained_input.input == "foobarbaz"
|
||||
|
||||
|
||||
def test_get_color_mapping() -> None:
|
||||
"""Test getting of color mapping."""
|
||||
# Test on few inputs.
|
||||
items = ["foo", "bar"]
|
||||
output = get_color_mapping(items)
|
||||
expected_output = {"foo": "blue", "bar": "yellow"}
|
||||
assert output == expected_output
|
||||
|
||||
# Test on a lot of inputs.
|
||||
items = [f"foo-{i}" for i in range(20)]
|
||||
output = get_color_mapping(items)
|
||||
assert len(output) == 20
|
||||
|
||||
|
||||
def test_get_color_mapping_excluded_colors() -> None:
|
||||
"""Test getting of color mapping with excluded colors."""
|
||||
items = ["foo", "bar"]
|
||||
output = get_color_mapping(items, excluded_colors=["blue"])
|
||||
expected_output = {"foo": "yellow", "bar": "red"}
|
||||
assert output == expected_output
|
||||
|
Loading…
Reference in New Issue
Block a user