establish router

harrison/router
Harrison Chase 2 years ago
parent 8869b0ab0e
commit 6f55fa8ba7

@ -32,7 +32,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 2,
"id": "07e96d99", "id": "07e96d99",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -64,7 +64,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 3,
"id": "a069c4b6", "id": "a069c4b6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -74,7 +74,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 4,
"id": "e603cd7d", "id": "e603cd7d",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -86,32 +86,32 @@
"\n", "\n",
"\u001b[1m> Entering new chain...\u001b[0m\n", "\u001b[1m> Entering new chain...\u001b[0m\n",
"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\n", "What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\n",
"Thought:\u001b[102m I need to find the age of Olivia Wilde's boyfriend\n", "Thought:\u001b[32;1m\u001b[1;3m I need to find the age of Olivia Wilde's boyfriend\n",
"Action: Search\n", "Action: Search\n",
"Action Input: \"Olivia Wilde's boyfriend\"\u001b[0m\n", "Action Input: \"Olivia Wilde's boyfriend\"\u001b[0m\n",
"Observation: \u001b[104mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n",
"Thought:\u001b[102m I need to find the age of Harry Styles\n", "Thought:\u001b[32;1m\u001b[1;3m I need to find the age of Harry Styles\n",
"Action: Search\n", "Action: Search\n",
"Action Input: \"Harry Styles age\"\u001b[0m\n", "Action Input: \"Harry Styles age\"\u001b[0m\n",
"Observation: \u001b[104m28 years\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3m28 years\u001b[0m\n",
"Thought:\u001b[102m I need to calculate 28 to the 0.23 power\n", "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 28 to the 0.23 power\n",
"Action: Calculator\n", "Action: Calculator\n",
"Action Input: 28^0.23\u001b[0m\n", "Action Input: 28^0.23\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Entering new chain...\u001b[0m\n", "\u001b[1m> Entering new chain...\u001b[0m\n",
"28^0.23\u001b[102m\n", "28^0.23\u001b[32;1m\u001b[1;3m\n",
"\n", "\n",
"```python\n", "```python\n",
"print(28**0.23)\n", "print(28**0.23)\n",
"```\n", "```\n",
"\u001b[0m\n", "\u001b[0m\n",
"Answer: \u001b[103m2.1520202182226886\n", "Answer: \u001b[33;1m\u001b[1;3m2.1520202182226886\n",
"\u001b[0m\n", "\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n",
"\n", "\n",
"Observation: \u001b[103mAnswer: 2.1520202182226886\n", "Observation: \u001b[33;1m\u001b[1;3mAnswer: 2.1520202182226886\n",
"\u001b[0m\n", "\u001b[0m\n",
"Thought:\u001b[102m I now know the final answer\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 2.1520202182226886\u001b[0m\n", "Final Answer: 2.1520202182226886\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -122,7 +122,7 @@
"'2.1520202182226886'" "'2.1520202182226886'"
] ]
}, },
"execution_count": 6, "execution_count": 4,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -133,7 +133,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 5,
"id": "a5c07010", "id": "a5c07010",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -145,35 +145,35 @@
"\n", "\n",
"\u001b[1m> Entering new chain...\u001b[0m\n", "\u001b[1m> Entering new chain...\u001b[0m\n",
"Who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\n", "Who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\n",
"Thought:\u001b[102m I need to find an album called 'The Storm Before the Calm'\n", "Thought:\u001b[32;1m\u001b[1;3m I need to find an album called 'The Storm Before the Calm'\n",
"Action: Search\n", "Action: Search\n",
"Action Input: \"The Storm Before the Calm album\"\u001b[0m\n", "Action Input: \"The Storm Before the Calm album\"\u001b[0m\n",
"Observation: \u001b[104mThe Storm Before the Calm (stylized in all lowercase) is the tenth (and eighth international) studio album by Canadian-American singer-songwriter Alanis ...\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3mThe Storm Before the Calm (stylized in all lowercase) is the tenth (and eighth international) studio album by Canadian-American singer-songwriter Alanis ...\u001b[0m\n",
"Thought:\u001b[102m I need to check if Alanis is in the FooBar database\n", "Thought:\u001b[32;1m\u001b[1;3m I need to check if Alanis is in the FooBar database\n",
"Action: FooBar DB\n", "Action: FooBar DB\n",
"Action Input: \"Does Alanis Morissette exist in the FooBar database?\"\u001b[0m\n", "Action Input: \"Does Alanis Morissette exist in the FooBar database?\"\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Entering new chain...\u001b[0m\n", "\u001b[1m> Entering new chain...\u001b[0m\n",
"Does Alanis Morissette exist in the FooBar database?\n", "Does Alanis Morissette exist in the FooBar database?\n",
"SQLQuery:\u001b[102m SELECT * FROM Artist WHERE Name = 'Alanis Morissette'\u001b[0m\n", "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT * FROM Artist WHERE Name = 'Alanis Morissette'\u001b[0m\n",
"SQLResult: \u001b[103m[(4, 'Alanis Morissette')]\u001b[0m\n", "SQLResult: \u001b[33;1m\u001b[1;3m[(4, 'Alanis Morissette')]\u001b[0m\n",
"Answer:\u001b[102m Yes\u001b[0m\n", "Answer:\u001b[32;1m\u001b[1;3m Yes\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n",
"\n", "\n",
"Observation: \u001b[101m Yes\u001b[0m\n", "Observation: \u001b[38;5;200m\u001b[1;3m Yes\u001b[0m\n",
"Thought:\u001b[102m I need to find out what albums of Alanis's are in the FooBar database\n", "Thought:\u001b[32;1m\u001b[1;3m I need to find out what albums of Alanis's are in the FooBar database\n",
"Action: FooBar DB\n", "Action: FooBar DB\n",
"Action Input: \"What albums by Alanis Morissette are in the FooBar database?\"\u001b[0m\n", "Action Input: \"What albums by Alanis Morissette are in the FooBar database?\"\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Entering new chain...\u001b[0m\n", "\u001b[1m> Entering new chain...\u001b[0m\n",
"What albums by Alanis Morissette are in the FooBar database?\n", "What albums by Alanis Morissette are in the FooBar database?\n",
"SQLQuery:\u001b[102m SELECT Title FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alanis Morissette')\u001b[0m\n", "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Album.Title FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alanis Morissette'\u001b[0m\n",
"SQLResult: \u001b[103m[('Jagged Little Pill',)]\u001b[0m\n", "SQLResult: \u001b[33;1m\u001b[1;3m[('Jagged Little Pill',)]\u001b[0m\n",
"Answer:\u001b[102m Jagged Little Pill\u001b[0m\n", "Answer:\u001b[32;1m\u001b[1;3m Jagged Little Pill\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n",
"\n", "\n",
"Observation: \u001b[101m Jagged Little Pill\u001b[0m\n", "Observation: \u001b[38;5;200m\u001b[1;3m Jagged Little Pill\u001b[0m\n",
"Thought:\u001b[102m I now know the final answer\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: The album is by Alanis Morissette and the albums in the FooBar database by her are Jagged Little Pill\u001b[0m\n", "Final Answer: The album is by Alanis Morissette and the albums in the FooBar database by her are Jagged Little Pill\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -184,7 +184,7 @@
"'The album is by Alanis Morissette and the albums in the FooBar database by her are Jagged Little Pill'" "'The album is by Alanis Morissette and the albums in the FooBar database by her are Jagged Little Pill'"
] ]
}, },
"execution_count": 10, "execution_count": 5,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }

@ -37,14 +37,14 @@
"\n", "\n",
"\u001b[1m> Entering new chain...\u001b[0m\n", "\u001b[1m> Entering new chain...\u001b[0m\n",
"Author David Chanoff has collaborated with a U.S. Navy admiral who served as the ambassador to the United Kingdom under which President?\n", "Author David Chanoff has collaborated with a U.S. Navy admiral who served as the ambassador to the United Kingdom under which President?\n",
"Thought 1:\u001b[102m I need to search David Chanoff and find the U.S. Navy admiral he\n", "Thought 1:\u001b[32;1m\u001b[1;3m I need to search David Chanoff and find the U.S. Navy admiral he collaborated\n",
"collaborated with.\n", "with.\n",
"Action 1: Search[David Chanoff]\u001b[0m\n", "Action 1: Search[David Chanoff]\u001b[0m\n",
"Observation 1: \u001b[103mDavid Chanoff is a noted author of non-fiction work. His work has typically involved collaborations with the principal protagonist of the work concerned. His collaborators have included; Augustus A. White, Joycelyn Elders, Đoàn Văn Toại, William J. Crowe, Ariel Sharon, Kenneth Good and Felix Zandman. He has also written about a wide range of subjects including literary history, education and foreign for The Washington Post, The New Republic and The New York Times Magazine. He has published more than twelve books.\u001b[0m\n", "Observation 1: \u001b[33;1m\u001b[1;3mDavid Chanoff is a noted author of non-fiction work. His work has typically involved collaborations with the principal protagonist of the work concerned. His collaborators have included; Augustus A. White, Joycelyn Elders, Đoàn Văn Toại, William J. Crowe, Ariel Sharon, Kenneth Good and Felix Zandman. He has also written about a wide range of subjects including literary history, education and foreign for The Washington Post, The New Republic and The New York Times Magazine. He has published more than twelve books.\u001b[0m\n",
"Thought 2:\u001b[102m The U.S. Navy admiral David Chanoff collaborated with is William J. Crowe.\n", "Thought 2:\u001b[32;1m\u001b[1;3m The U.S. Navy admiral David Chanoff collaborated with is William J. Crowe.\n",
"Action 2: Search[William J. Crowe]\u001b[0m\n", "Action 2: Search[William J. Crowe]\u001b[0m\n",
"Observation 2: \u001b[103mWilliam James Crowe Jr. (January 2, 1925 October 18, 2007) was a United States Navy admiral and diplomat who served as the 11th chairman of the Joint Chiefs of Staff under Presidents Ronald Reagan and George H. W. Bush, and as the ambassador to the United Kingdom and Chair of the Intelligence Oversight Board under President Bill Clinton.\u001b[0m\n", "Observation 2: \u001b[33;1m\u001b[1;3mWilliam James Crowe Jr. (January 2, 1925 October 18, 2007) was a United States Navy admiral and diplomat who served as the 11th chairman of the Joint Chiefs of Staff under Presidents Ronald Reagan and George H. W. Bush, and as the ambassador to the United Kingdom and Chair of the Intelligence Oversight Board under President Bill Clinton.\u001b[0m\n",
"Thought 3:\u001b[102m William J. Crowe served as the ambassador to the United Kingdom under President Bill Clinton. So the answer is Bill Clinton.\n", "Thought 3:\u001b[32;1m\u001b[1;3m William J. Crowe served as the ambassador to the United Kingdom under President Bill Clinton.\n",
"Action 3: Finish[Bill Clinton]\u001b[0m\n", "Action 3: Finish[Bill Clinton]\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -68,7 +68,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "0a6bd3b4", "id": "b5a265c6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []

@ -24,19 +24,19 @@
"\n", "\n",
"\u001b[1m> Entering new chain...\u001b[0m\n", "\u001b[1m> Entering new chain...\u001b[0m\n",
"What is the hometown of the reigning men's U.S. Open champion?\n", "What is the hometown of the reigning men's U.S. Open champion?\n",
"Are follow up questions needed here:\u001b[102m Yes.\n", "Are follow up questions needed here:\u001b[32;1m\u001b[1;3m Yes.\n",
"Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n", "Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n",
"Intermediate answer: \u001b[103mCarlos Alcaraz won the 2022 Men's single title while Poland's Iga Swiatek won the Women's single title defeating Tunisian's Ons Jabeur..\u001b[0m\u001b[102m\n", "Intermediate answer: \u001b[33;1m\u001b[1;3mCarlos Alcaraz.\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Follow up: Where is Carlos Alcaraz from?\u001b[0m\n", "Follow up: Where is Carlos Alcaraz from?\u001b[0m\n",
"Intermediate answer: \u001b[103mEl Palmar, Murcia, Spain.\u001b[0m\u001b[102m\n", "Intermediate answer: \u001b[33;1m\u001b[1;3mEl Palmar, Spain.\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"So the final answer is: El Palmar, Murcia, Spain\u001b[0m\n", "So the final answer is: El Palmar, Spain\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
}, },
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'\\nSo the final answer is: El Palmar, Murcia, Spain'" "'\\nSo the final answer is: El Palmar, Spain'"
] ]
}, },
"execution_count": 1, "execution_count": 1,

@ -1,14 +1,15 @@
"""Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf.""" """Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf."""
from typing import Any, Callable, Dict, List, NamedTuple, Tuple from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.mrkl.prompt import BASE_TEMPLATE from langchain.chains.mrkl.prompt import BASE_TEMPLATE
from langchain.chains.router import LLMRouterChain
from langchain.input import ChainedInput, get_color_mapping from langchain.input import ChainedInput, get_color_mapping
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.prompts import BasePrompt, Prompt from langchain.prompts import Prompt
FINAL_ANSWER_ACTION = "Final Answer: " FINAL_ANSWER_ACTION = "Final Answer: "
@ -48,6 +49,25 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]:
return action, action_input.strip(" ").strip('"') return action, action_input.strip(" ").strip('"')
class MRKLRouterChain(LLMRouterChain):
"""Router for the MRKL chain."""
def __init__(self, llm: LLM, chain_configs: List[ChainConfig], **kwargs: Any):
"""Initialize with an LLM and the chain configs it has access to."""
tools = "\n".join(
[f"{c.action_name}: {c.action_description}" for c in chain_configs]
)
tool_names = ", ".join([chain.action_name for chain in chain_configs])
template = BASE_TEMPLATE.format(tools=tools, tool_names=tool_names)
prompt = Prompt(template=template, input_variables=["input"])
llm_chain = LLMChain(llm=llm, prompt=prompt)
stops = ["\nObservation"]
super().__init__(llm_chain=llm_chain, stops=stops, **kwargs)
def _extract_action_and_input(self, text: str) -> Optional[Tuple[str, str]]:
return get_action_and_input(text)
class MRKLChain(Chain, BaseModel): class MRKLChain(Chain, BaseModel):
"""Chain that implements the MRKL system. """Chain that implements the MRKL system.
@ -68,8 +88,8 @@ class MRKLChain(Chain, BaseModel):
llm: LLM llm: LLM
"""LLM wrapper to use as router.""" """LLM wrapper to use as router."""
prompt: BasePrompt chain_configs: List[ChainConfig]
"""Prompt to use as router.""" """Chain configs this chain has access to."""
action_to_chain_map: Dict[str, Callable] action_to_chain_map: Dict[str, Callable]
"""Mapping from action name to chain to execute.""" """Mapping from action name to chain to execute."""
input_key: str = "question" #: :meta private: input_key: str = "question" #: :meta private:
@ -114,15 +134,12 @@ class MRKLChain(Chain, BaseModel):
] ]
mrkl = MRKLChain.from_chains(llm, chains) mrkl = MRKLChain.from_chains(llm, chains)
""" """
tools = "\n".join(
[f"{chain.action_name}: {chain.action_description}" for chain in chains]
)
tool_names = ", ".join([chain.action_name for chain in chains])
template = BASE_TEMPLATE.format(tools=tools, tool_names=tool_names)
prompt = Prompt(template=template, input_variables=["input"])
action_to_chain_map = {chain.action_name: chain.action for chain in chains} action_to_chain_map = {chain.action_name: chain.action for chain in chains}
return cls( return cls(
llm=llm, prompt=prompt, action_to_chain_map=action_to_chain_map, **kwargs llm=llm,
chain_configs=chains,
action_to_chain_map=action_to_chain_map,
**kwargs,
) )
class Config: class Config:
@ -148,7 +165,7 @@ class MRKLChain(Chain, BaseModel):
return [self.output_key] return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) router_chain = MRKLRouterChain(self.llm, self.chain_configs)
chained_input = ChainedInput( chained_input = ChainedInput(
f"{inputs[self.input_key]}\nThought:", verbose=self.verbose f"{inputs[self.input_key]}\nThought:", verbose=self.verbose
) )
@ -156,11 +173,10 @@ class MRKLChain(Chain, BaseModel):
list(self.action_to_chain_map.keys()), excluded_colors=["green"] list(self.action_to_chain_map.keys()), excluded_colors=["green"]
) )
while True: while True:
thought = llm_chain.predict( action, action_input, thought = router_chain.get_action_and_input(
input=chained_input.input, stop=["\nObservation"] chained_input.input
) )
chained_input.add(thought, color="green") chained_input.add(thought, color="green")
action, action_input = get_action_and_input(thought)
if action == FINAL_ANSWER_ACTION: if action == FINAL_ANSWER_ACTION:
return {self.output_key: action_input} return {self.output_key: action_input}
chain = self.action_to_chain_map[action] chain = self.action_to_chain_map[action]

@ -1,38 +1,47 @@
"""Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf.""" """Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf."""
import re import re
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.react.prompt import PROMPT from langchain.chains.react.prompt import PROMPT
from langchain.chains.router import LLMRouterChain
from langchain.docstore.base import Docstore from langchain.docstore.base import Docstore
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.input import ChainedInput from langchain.input import ChainedInput
from langchain.llms.base import LLM from langchain.llms.base import LLM
def predict_until_observation( class ReActRouterChain(LLMRouterChain, BaseModel):
llm_chain: LLMChain, prompt: str, i: int """Router for the ReAct chin."""
) -> Tuple[str, str, str]:
"""Generate text until an observation is needed.""" i: int = 1
action_prefix = f"Action {i}: "
stop_seq = f"\nObservation {i}:" def __init__(self, llm: LLM, **kwargs: Any):
ret_text = llm_chain.predict(input=prompt, stop=[stop_seq]) """Initialize with the language model."""
# Sometimes the LLM forgets to take an action, so we prompt it to. llm_chain = LLMChain(llm=llm, prompt=PROMPT)
while not ret_text.split("\n")[-1].startswith(action_prefix): stops = ["\nObservation 1:"]
ret_text += f"\nAction {i}:" super().__init__(llm_chain=llm_chain, stops=stops, **kwargs)
new_text = llm_chain.predict(input=prompt + ret_text, stop=[stop_seq])
ret_text += new_text def _fix_text(self, text: str) -> str:
# The action block should be the last line. return text + f"\nAction {self.i}:"
action_block = ret_text.split("\n")[-1]
action_str = action_block[len(action_prefix) :] def _extract_action_and_input(self, text: str) -> Optional[Tuple[str, str]]:
# Parse out the action and the directive. action_prefix = f"Action {self.i}: "
re_matches = re.search(r"(.*?)\[(.*?)\]", action_str) if not text.split("\n")[-1].startswith(action_prefix):
if re_matches is None: return None
raise ValueError(f"Could not parse action directive: {action_str}") self.i += 1
return ret_text, re_matches.group(1), re_matches.group(2) self.stops = [f"\nObservation {self.i}:"]
action_block = text.split("\n")[-1]
action_str = action_block[len(action_prefix) :]
# Parse out the action and the directive.
re_matches = re.search(r"(.*?)\[(.*?)\]", action_str)
if re_matches is None:
raise ValueError(f"Could not parse action directive: {action_str}")
return re_matches.group(1), re_matches.group(2)
class ReActChain(Chain, BaseModel): class ReActChain(Chain, BaseModel):
@ -76,13 +85,12 @@ class ReActChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
question = inputs[self.input_key] question = inputs[self.input_key]
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT) router_chain = ReActRouterChain(self.llm)
chained_input = ChainedInput(f"{question}\nThought 1:", verbose=self.verbose) chained_input = ChainedInput(f"{question}\nThought 1:", verbose=self.verbose)
i = 1
document = None document = None
while True: while True:
ret_text, action, directive = predict_until_observation( action, directive, ret_text = router_chain.get_action_and_input(
llm_chain, chained_input.input, i chained_input.input
) )
chained_input.add(ret_text, color="green") chained_input.add(ret_text, color="green")
if action == "Search": if action == "Search":
@ -101,7 +109,6 @@ class ReActChain(Chain, BaseModel):
return {self.output_key: directive} return {self.output_key: directive}
else: else:
raise ValueError(f"Got unknown action directive: {action}") raise ValueError(f"Got unknown action directive: {action}")
chained_input.add(f"\nObservation {i}: ") chained_input.add(f"\nObservation {router_chain.i - 1}: ")
chained_input.add(observation, color="yellow") chained_input.add(observation, color="yellow")
chained_input.add(f"\nThought {i + 1}:") chained_input.add(f"\nThought {router_chain.i}:")
i += 1

@ -0,0 +1,75 @@
"""Chain that takes in an input and produces an action and action input."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
class RouterChain(Chain, BaseModel, ABC):
"""Chain responsible for deciding the action to take."""
input_key: str = "input_text" #: :meta private:
action_key: str = "action" #: :meta private:
action_input_key: str = "action_input" #: :meta private:
log_key: str = "log" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Will be the input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return three keys: the action, the action input, and the log.
:meta private:
"""
return [self.action_key, self.action_input_key, self.log_key]
@abstractmethod
def get_action_and_input(self, text: str) -> Tuple[str, str, str]:
"""Return action, action input, and log (in that order)."""
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
action, action_input, log = self.get_action_and_input(inputs[self.input_key])
return {
self.action_key: action,
self.action_input_key: action_input,
self.log_key: log,
}
class LLMRouterChain(RouterChain, BaseModel, ABC):
"""RouterChain that uses an LLM."""
llm_chain: LLMChain
stops: Optional[List[str]]
@abstractmethod
def _extract_action_and_input(self, text: str) -> Optional[Tuple[str, str]]:
"""Extract action and action input from llm output."""
def _fix_text(self, text: str) -> str:
"""Fix the text."""
raise ValueError("fix_text not implemented for this router.")
def get_action_and_input(self, text: str) -> Tuple[str, str, str]:
"""Return action, action input, and log (in that order)."""
input_key = self.llm_chain.input_keys[0]
inputs = {input_key: text, "stop": self.stops}
full_output = self.llm_chain.predict(**inputs)
parsed_output = self._extract_action_and_input(full_output)
while parsed_output is None:
full_output = self._fix_text(full_output)
inputs = {input_key: text + full_output, "stop": self.stops}
output = self.llm_chain.predict(**inputs)
full_output += output
parsed_output = self._extract_action_and_input(full_output)
action, action_input = parsed_output
return action, action_input, full_output

@ -1,77 +1,46 @@
"""Chain that does self ask with search.""" """Chain that does self ask with search."""
from typing import Any, Dict, List from typing import Any, Dict, List, Tuple
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.router import LLMRouterChain
from langchain.chains.self_ask_with_search.prompt import PROMPT from langchain.chains.self_ask_with_search.prompt import PROMPT
from langchain.chains.serpapi import SerpAPIChain from langchain.chains.serpapi import SerpAPIChain
from langchain.input import ChainedInput from langchain.input import ChainedInput
from langchain.llms.base import LLM from langchain.llms.base import LLM
def extract_answer(generated: str) -> str: class SelfAskWithSearchRouter(LLMRouterChain):
"""Extract answer from text.""" """Router for the self-ask-with-search paper."""
if "\n" not in generated:
last_line = generated
else:
last_line = generated.split("\n")[-1]
if ":" not in last_line: def __init__(self, llm: LLM, **kwargs: Any):
after_colon = last_line """Initialize with an LLM."""
else: llm_chain = LLMChain(llm=llm, prompt=PROMPT)
after_colon = generated.split(":")[-1] super().__init__(llm_chain=llm_chain, **kwargs)
if " " == after_colon[0]: def _extract_action_and_input(self, text: str) -> Tuple[str, str]:
after_colon = after_colon[1:] followup = "Follow up:"
if "." == after_colon[-1]: if "\n" not in text:
after_colon = after_colon[:-1] last_line = text
else:
return after_colon last_line = text.split("\n")[-1]
def extract_question(generated: str, followup: str) -> str:
"""Extract question from text."""
if "\n" not in generated:
last_line = generated
else:
last_line = generated.split("\n")[-1]
if followup not in last_line:
print("we probably should never get here..." + generated)
if ":" not in last_line:
after_colon = last_line
else:
after_colon = generated.split(":")[-1]
if " " == after_colon[0]:
after_colon = after_colon[1:]
if "?" != after_colon[-1]:
print("we probably should never get here..." + generated)
return after_colon
def get_last_line(generated: str) -> str:
"""Get the last line in text."""
if "\n" not in generated:
last_line = generated
else:
last_line = generated.split("\n")[-1]
return last_line
if followup not in last_line:
return "Final Answer", text
def greenify(_input: str) -> str: if ":" not in last_line:
"""Add green highlighting to text.""" after_colon = last_line
return "\x1b[102m" + _input + "\x1b[0m" else:
after_colon = text.split(":")[-1]
if " " == after_colon[0]:
after_colon = after_colon[1:]
if "?" != after_colon[-1]:
print("we probably should never get here..." + text)
def yellowfy(_input: str) -> str: return "Intermediate Answer", after_colon
"""Add yellow highlighting to text."""
return "\x1b[106m" + _input + "\x1b[0m"
class SelfAskWithSearchChain(Chain, BaseModel): class SelfAskWithSearchChain(Chain, BaseModel):
@ -117,33 +86,14 @@ class SelfAskWithSearchChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
chained_input = ChainedInput(inputs[self.input_key], verbose=self.verbose) chained_input = ChainedInput(inputs[self.input_key], verbose=self.verbose)
chained_input.add("\nAre follow up questions needed here:") chained_input.add("\nAre follow up questions needed here:")
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
intermediate = "\nIntermediate answer:" intermediate = "\nIntermediate answer:"
followup = "Follow up:" router = SelfAskWithSearchRouter(self.llm, stops=[intermediate])
finalans = "\nSo the final answer is:" action, action_input, log = router.get_action_and_input(chained_input.input)
ret_text = llm_chain.predict(input=chained_input.input, stop=[intermediate]) chained_input.add(log, color="green")
chained_input.add(ret_text, color="green") while action != "Final Answer":
while followup in get_last_line(ret_text): external_answer = self.search_chain.run(action_input)
question = extract_question(ret_text, followup) chained_input.add(intermediate + " ")
external_answer = self.search_chain.run(question) chained_input.add(external_answer + ".", color="yellow")
if external_answer is not None: action, action_input, log = router.get_action_and_input(chained_input.input)
chained_input.add(intermediate + " ") chained_input.add(log, color="green")
chained_input.add(external_answer + ".", color="yellow") return {self.output_key: action_input}
ret_text = llm_chain.predict(
input=chained_input.input, stop=["\nIntermediate answer:"]
)
chained_input.add(ret_text, color="green")
else:
# We only get here in the very rare case that Google returns no answer.
chained_input.add(intermediate + " ")
preds = llm_chain.predict(
input=chained_input.input, stop=["\n" + followup, finalans]
)
chained_input.add(preds, color="green")
if finalans not in ret_text:
chained_input.add(finalans)
ret_text = llm_chain.predict(input=chained_input.input, stop=["\n"])
chained_input.add(ret_text, color="green")
return {self.output_key: ret_text}

@ -2,7 +2,11 @@
import pytest import pytest
from langchain.chains.mrkl.base import ChainConfig, MRKLChain, get_action_and_input from langchain.chains.mrkl.base import (
ChainConfig,
MRKLRouterChain,
get_action_and_input,
)
from langchain.chains.mrkl.prompt import BASE_TEMPLATE from langchain.chains.mrkl.prompt import BASE_TEMPLATE
from langchain.prompts import Prompt from langchain.prompts import Prompt
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_llm import FakeLLM
@ -59,12 +63,12 @@ def test_from_chains() -> None:
action_name="bar", action=lambda x: "bar", action_description="foobar2" action_name="bar", action=lambda x: "bar", action_description="foobar2"
), ),
] ]
mrkl_chain = MRKLChain.from_chains(FakeLLM(), chain_configs) router_chain = MRKLRouterChain(FakeLLM(), chain_configs)
expected_tools_prompt = "foo: foobar1\nbar: foobar2" expected_tools_prompt = "foo: foobar1\nbar: foobar2"
expected_tool_names = "foo, bar" expected_tool_names = "foo, bar"
expected_template = BASE_TEMPLATE.format( expected_template = BASE_TEMPLATE.format(
tools=expected_tools_prompt, tool_names=expected_tool_names tools=expected_tools_prompt, tool_names=expected_tool_names
) )
prompt = mrkl_chain.prompt prompt = router_chain.llm_chain.prompt
assert isinstance(prompt, Prompt) assert isinstance(prompt, Prompt)
assert prompt.template == expected_template assert prompt.template == expected_template

@ -4,8 +4,7 @@ from typing import Any, List, Mapping, Optional, Union
import pytest import pytest
from langchain.chains.llm import LLMChain from langchain.chains.react.base import ReActChain, ReActRouterChain
from langchain.chains.react.base import ReActChain, predict_until_observation
from langchain.docstore.base import Docstore from langchain.docstore.base import Docstore
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.llms.base import LLM from langchain.llms.base import LLM
@ -53,8 +52,8 @@ def test_predict_until_observation_normal() -> None:
"""Test predict_until_observation when observation is made normally.""" """Test predict_until_observation when observation is made normally."""
outputs = ["foo\nAction 1: search[foo]"] outputs = ["foo\nAction 1: search[foo]"]
fake_llm = FakeListLLM(outputs) fake_llm = FakeListLLM(outputs)
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT) router_chain = ReActRouterChain(llm=fake_llm)
ret_text, action, directive = predict_until_observation(fake_llm_chain, "", 1) action, directive, ret_text = router_chain.get_action_and_input("")
assert ret_text == outputs[0] assert ret_text == outputs[0]
assert action == "search" assert action == "search"
assert directive == "foo" assert directive == "foo"
@ -64,22 +63,13 @@ def test_predict_until_observation_repeat() -> None:
"""Test when no action is generated initially.""" """Test when no action is generated initially."""
outputs = ["foo", " search[foo]"] outputs = ["foo", " search[foo]"]
fake_llm = FakeListLLM(outputs) fake_llm = FakeListLLM(outputs)
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT) router_chain = ReActRouterChain(llm=fake_llm)
ret_text, action, directive = predict_until_observation(fake_llm_chain, "", 1) action, directive, ret_text = router_chain.get_action_and_input("")
assert ret_text == "foo\nAction 1: search[foo]" assert ret_text == "foo\nAction 1: search[foo]"
assert action == "search" assert action == "search"
assert directive == "foo" assert directive == "foo"
def test_predict_until_observation_error() -> None:
"""Test handling of generation of text that cannot be parsed."""
outputs = ["foo\nAction 1: foo"]
fake_llm = FakeListLLM(outputs)
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
with pytest.raises(ValueError):
predict_until_observation(fake_llm_chain, "", 1)
def test_react_chain() -> None: def test_react_chain() -> None:
"""Test react chain.""" """Test react chain."""
responses = [ responses = [

Loading…
Cancel
Save