reorg smart chains

harrison/reorg_smart_chains
Harrison Chase 2 years ago
parent 2a84d3d5ca
commit 68eaf4e5ee

@ -20,9 +20,6 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
@ -33,8 +30,6 @@
"\u001b[32;1m\u001b[1;3mFollow up: Where is Carlos Alcaraz from?\u001b[0m\n",
"Intermediate answer: \u001b[36;1m\u001b[1;3mEl Palmar, Spain\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mSo the final answer is: El Palmar, Spain\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},

@ -8,10 +8,7 @@ with open(Path(__file__).absolute().parents[0] / "VERSION") as _f:
from langchain.chains import (
LLMChain,
LLMMathChain,
MRKLChain,
PythonChain,
ReActChain,
SelfAskWithSearchChain,
SerpAPIChain,
SQLDatabaseChain,
VectorDBQA,
@ -19,6 +16,7 @@ from langchain.chains import (
from langchain.docstore import Wikipedia
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
from langchain.prompts import BasePromptTemplate, PromptTemplate
from langchain.smart_chains import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.sql_database import SQLDatabase
from langchain.vectorstores import FAISS, ElasticVectorSearch

@ -1,10 +1,7 @@
"""Chains are easily reusable components which can be linked together."""
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.mrkl.base import MRKLChain
from langchain.chains.python import PythonChain
from langchain.chains.react.base import ReActChain
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
from langchain.chains.serpapi import SerpAPIChain
from langchain.chains.sql_database.base import SQLDatabaseChain
from langchain.chains.vector_db_qa.base import VectorDBQA
@ -13,10 +10,7 @@ __all__ = [
"LLMChain",
"LLMMathChain",
"PythonChain",
"SelfAskWithSearchChain",
"SerpAPIChain",
"ReActChain",
"SQLDatabaseChain",
"MRKLChain",
"VectorDBQA",
]

@ -0,0 +1,6 @@
"""Smart chains."""
from langchain.smart_chains.mrkl.base import MRKLChain
from langchain.smart_chains.react.base import ReActChain
from langchain.smart_chains.self_ask_with_search.base import SelfAskWithSearchChain
__all__ = ["MRKLChain", "SelfAskWithSearchChain", "ReActChain"]

@ -1,16 +1,12 @@
"""Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf."""
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
from typing import Any, Callable, List, NamedTuple, Optional, Tuple
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
from langchain.chains.router import LLMRouterChain
from langchain.input import ChainedInput, get_color_mapping
from langchain.llms.base import LLM
from langchain.prompts import BasePromptTemplate, PromptTemplate
from langchain.chains.router_expert import RouterExpertChain, ExpertConfig
from langchain.prompts import PromptTemplate
from langchain.smart_chains.mrkl.prompt import BASE_TEMPLATE
from langchain.smart_chains.router import LLMRouterChain
from langchain.smart_chains.router_expert import ExpertConfig, RouterExpertChain
FINAL_ANSWER_ACTION = "Final Answer: "
@ -79,33 +75,20 @@ class MRKLRouterChain(LLMRouterChain):
return get_action_and_input(text)
class MRKLChain(Chain, BaseModel):
class MRKLChain(RouterExpertChain):
"""Chain that implements the MRKL system.
Example:
.. code-block:: python
from langchain import OpenAI, Prompt, MRKLChain
from langchain import OpenAI, MRKLChain
from langchain.chains.mrkl.base import ChainConfig
llm = OpenAI(temperature=0)
prompt = PromptTemplate(...)
action_to_chain_map = {...}
mrkl = MRKLChain(
llm=llm,
prompt=prompt,
action_to_chain_map=action_to_chain_map
)
chains = [...]
mrkl = MRKLChain.from_chains(llm=llm, prompt=prompt)
"""
llm: LLM
"""LLM wrapper to use as router."""
chain_configs: List[ChainConfig]
"""Chain configs this chain has access to."""
action_to_chain_map: Dict[str, Callable]
"""Mapping from action name to chain to execute."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
@classmethod
def from_chains(
cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any
@ -145,47 +128,8 @@ class MRKLChain(Chain, BaseModel):
]
mrkl = MRKLChain.from_chains(llm, chains)
"""
action_to_chain_map = {chain.action_name: chain.action for chain in chains}
return cls(
llm=llm,
chain_configs=chains,
action_to_chain_map=action_to_chain_map,
**kwargs,
)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
router_chain = MRKLRouterChain(self.llm, self.chain_configs)
question = inputs[self.input_key]
router_chain = MRKLRouterChain(llm, chains)
expert_configs = [
ExpertConfig(expert_name=c.action_name, expert=c.action)
for c in self.chain_configs
ExpertConfig(expert_name=c.action_name, expert=c.action) for c in chains
]
chain = RouterExpertChain(
router_chain=router_chain,
expert_configs=expert_configs,
verbose=self.verbose
)
output = chain.run(question)
return {self.output_key: output}
return cls(router_chain=router_chain, expert_configs=expert_configs, **kwargs)

@ -1,18 +1,16 @@
"""Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf."""
import re
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional, Tuple
from pydantic import BaseModel, Extra
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.react.prompt import PROMPT
from langchain.chains.router import LLMRouterChain
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.input import ChainedInput
from langchain.llms.base import LLM
from langchain.chains.router_expert import RouterExpertChain, ExpertConfig
from langchain.smart_chains.react.prompt import PROMPT
from langchain.smart_chains.router import LLMRouterChain
from langchain.smart_chains.router_expert import ExpertConfig, RouterExpertChain
class ReActRouterChain(LLMRouterChain, BaseModel):
@ -46,7 +44,7 @@ class ReActRouterChain(LLMRouterChain, BaseModel):
@property
def finish_action_name(self) -> str:
"""The action name of when to finish the chain."""
"""Name of the action of when to finish the chain."""
return "Finish"
@property
@ -61,12 +59,15 @@ class ReActRouterChain(LLMRouterChain, BaseModel):
class DocstoreExplorer:
"""Class to assist with exploration of a document store."""
def __init__(self, docstore: Docstore):
self.docstore=docstore
self.document = None
"""Initialize with a docstore, and set initial document to None."""
self.docstore = docstore
self.document: Optional[Document] = None
def search(self, term: str):
def search(self, term: str) -> str:
"""Search for a term in the docstore, and if found save."""
result = self.docstore.search(term)
if isinstance(result, Document):
self.document = result
@ -75,13 +76,14 @@ class DocstoreExplorer:
self.document = None
return result
def lookup(self, term: str):
def lookup(self, term: str) -> str:
"""Lookup a term in document (if saved)."""
if self.document is None:
raise ValueError("Cannot lookup without a successful search first")
return self.document.lookup(term)
class ReActChain(Chain, BaseModel):
class ReActChain(RouterExpertChain):
"""Chain that implements the ReAct paper.
Example:
@ -91,47 +93,14 @@ class ReActChain(Chain, BaseModel):
react = ReAct(llm=OpenAI())
"""
llm: LLM
"""LLM wrapper to use."""
docstore: Docstore
"""Docstore to use."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
question = inputs[self.input_key]
router_chain = ReActRouterChain(self.llm)
docstore_explorer = DocstoreExplorer(self.docstore)
def __init__(self, llm: LLM, docstore: Docstore, **kwargs: Any):
"""Initialize with the LLM and a docstore."""
router_chain = ReActRouterChain(llm)
docstore_explorer = DocstoreExplorer(docstore)
expert_configs = [
ExpertConfig(expert_name="Search", expert=docstore_explorer.search),
ExpertConfig(expert_name="Lookup", expert=docstore_explorer.lookup)
ExpertConfig(expert_name="Lookup", expert=docstore_explorer.lookup),
]
chain = RouterExpertChain(
router_chain=router_chain,
expert_configs=expert_configs,
verbose=self.verbose
super().__init__(
router_chain=router_chain, expert_configs=expert_configs, **kwargs
)
output = chain.run(question)
return {self.output_key: output}

@ -48,7 +48,7 @@ class RouterChain(Chain, BaseModel, ABC):
@property
def finish_action_name(self) -> str:
"""The action name of when to finish the chain."""
"""Name of the action of when to finish the chain."""
return "Final Answer"
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:

@ -1,21 +1,15 @@
"""Router-Expert framework."""
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
from typing import Callable, Dict, List, NamedTuple
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
from langchain.chains.router import LLMRouterChain
from langchain.input import ChainedInput, get_color_mapping
from langchain.llms.base import LLM
from langchain.prompts import BasePromptTemplate, PromptTemplate
from langchain.chains.router import RouterChain
from langchain.smart_chains.router import RouterChain
class ExpertConfig(NamedTuple):
"""Configuration for experts."""
expert_name: str
expert: Callable[[str], str]
@ -57,8 +51,14 @@ class RouterExpertChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
action_to_chain_map = {e.expert_name: e.expert for e in self.expert_configs}
starter_string = (
inputs[self.input_key]
+ self.starter_string
+ self.router_chain.router_prefix
)
chained_input = ChainedInput(
f"{inputs[self.input_key]}{self.starter_string}{self.router_chain.router_prefix}", verbose=self.verbose
starter_string,
verbose=self.verbose,
)
color_mapping = get_color_mapping(
[c.expert_name for c in self.expert_configs], excluded_colors=["green"]
@ -74,4 +74,4 @@ class RouterExpertChain(Chain, BaseModel):
ca = chain(action_input)
chained_input.add(f"\n{self.router_chain.observation_prefix}")
chained_input.add(ca, color=color_mapping[action])
chained_input.add(f"\n{self.router_chain.router_prefix}")
chained_input.add(f"\n{self.router_chain.router_prefix}")

@ -1,16 +1,12 @@
"""Chain that does self ask with search."""
from typing import Any, Dict, List, Tuple
from typing import Any, Tuple
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.router import LLMRouterChain
from langchain.chains.self_ask_with_search.prompt import PROMPT
from langchain.chains.serpapi import SerpAPIChain
from langchain.input import ChainedInput
from langchain.llms.base import LLM
from langchain.chains.router_expert import RouterExpertChain, ExpertConfig
from langchain.smart_chains.router import LLMRouterChain
from langchain.smart_chains.router_expert import ExpertConfig, RouterExpertChain
from langchain.smart_chains.self_ask_with_search.prompt import PROMPT
class SelfAskWithSearchRouter(LLMRouterChain):
@ -32,7 +28,7 @@ class SelfAskWithSearchRouter(LLMRouterChain):
finish_string = "So the final answer is: "
if finish_string not in last_line:
raise ValueError("We should probably never get here")
return "Final Answer", text[len(finish_string):]
return "Final Answer", text[len(finish_string) :]
if ":" not in last_line:
after_colon = last_line
@ -57,7 +53,7 @@ class SelfAskWithSearchRouter(LLMRouterChain):
return ""
class SelfAskWithSearchChain(Chain, BaseModel):
class SelfAskWithSearchChain(RouterExpertChain):
"""Chain that does self ask with search.
Example:
@ -68,39 +64,16 @@ class SelfAskWithSearchChain(Chain, BaseModel):
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
"""
llm: LLM
"""LLM wrapper to use."""
search_chain: SerpAPIChain
"""Search chain to use."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def __init__(self, llm: LLM, search_chain: SerpAPIChain, **kwargs: Any):
"""Initialize with just an LLM and a search chain."""
intermediate = "\nIntermediate answer:"
router = SelfAskWithSearchRouter(self.llm, stops=[intermediate])
expert_configs = [ExpertConfig(expert_name="Intermediate Answer", expert=self.search_chain.run)]
chain = RouterExpertChain(router_chain=router, expert_configs=expert_configs, verbose=self.verbose, starter_string="\nAre follow up questions needed here:")
output = chain.run(inputs[self.input_key])
return {self.output_key: output}
router = SelfAskWithSearchRouter(llm, stops=[intermediate])
expert_configs = [
ExpertConfig(expert_name="Intermediate Answer", expert=search_chain.run)
]
super().__init__(
router_chain=router,
expert_configs=expert_configs,
starter_string="\nAre follow up questions needed here:",
**kwargs
)

@ -1,8 +1,8 @@
"""Integration test for self ask with search."""
from langchain.chains.react.base import ReActChain
from langchain.docstore.wikipedia import Wikipedia
from langchain.llms.openai import OpenAI
from langchain.smart_chains.react.base import ReActChain
def test_react() -> None:

@ -1,7 +1,7 @@
"""Integration test for self ask with search."""
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
from langchain.chains.serpapi import SerpAPIChain
from langchain.llms.openai import OpenAI
from langchain.smart_chains.self_ask_with_search.base import SelfAskWithSearchChain
def test_self_ask_with_search() -> None:

@ -0,0 +1 @@
"""Test smart chain functionality."""

@ -2,13 +2,13 @@
import pytest
from langchain.chains.mrkl.base import (
from langchain.prompts import PromptTemplate
from langchain.smart_chains.mrkl.base import (
ChainConfig,
MRKLRouterChain,
get_action_and_input,
)
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
from langchain.prompts import PromptTemplate
from langchain.smart_chains.mrkl.prompt import BASE_TEMPLATE
from tests.unit_tests.llms.fake_llm import FakeLLM

@ -4,11 +4,11 @@ from typing import Any, List, Mapping, Optional, Union
import pytest
from langchain.chains.react.base import ReActChain, ReActRouterChain
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import LLM
from langchain.prompts.prompt import PromptTemplate
from langchain.smart_chains.react.base import ReActChain, ReActRouterChain
_PAGE_CONTENT = """This is a page about LangChain.
Loading…
Cancel
Save