harrison/router_docs
Harrison Chase 2 years ago
parent 4ccb9b684a
commit 505cb2eb62

@ -27,7 +27,7 @@
"outputs": [],
"source": [
"from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain, SQLDatabase, SQLDatabaseChain\n",
"from langchain.chains.mrkl.base import ChainConfig"
"from langchain.routing_chains.mrkl.base import ChainConfig"
]
},
{
@ -167,7 +167,7 @@
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"What albums by Alanis Morissette are in the FooBar database?\n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Album.Title FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alanis Morissette'\u001b[0m\n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Album.Title FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = \"Alanis Morissette\"\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[('Jagged Little Pill',)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m Jagged Little Pill\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",

@ -41,10 +41,12 @@
"with.\n",
"Action 1: Search[David Chanoff]\u001b[0m\n",
"Observation 1: \u001b[36;1m\u001b[1;3mDavid Chanoff is a noted author of non-fiction work. His work has typically involved collaborations with the principal protagonist of the work concerned. His collaborators have included; Augustus A. White, Joycelyn Elders, Đoàn Văn Toại, William J. Crowe, Ariel Sharon, Kenneth Good and Felix Zandman. He has also written about a wide range of subjects including literary history, education and foreign for The Washington Post, The New Republic and The New York Times Magazine. He has published more than twelve books.\u001b[0m\n",
"Thought 2:\u001b[32;1m\u001b[1;3m The U.S. Navy admiral David Chanoff collaborated with is William J. Crowe.\n",
"Thought 2:\u001b[32;1m\u001b[1;3m The U.S. Navy admiral David Chanoff collaborated with is William J. Crowe. I\n",
"need to search him next.\n",
"Action 2: Search[William J. Crowe]\u001b[0m\n",
"Observation 2: \u001b[36;1m\u001b[1;3mWilliam James Crowe Jr. (January 2, 1925 October 18, 2007) was a United States Navy admiral and diplomat who served as the 11th chairman of the Joint Chiefs of Staff under Presidents Ronald Reagan and George H. W. Bush, and as the ambassador to the United Kingdom and Chair of the Intelligence Oversight Board under President Bill Clinton.\u001b[0m\n",
"Thought 3:\u001b[32;1m\u001b[1;3m William J. Crowe served as the ambassador to the United Kingdom under President Bill Clinton.\n",
"Thought 3:\u001b[32;1m\u001b[1;3m William J. Crowe served as the ambassador to the United Kingdom under\n",
"President Bill Clinton. So the answer is Bill Clinton.\n",
"Action 3: Finish[Bill Clinton]\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -68,7 +70,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "3cb9d77c",
"id": "4ff64e81",
"metadata": {},
"outputs": [],
"source": []

@ -13,7 +13,7 @@
},
{
"cell_type": "markdown",
"id": "4b21ae68",
"id": "3c6226b9",
"metadata": {},
"source": [
"## Concepts\n",
@ -26,20 +26,19 @@
},
{
"cell_type": "markdown",
"id": "c9677868",
"id": "05d4b21e",
"metadata": {},
"source": [
"## Tools\n",
"When constructing your own Routing Chain, you will need to provide it with a list of tools that it can use. This is done with a list of ToolConfigs. The ToolConfig is used not only to create the Routing Chain, but is also sometimes used to create the router itself (often, the router logic depends on the tools available). \n",
"When constructing your own Routing Chain, you will need to provide it with a list of tools that it can use. This is done with a list of Tools. The Tools are used not only to create the Routing Chain, but is also sometimes used to create the router itself (often, the router logic depends on the tools available). \n",
"\n",
"```python\n",
"class ToolConfig(NamedTuple):\n",
" \"\"\"Configuration for tools.\"\"\"\n",
"class Tool(NamedTuple):\n",
" \"\"\"Interface for tools.\"\"\"\n",
"\n",
" tool_name: str\n",
" tool: Callable[[str], str]\n",
" # Needed to construct some routers.\n",
" tool_description: Optional[str] = None\n",
" name: str\n",
" func: Callable[[str], str]\n",
" description: Optional[str] = None\n",
"```\n",
"\n",
"The two required components of a ToolConfig are the name and then the tool itself. A tool description is optional, as it is needed for some routers but not all."
@ -50,22 +49,123 @@
"id": "2558a02d",
"metadata": {},
"source": [
"## Loading the chains\n",
"\n",
"```python\n",
"from langchain.routing_chains import load_chain\n",
"\n",
"tools: List[ToolConfig] = [...]\n",
"llm: LLM = OpenAI(temperature=0)\n",
"router_type: str = \"zero-shot\"\n",
"chain = load_chain(tools, llm, router_type, verbose=True)\n",
"```"
"## Loading the chains\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "36ed392e",
"metadata": {},
"outputs": [],
"source": [
"# Import things that are needed generically\n",
"from langchain.routing_chains import load_routing_chain, Tool\n",
"from langchain.llms import OpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "56ff7670",
"metadata": {},
"outputs": [],
"source": [
"# Load the tool configs that are needed.\n",
"from langchain import LLMMathChain, SerpAPIChain\n",
"llm = OpenAI(temperature=0)\n",
"search = SerpAPIChain()\n",
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
"tools = [\n",
" Tool(\n",
" name = \"Search\",\n",
" func=search.run,\n",
" description=\"useful for when you need to answer questions about current events\"\n",
" ),\n",
" Tool(\n",
" name=\"Calculator\",\n",
" func=llm_math_chain.run,\n",
" description=\"useful for when you need to answer questions about math\"\n",
" )\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5b93047d",
"metadata": {},
"outputs": [],
"source": [
"# Construct the routing chain. We will use the default router type here.\n",
"# See documentation for a full list of options.\n",
"router_llm = OpenAI(temperature=0)\n",
"chain = load_routing_chain(tools, router_llm, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "6f96a891",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find the age of Olivia Wilde's boyfriend\n",
"Action: Search\n",
"Action Input: \"Olivia Wilde's boyfriend\"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find the age of Harry Styles\n",
"Action: Search\n",
"Action Input: \"Harry Styles age\"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m28 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 28 to the 0.23 power\n",
"Action: Calculator\n",
"Action Input: 28^0.23\u001b[0m\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"28^0.23\u001b[32;1m\u001b[1;3m\n",
"\n",
"```python\n",
"print(28**0.23)\n",
"```\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m2.1520202182226886\n",
"\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 2.1520202182226886\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 2.1520202182226886\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'2.1520202182226886'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run(\"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d3254c6c",
"id": "2f0852ff",
"metadata": {},
"outputs": [],
"source": []

@ -1,9 +1,11 @@
"""Routing chains."""
from langchain.routing_chains.loading import load_routing_chain
from langchain.routing_chains.mrkl.base import MRKLChain
from langchain.routing_chains.react.base import ReActChain
from langchain.routing_chains.router import LLMRouter
from langchain.routing_chains.routing_chain import RoutingChain
from langchain.routing_chains.self_ask_with_search.base import SelfAskWithSearchChain
from langchain.routing_chains.tools import Tool
__all__ = [
"MRKLChain",
@ -11,4 +13,6 @@ __all__ = [
"ReActChain",
"LLMRouter",
"RoutingChain",
"Tool",
"load_routing_chain",
]

@ -0,0 +1,43 @@
"""Load routing chains."""
from typing import Any, List
from langchain.llms.base import LLM
from langchain.routing_chains.mrkl.base import ZeroShotRouter
from langchain.routing_chains.react.base import ReActDocstoreRouter
from langchain.routing_chains.routing_chain import RoutingChain
from langchain.routing_chains.self_ask_with_search.base import SelfAskWithSearchRouter
from langchain.routing_chains.tools import Tool
ROUTER_TYPE_TO_CLASS = {
"zero-shot-react-description": ZeroShotRouter,
"react-docstore": ReActDocstoreRouter,
"self-ask-with-search": SelfAskWithSearchRouter,
}
def load_routing_chain(
tools: List[Tool],
llm: LLM,
router_type: str = "zero-shot-react-description",
**kwargs: Any,
) -> RoutingChain:
"""Load routing chain given tools and LLM.
Args:
tools: List of tools this routing chain has access to.
llm: Language model to use as the router.
router_type: The router to use. Valid options are:
`zero-shot-react-description`.
**kwargs: Additional key word arguments to pass to the routing chain.
Returns:
A routing chain.
"""
if router_type not in ROUTER_TYPE_TO_CLASS:
raise ValueError(
f"Got unknown router type: {router_type}. "
f"Valid types are: {ROUTER_TYPE_TO_CLASS.keys()}."
)
router_cls = ROUTER_TYPE_TO_CLASS[router_type]
router = router_cls.from_llm_and_tools(llm, tools)
return RoutingChain(router=router, tools=tools, **kwargs)

@ -6,7 +6,8 @@ from langchain.llms.base import LLM
from langchain.prompts import PromptTemplate
from langchain.routing_chains.mrkl.prompt import BASE_TEMPLATE
from langchain.routing_chains.router import LLMRouter
from langchain.routing_chains.routing_chain import RoutingChain, ToolConfig
from langchain.routing_chains.routing_chain import RoutingChain
from langchain.routing_chains.tools import Tool
FINAL_ANSWER_ACTION = "Final Answer: "
@ -46,7 +47,7 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]:
return action, action_input.strip(" ").strip('"')
class MRKLRouterChain(LLMRouter):
class ZeroShotRouter(LLMRouter):
"""Router for the MRKL chain."""
@property
@ -59,17 +60,15 @@ class MRKLRouterChain(LLMRouter):
"""Prefix to append the router call with."""
return "Thought:"
def __init__(self, llm: LLM, chain_configs: List[ChainConfig], **kwargs: Any):
"""Initialize with an LLM and the chain configs it has access to."""
tools = "\n".join(
[f"{c.action_name}: {c.action_description}" for c in chain_configs]
)
tool_names = ", ".join([chain.action_name for chain in chain_configs])
template = BASE_TEMPLATE.format(tools=tools, tool_names=tool_names)
@classmethod
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> "ZeroShotRouter":
"""Construct a router from an LLM and tools."""
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
template = BASE_TEMPLATE.format(tools=tool_strings, tool_names=tool_names)
prompt = PromptTemplate(template=template, input_variables=["input"])
llm_chain = LLMChain(llm=llm, prompt=prompt)
stops = ["\nObservation"]
super().__init__(llm_chain=llm_chain, stops=stops, **kwargs)
return cls(llm_chain=llm_chain)
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
return get_action_and_input(text)
@ -128,8 +127,50 @@ class MRKLChain(RoutingChain):
]
mrkl = MRKLChain.from_chains(llm, chains)
"""
router_chain = MRKLRouterChain(llm, chains)
expert_configs = [
ToolConfig(tool_name=c.action_name, tool=c.action) for c in chains
tools = [
Tool(name=c.action_name, func=c.action, description=c.action_description)
for c in chains
]
return cls(router_chain=router_chain, expert_configs=expert_configs, **kwargs)
return cls.from_tools_and_llm(tools, llm, **kwargs)
@classmethod
def from_tools_and_llm(
cls, tools: List[Tool], llm: LLM, **kwargs: Any
) -> "MRKLChain":
"""User friendly way to initialize the MRKL chain.
This is intended to be an easy way to get up and running with the
MRKL chain.
Args:
tools: The tools the MRKL system has access to.
llm: The LLM to use as the router LLM.
**kwargs: parameters to be passed to initialization.
Returns:
An initialized MRKL chain.
Example:
.. code-block:: python
from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain
from langchain.routing_chains.tools import ToolConfig
llm = OpenAI(temperature=0)
search = SerpAPIChain()
llm_math_chain = LLMMathChain(llm=llm)
tools = [
ToolConfig(
tool_name = "Search",
tool=search.search,
tool_description="useful for searching"
),
ToolConfig(
tool_name="Calculator",
tool=llm_math_chain.run,
tool_description="useful for doing math"
)
]
mrkl = MRKLChain.from_tools_and_llm(llm, tools)
"""
router = ZeroShotRouter.from_llm_and_tools(llm, tools)
return cls(router=router, tools=tools, **kwargs)

@ -1,6 +1,6 @@
"""Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf."""
import re
from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple
from pydantic import BaseModel
@ -10,19 +10,28 @@ from langchain.docstore.document import Document
from langchain.llms.base import LLM
from langchain.routing_chains.react.prompt import PROMPT
from langchain.routing_chains.router import LLMRouter
from langchain.routing_chains.routing_chain import RoutingChain, ToolConfig
from langchain.routing_chains.routing_chain import RoutingChain
from langchain.routing_chains.tools import Tool
class ReActRouterChain(LLMRouter, BaseModel):
class ReActDocstoreRouter(LLMRouter, BaseModel):
"""Router for the ReAct chin."""
i: int = 1
def __init__(self, llm: LLM, **kwargs: Any):
"""Initialize with the language model."""
@classmethod
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> "ReActDocstoreRouter":
"""Construct a router from an LLM and tools."""
if len(tools) != 2:
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
tool_names = {tool.name for tool in tools}
if tool_names != {"Lookup", "Search"}:
raise ValueError(
f"Tool names should be Lookup and Search, got {tool_names}"
)
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
stops = ["\nObservation 1:"]
super().__init__(llm_chain=llm_chain, stops=stops, **kwargs)
return cls(llm_chain=llm_chain)
def _fix_text(self, text: str) -> str:
return text + f"\nAction {self.i}:"
@ -32,7 +41,6 @@ class ReActRouterChain(LLMRouter, BaseModel):
if not text.split("\n")[-1].startswith(action_prefix):
return None
self.i += 1
self.stops = [f"\nObservation {self.i}:"]
action_block = text.split("\n")[-1]
action_str = action_block[len(action_prefix) :]
@ -43,8 +51,8 @@ class ReActRouterChain(LLMRouter, BaseModel):
return re_matches.group(1), re_matches.group(2)
@property
def finish_action_name(self) -> str:
"""Name of the action of when to finish the chain."""
def finish_tool_name(self) -> str:
"""Name of the tool of when to finish the chain."""
return "Finish"
@property
@ -52,6 +60,10 @@ class ReActRouterChain(LLMRouter, BaseModel):
"""Prefix to append the observation with."""
return f"Observation {self.i - 1}: "
@property
def _stop(self) -> List[str]:
return [f"\nObservation {self.i}: "]
@property
def router_prefix(self) -> str:
"""Prefix to append the router call with."""
@ -95,10 +107,10 @@ class ReActChain(RoutingChain):
def __init__(self, llm: LLM, docstore: Docstore, **kwargs: Any):
"""Initialize with the LLM and a docstore."""
router = ReActRouterChain(llm)
docstore_explorer = DocstoreExplorer(docstore)
tool_configs = [
ToolConfig(tool_name="Search", tool=docstore_explorer.search),
ToolConfig(tool_name="Lookup", tool=docstore_explorer.lookup),
tools = [
Tool(name="Search", func=docstore_explorer.search),
Tool(name="Lookup", func=docstore_explorer.lookup),
]
super().__init__(router=router, expert_configs=tool_configs, **kwargs)
router = ReActDocstoreRouter.from_llm_and_tools(llm, tools)
super().__init__(router=router, tools=tools, **kwargs)

@ -1,10 +1,12 @@
"""Chain that takes in an input and produces an action and action input."""
from abc import ABC, abstractmethod
from typing import NamedTuple, Optional, Tuple
from typing import List, NamedTuple, Optional, Tuple
from pydantic import BaseModel
from langchain.chains.llm import LLMChain
from langchain.llms.base import LLM
from langchain.routing_chains.tools import Tool
class RouterOutput(NamedTuple):
@ -63,6 +65,15 @@ class LLMRouter(Router, BaseModel, ABC):
"""Fix the text."""
raise ValueError("fix_text not implemented for this router.")
@property
def _stop(self) -> List[str]:
return [f"\n{self.observation_prefix}"]
@classmethod
@abstractmethod
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> "Router":
"""Construct a router from an LLM and tools."""
def route(self, text: str) -> RouterOutput:
"""Given input, decided how to route it.
@ -73,12 +84,12 @@ class LLMRouter(Router, BaseModel, ABC):
RouterOutput specifying what tool to use.
"""
input_key = self.llm_chain.input_keys[0]
inputs = {input_key: text, "stop": [self.observation_prefix]}
inputs = {input_key: text, "stop": self._stop}
full_output = self.llm_chain.predict(**inputs)
parsed_output = self._extract_tool_and_input(full_output)
while parsed_output is None:
full_output = self._fix_text(full_output)
inputs = {input_key: text + full_output, "stop": [self.observation_prefix]}
inputs = {input_key: text + full_output, "stop": self._stop}
output = self.llm_chain.predict(**inputs)
full_output += output
parsed_output = self._extract_tool_and_input(full_output)

@ -1,18 +1,12 @@
"""Router-Expert framework."""
from typing import Callable, Dict, List, NamedTuple
from typing import Dict, List
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.input import ChainedInput, get_color_mapping
from langchain.routing_chains.router import Router
class ToolConfig(NamedTuple):
"""Configuration for tools."""
tool_name: str
tool: Callable[[str], str]
from langchain.routing_chains.tools import Tool
class RoutingChain(Chain, BaseModel):
@ -20,8 +14,8 @@ class RoutingChain(Chain, BaseModel):
router: Router
"""Router to use."""
tool_configs: List[ToolConfig]
"""Tool configs this chain has access to."""
tools: List[Tool]
"""Tools this chain has access to."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
@ -49,7 +43,7 @@ class RoutingChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tc.tool_name: tc.tool for tc in self.tool_configs}
name_to_tool_map = {tool.name: tool.func for tool in self.tools}
# Construct the initial string to pass into the router. This is made up
# of the user input, the special starter string, and then the router prefix.
# The starter string is a special string that may be used by a router to
@ -64,7 +58,7 @@ class RoutingChain(Chain, BaseModel):
chained_input = ChainedInput(starter_string, verbose=self.verbose)
# We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping(
[c.tool_name for c in self.tool_configs], excluded_colors=["green"]
[tool.name for tool in self.tools], excluded_colors=["green"]
)
# We now enter the router loop (until it returns something).
while True:

@ -1,21 +1,33 @@
"""Chain that does self ask with search."""
from typing import Any, Tuple
from typing import Any, List, Tuple
from langchain.chains.llm import LLMChain
from langchain.chains.serpapi import SerpAPIChain
from langchain.llms.base import LLM
from langchain.routing_chains.router import LLMRouter
from langchain.routing_chains.routing_chain import RoutingChain, ToolConfig
from langchain.routing_chains.routing_chain import RoutingChain
from langchain.routing_chains.self_ask_with_search.prompt import PROMPT
from langchain.routing_chains.tools import Tool
class SelfAskWithSearchRouter(LLMRouter):
"""Router for the self-ask-with-search paper."""
def __init__(self, llm: LLM, **kwargs: Any):
"""Initialize with an LLM."""
@classmethod
def from_llm_and_tools(
cls, llm: LLM, tools: List[Tool]
) -> "SelfAskWithSearchRouter":
"""Construct a router from an LLM and tools."""
if len(tools) != 1:
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
tool_names = {tool.name for tool in tools}
if tool_names != {"Intermediate Answer"}:
raise ValueError(
f"Tool name should be Intermediate Answer, got {tool_names}"
)
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
super().__init__(llm_chain=llm_chain, **kwargs)
return cls(llm_chain=llm_chain, tools=tools)
def _extract_tool_and_input(self, text: str) -> Tuple[str, str]:
followup = "Follow up:"
@ -71,8 +83,6 @@ class SelfAskWithSearchChain(RoutingChain):
def __init__(self, llm: LLM, search_chain: SerpAPIChain, **kwargs: Any):
"""Initialize with just an LLM and a search chain."""
intermediate = "\nIntermediate answer:"
router = SelfAskWithSearchRouter(llm, stops=[intermediate])
search_tool = ToolConfig(tool_name="Intermediate Answer", tool=search_chain.run)
expert_configs = [search_tool]
super().__init__(router=router, expert_configs=expert_configs, **kwargs)
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
router = SelfAskWithSearchRouter.from_llm_and_tools(llm, [search_tool])
super().__init__(router=router, tools=[search_tool], **kwargs)

@ -0,0 +1,10 @@
"""Interface for tools."""
from typing import Callable, NamedTuple, Optional
class Tool(NamedTuple):
"""Interface for tools."""
name: str
func: Callable[[str], str]
description: Optional[str] = None

@ -3,12 +3,9 @@
import pytest
from langchain.prompts import PromptTemplate
from langchain.routing_chains.mrkl.base import (
ChainConfig,
MRKLRouterChain,
get_action_and_input,
)
from langchain.routing_chains.mrkl.base import ZeroShotRouter, get_action_and_input
from langchain.routing_chains.mrkl.prompt import BASE_TEMPLATE
from langchain.routing_chains.tools import Tool
from tests.unit_tests.llms.fake_llm import FakeLLM
@ -56,14 +53,10 @@ def test_bad_action_line() -> None:
def test_from_chains() -> None:
"""Test initializing from chains."""
chain_configs = [
ChainConfig(
action_name="foo", action=lambda x: "foo", action_description="foobar1"
),
ChainConfig(
action_name="bar", action=lambda x: "bar", action_description="foobar2"
),
Tool(name="foo", func=lambda x: "foo", description="foobar1"),
Tool(name="bar", func=lambda x: "bar", description="foobar2"),
]
router_chain = MRKLRouterChain(FakeLLM(), chain_configs)
router_chain = ZeroShotRouter.from_llm_and_tools(FakeLLM(), chain_configs)
expected_tools_prompt = "foo: foobar1\nbar: foobar2"
expected_tool_names = "foo, bar"
expected_template = BASE_TEMPLATE.format(

@ -8,7 +8,7 @@ from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import LLM
from langchain.prompts.prompt import PromptTemplate
from langchain.routing_chains.react.base import ReActChain, ReActRouterChain
from langchain.routing_chains.react.base import ReActChain, ReActDocstoreRouter
_PAGE_CONTENT = """This is a page about LangChain.
@ -52,7 +52,7 @@ def test_predict_until_observation_normal() -> None:
"""Test predict_until_observation when observation is made normally."""
outputs = ["foo\nAction 1: search[foo]"]
fake_llm = FakeListLLM(outputs)
router_chain = ReActRouterChain(llm=fake_llm)
router_chain = ReActDocstoreRouter(llm=fake_llm)
output = router_chain.route("")
assert output.log == outputs[0]
assert output.tool == "search"
@ -63,7 +63,7 @@ def test_predict_until_observation_repeat() -> None:
"""Test when no action is generated initially."""
outputs = ["foo", " search[foo]"]
fake_llm = FakeListLLM(outputs)
router_chain = ReActRouterChain(llm=fake_llm)
router_chain = ReActDocstoreRouter(llm=fake_llm)
output = router_chain.route("")
assert output.log == "foo\nAction 1: search[foo]"
assert output.tool == "search"

Loading…
Cancel
Save