forked from Archives/langchain
Compare commits
6 Commits
main
...
harrison/r
Author | SHA1 | Date |
---|---|---|
Harrison Chase | 8c8eb47765 | 2 years ago |
Harrison Chase | 68eaf4e5ee | 2 years ago |
Harrison Chase | 2a84d3d5ca | 2 years ago |
Harrison Chase | 45ce74d0bc | 2 years ago |
Harrison Chase | 2a2d3323c9 | 2 years ago |
Harrison Chase | 6f55fa8ba7 | 2 years ago |
@ -1,107 +0,0 @@
|
|||||||
"""Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf."""
|
|
||||||
import re
|
|
||||||
from typing import Any, Dict, List, Tuple
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra
|
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.chains.react.prompt import PROMPT
|
|
||||||
from langchain.docstore.base import Docstore
|
|
||||||
from langchain.docstore.document import Document
|
|
||||||
from langchain.input import ChainedInput
|
|
||||||
from langchain.llms.base import LLM
|
|
||||||
|
|
||||||
|
|
||||||
def predict_until_observation(
|
|
||||||
llm_chain: LLMChain, prompt: str, i: int
|
|
||||||
) -> Tuple[str, str, str]:
|
|
||||||
"""Generate text until an observation is needed."""
|
|
||||||
action_prefix = f"Action {i}: "
|
|
||||||
stop_seq = f"\nObservation {i}:"
|
|
||||||
ret_text = llm_chain.predict(input=prompt, stop=[stop_seq])
|
|
||||||
# Sometimes the LLM forgets to take an action, so we prompt it to.
|
|
||||||
while not ret_text.split("\n")[-1].startswith(action_prefix):
|
|
||||||
ret_text += f"\nAction {i}:"
|
|
||||||
new_text = llm_chain.predict(input=prompt + ret_text, stop=[stop_seq])
|
|
||||||
ret_text += new_text
|
|
||||||
# The action block should be the last line.
|
|
||||||
action_block = ret_text.split("\n")[-1]
|
|
||||||
action_str = action_block[len(action_prefix) :]
|
|
||||||
# Parse out the action and the directive.
|
|
||||||
re_matches = re.search(r"(.*?)\[(.*?)\]", action_str)
|
|
||||||
if re_matches is None:
|
|
||||||
raise ValueError(f"Could not parse action directive: {action_str}")
|
|
||||||
return ret_text, re_matches.group(1), re_matches.group(2)
|
|
||||||
|
|
||||||
|
|
||||||
class ReActChain(Chain, BaseModel):
|
|
||||||
"""Chain that implements the ReAct paper.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain import ReActChain, OpenAI
|
|
||||||
react = ReAct(llm=OpenAI())
|
|
||||||
"""
|
|
||||||
|
|
||||||
llm: LLM
|
|
||||||
"""LLM wrapper to use."""
|
|
||||||
docstore: Docstore
|
|
||||||
"""Docstore to use."""
|
|
||||||
input_key: str = "question" #: :meta private:
|
|
||||||
output_key: str = "answer" #: :meta private:
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Expect input key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Expect output key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.output_key]
|
|
||||||
|
|
||||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
||||||
question = inputs[self.input_key]
|
|
||||||
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
|
||||||
chained_input = ChainedInput(f"{question}\nThought 1:", verbose=self.verbose)
|
|
||||||
i = 1
|
|
||||||
document = None
|
|
||||||
while True:
|
|
||||||
ret_text, action, directive = predict_until_observation(
|
|
||||||
llm_chain, chained_input.input, i
|
|
||||||
)
|
|
||||||
chained_input.add(ret_text, color="green")
|
|
||||||
if action == "Search":
|
|
||||||
result = self.docstore.search(directive)
|
|
||||||
if isinstance(result, Document):
|
|
||||||
document = result
|
|
||||||
observation = document.summary
|
|
||||||
else:
|
|
||||||
document = None
|
|
||||||
observation = result
|
|
||||||
elif action == "Lookup":
|
|
||||||
if document is None:
|
|
||||||
raise ValueError("Cannot lookup without a successful search first")
|
|
||||||
observation = document.lookup(directive)
|
|
||||||
elif action == "Finish":
|
|
||||||
return {self.output_key: directive}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unknown action directive: {action}")
|
|
||||||
chained_input.add(f"\nObservation {i}: ")
|
|
||||||
chained_input.add(observation, color="yellow")
|
|
||||||
chained_input.add(f"\nThought {i + 1}:")
|
|
||||||
i += 1
|
|
@ -1,149 +0,0 @@
|
|||||||
"""Chain that does self ask with search."""
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra
|
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.chains.self_ask_with_search.prompt import PROMPT
|
|
||||||
from langchain.chains.serpapi import SerpAPIChain
|
|
||||||
from langchain.input import ChainedInput
|
|
||||||
from langchain.llms.base import LLM
|
|
||||||
|
|
||||||
|
|
||||||
def extract_answer(generated: str) -> str:
|
|
||||||
"""Extract answer from text."""
|
|
||||||
if "\n" not in generated:
|
|
||||||
last_line = generated
|
|
||||||
else:
|
|
||||||
last_line = generated.split("\n")[-1]
|
|
||||||
|
|
||||||
if ":" not in last_line:
|
|
||||||
after_colon = last_line
|
|
||||||
else:
|
|
||||||
after_colon = generated.split(":")[-1]
|
|
||||||
|
|
||||||
if " " == after_colon[0]:
|
|
||||||
after_colon = after_colon[1:]
|
|
||||||
if "." == after_colon[-1]:
|
|
||||||
after_colon = after_colon[:-1]
|
|
||||||
|
|
||||||
return after_colon
|
|
||||||
|
|
||||||
|
|
||||||
def extract_question(generated: str, followup: str) -> str:
|
|
||||||
"""Extract question from text."""
|
|
||||||
if "\n" not in generated:
|
|
||||||
last_line = generated
|
|
||||||
else:
|
|
||||||
last_line = generated.split("\n")[-1]
|
|
||||||
|
|
||||||
if followup not in last_line:
|
|
||||||
print("we probably should never get here..." + generated)
|
|
||||||
|
|
||||||
if ":" not in last_line:
|
|
||||||
after_colon = last_line
|
|
||||||
else:
|
|
||||||
after_colon = generated.split(":")[-1]
|
|
||||||
|
|
||||||
if " " == after_colon[0]:
|
|
||||||
after_colon = after_colon[1:]
|
|
||||||
if "?" != after_colon[-1]:
|
|
||||||
print("we probably should never get here..." + generated)
|
|
||||||
|
|
||||||
return after_colon
|
|
||||||
|
|
||||||
|
|
||||||
def get_last_line(generated: str) -> str:
|
|
||||||
"""Get the last line in text."""
|
|
||||||
if "\n" not in generated:
|
|
||||||
last_line = generated
|
|
||||||
else:
|
|
||||||
last_line = generated.split("\n")[-1]
|
|
||||||
|
|
||||||
return last_line
|
|
||||||
|
|
||||||
|
|
||||||
def greenify(_input: str) -> str:
|
|
||||||
"""Add green highlighting to text."""
|
|
||||||
return "\x1b[102m" + _input + "\x1b[0m"
|
|
||||||
|
|
||||||
|
|
||||||
def yellowfy(_input: str) -> str:
|
|
||||||
"""Add yellow highlighting to text."""
|
|
||||||
return "\x1b[106m" + _input + "\x1b[0m"
|
|
||||||
|
|
||||||
|
|
||||||
class SelfAskWithSearchChain(Chain, BaseModel):
|
|
||||||
"""Chain that does self ask with search.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain
|
|
||||||
search_chain = SerpAPIChain()
|
|
||||||
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
|
|
||||||
"""
|
|
||||||
|
|
||||||
llm: LLM
|
|
||||||
"""LLM wrapper to use."""
|
|
||||||
search_chain: SerpAPIChain
|
|
||||||
"""Search chain to use."""
|
|
||||||
input_key: str = "question" #: :meta private:
|
|
||||||
output_key: str = "answer" #: :meta private:
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Expect input key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Expect output key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.output_key]
|
|
||||||
|
|
||||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
||||||
chained_input = ChainedInput(inputs[self.input_key], verbose=self.verbose)
|
|
||||||
chained_input.add("\nAre follow up questions needed here:")
|
|
||||||
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
|
||||||
intermediate = "\nIntermediate answer:"
|
|
||||||
followup = "Follow up:"
|
|
||||||
finalans = "\nSo the final answer is:"
|
|
||||||
ret_text = llm_chain.predict(input=chained_input.input, stop=[intermediate])
|
|
||||||
chained_input.add(ret_text, color="green")
|
|
||||||
while followup in get_last_line(ret_text):
|
|
||||||
question = extract_question(ret_text, followup)
|
|
||||||
external_answer = self.search_chain.run(question)
|
|
||||||
if external_answer is not None:
|
|
||||||
chained_input.add(intermediate + " ")
|
|
||||||
chained_input.add(external_answer + ".", color="yellow")
|
|
||||||
ret_text = llm_chain.predict(
|
|
||||||
input=chained_input.input, stop=["\nIntermediate answer:"]
|
|
||||||
)
|
|
||||||
chained_input.add(ret_text, color="green")
|
|
||||||
else:
|
|
||||||
# We only get here in the very rare case that Google returns no answer.
|
|
||||||
chained_input.add(intermediate + " ")
|
|
||||||
preds = llm_chain.predict(
|
|
||||||
input=chained_input.input, stop=["\n" + followup, finalans]
|
|
||||||
)
|
|
||||||
chained_input.add(preds, color="green")
|
|
||||||
|
|
||||||
if finalans not in ret_text:
|
|
||||||
chained_input.add(finalans)
|
|
||||||
ret_text = llm_chain.predict(input=chained_input.input, stop=["\n"])
|
|
||||||
chained_input.add(ret_text, color="green")
|
|
||||||
|
|
||||||
return {self.output_key: ret_text}
|
|
@ -0,0 +1,6 @@
|
|||||||
|
"""Smart chains."""
|
||||||
|
from langchain.smart_chains.mrkl.base import MRKLChain
|
||||||
|
from langchain.smart_chains.react.base import ReActChain
|
||||||
|
from langchain.smart_chains.self_ask_with_search.base import SelfAskWithSearchChain
|
||||||
|
|
||||||
|
__all__ = ["MRKLChain", "SelfAskWithSearchChain", "ReActChain"]
|
@ -0,0 +1,106 @@
|
|||||||
|
"""Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf."""
|
||||||
|
import re
|
||||||
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.docstore.base import Docstore
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.smart_chains.react.prompt import PROMPT
|
||||||
|
from langchain.smart_chains.router import LLMRouterChain
|
||||||
|
from langchain.smart_chains.router_expert import ExpertConfig, RouterExpertChain
|
||||||
|
|
||||||
|
|
||||||
|
class ReActRouterChain(LLMRouterChain, BaseModel):
|
||||||
|
"""Router for the ReAct chin."""
|
||||||
|
|
||||||
|
i: int = 1
|
||||||
|
|
||||||
|
def __init__(self, llm: LLM, **kwargs: Any):
|
||||||
|
"""Initialize with the language model."""
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
||||||
|
stops = ["\nObservation 1:"]
|
||||||
|
super().__init__(llm_chain=llm_chain, stops=stops, **kwargs)
|
||||||
|
|
||||||
|
def _fix_text(self, text: str) -> str:
|
||||||
|
return text + f"\nAction {self.i}:"
|
||||||
|
|
||||||
|
def _extract_action_and_input(self, text: str) -> Optional[Tuple[str, str]]:
|
||||||
|
action_prefix = f"Action {self.i}: "
|
||||||
|
if not text.split("\n")[-1].startswith(action_prefix):
|
||||||
|
return None
|
||||||
|
self.i += 1
|
||||||
|
self.stops = [f"\nObservation {self.i}:"]
|
||||||
|
action_block = text.split("\n")[-1]
|
||||||
|
|
||||||
|
action_str = action_block[len(action_prefix) :]
|
||||||
|
# Parse out the action and the directive.
|
||||||
|
re_matches = re.search(r"(.*?)\[(.*?)\]", action_str)
|
||||||
|
if re_matches is None:
|
||||||
|
raise ValueError(f"Could not parse action directive: {action_str}")
|
||||||
|
return re_matches.group(1), re_matches.group(2)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def finish_action_name(self) -> str:
|
||||||
|
"""Name of the action of when to finish the chain."""
|
||||||
|
return "Finish"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_prefix(self) -> str:
|
||||||
|
"""Prefix to append the observation with."""
|
||||||
|
return f"Observation {self.i - 1}: "
|
||||||
|
|
||||||
|
@property
|
||||||
|
def router_prefix(self) -> str:
|
||||||
|
"""Prefix to append the router call with."""
|
||||||
|
return f"Thought {self.i}:"
|
||||||
|
|
||||||
|
|
||||||
|
class DocstoreExplorer:
|
||||||
|
"""Class to assist with exploration of a document store."""
|
||||||
|
|
||||||
|
def __init__(self, docstore: Docstore):
|
||||||
|
"""Initialize with a docstore, and set initial document to None."""
|
||||||
|
self.docstore = docstore
|
||||||
|
self.document: Optional[Document] = None
|
||||||
|
|
||||||
|
def search(self, term: str) -> str:
|
||||||
|
"""Search for a term in the docstore, and if found save."""
|
||||||
|
result = self.docstore.search(term)
|
||||||
|
if isinstance(result, Document):
|
||||||
|
self.document = result
|
||||||
|
return self.document.summary
|
||||||
|
else:
|
||||||
|
self.document = None
|
||||||
|
return result
|
||||||
|
|
||||||
|
def lookup(self, term: str) -> str:
|
||||||
|
"""Lookup a term in document (if saved)."""
|
||||||
|
if self.document is None:
|
||||||
|
raise ValueError("Cannot lookup without a successful search first")
|
||||||
|
return self.document.lookup(term)
|
||||||
|
|
||||||
|
|
||||||
|
class ReActChain(RouterExpertChain):
|
||||||
|
"""Chain that implements the ReAct paper.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain import ReActChain, OpenAI
|
||||||
|
react = ReAct(llm=OpenAI())
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, llm: LLM, docstore: Docstore, **kwargs: Any):
|
||||||
|
"""Initialize with the LLM and a docstore."""
|
||||||
|
router_chain = ReActRouterChain(llm)
|
||||||
|
docstore_explorer = DocstoreExplorer(docstore)
|
||||||
|
expert_configs = [
|
||||||
|
ExpertConfig(expert_name="Search", expert=docstore_explorer.search),
|
||||||
|
ExpertConfig(expert_name="Lookup", expert=docstore_explorer.lookup),
|
||||||
|
]
|
||||||
|
super().__init__(
|
||||||
|
router_chain=router_chain, expert_configs=expert_configs, **kwargs
|
||||||
|
)
|
@ -0,0 +1,90 @@
|
|||||||
|
"""Chain that takes in an input and produces an action and action input."""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
|
||||||
|
|
||||||
|
class RouterChain(Chain, BaseModel, ABC):
|
||||||
|
"""Chain responsible for deciding the action to take."""
|
||||||
|
|
||||||
|
input_key: str = "input_text" #: :meta private:
|
||||||
|
action_key: str = "action" #: :meta private:
|
||||||
|
action_input_key: str = "action_input" #: :meta private:
|
||||||
|
log_key: str = "log" #: :meta private:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Will be the input key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Return three keys: the action, the action input, and the log.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.action_key, self.action_input_key, self.log_key]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_action_and_input(self, text: str) -> Tuple[str, str, str]:
|
||||||
|
"""Return action, action input, and log (in that order)."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def observation_prefix(self) -> str:
|
||||||
|
"""Prefix to append the observation with."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def router_prefix(self) -> str:
|
||||||
|
"""Prefix to append the router call with."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def finish_action_name(self) -> str:
|
||||||
|
"""Name of the action of when to finish the chain."""
|
||||||
|
return "Final Answer"
|
||||||
|
|
||||||
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
action, action_input, log = self.get_action_and_input(inputs[self.input_key])
|
||||||
|
return {
|
||||||
|
self.action_key: action,
|
||||||
|
self.action_input_key: action_input,
|
||||||
|
self.log_key: log,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LLMRouterChain(RouterChain, BaseModel, ABC):
|
||||||
|
"""RouterChain that uses an LLM."""
|
||||||
|
|
||||||
|
llm_chain: LLMChain
|
||||||
|
stops: Optional[List[str]]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _extract_action_and_input(self, text: str) -> Optional[Tuple[str, str]]:
|
||||||
|
"""Extract action and action input from llm output."""
|
||||||
|
|
||||||
|
def _fix_text(self, text: str) -> str:
|
||||||
|
"""Fix the text."""
|
||||||
|
raise ValueError("fix_text not implemented for this router.")
|
||||||
|
|
||||||
|
def get_action_and_input(self, text: str) -> Tuple[str, str, str]:
|
||||||
|
"""Return action, action input, and log (in that order)."""
|
||||||
|
input_key = self.llm_chain.input_keys[0]
|
||||||
|
inputs = {input_key: text, "stop": self.stops}
|
||||||
|
full_output = self.llm_chain.predict(**inputs)
|
||||||
|
parsed_output = self._extract_action_and_input(full_output)
|
||||||
|
while parsed_output is None:
|
||||||
|
full_output = self._fix_text(full_output)
|
||||||
|
inputs = {input_key: text + full_output, "stop": self.stops}
|
||||||
|
output = self.llm_chain.predict(**inputs)
|
||||||
|
full_output += output
|
||||||
|
parsed_output = self._extract_action_and_input(full_output)
|
||||||
|
action, action_input = parsed_output
|
||||||
|
return action, action_input, full_output
|
@ -0,0 +1,77 @@
|
|||||||
|
"""Router-Expert framework."""
|
||||||
|
from typing import Callable, Dict, List, NamedTuple
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.input import ChainedInput, get_color_mapping
|
||||||
|
from langchain.smart_chains.router import RouterChain
|
||||||
|
|
||||||
|
|
||||||
|
class ExpertConfig(NamedTuple):
|
||||||
|
"""Configuration for experts."""
|
||||||
|
|
||||||
|
expert_name: str
|
||||||
|
expert: Callable[[str], str]
|
||||||
|
|
||||||
|
|
||||||
|
class RouterExpertChain(Chain, BaseModel):
|
||||||
|
"""Chain that implements the Router/Expert system."""
|
||||||
|
|
||||||
|
router_chain: RouterChain
|
||||||
|
"""Router chain."""
|
||||||
|
expert_configs: List[ExpertConfig]
|
||||||
|
"""Expert configs this chain has access to."""
|
||||||
|
starter_string: str = "\n"
|
||||||
|
"""String to put after user input but before first router."""
|
||||||
|
input_key: str = "question" #: :meta private:
|
||||||
|
output_key: str = "answer" #: :meta private:
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Expect input key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Expect output key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
action_to_chain_map = {e.expert_name: e.expert for e in self.expert_configs}
|
||||||
|
starter_string = (
|
||||||
|
inputs[self.input_key]
|
||||||
|
+ self.starter_string
|
||||||
|
+ self.router_chain.router_prefix
|
||||||
|
)
|
||||||
|
chained_input = ChainedInput(
|
||||||
|
starter_string,
|
||||||
|
verbose=self.verbose,
|
||||||
|
)
|
||||||
|
color_mapping = get_color_mapping(
|
||||||
|
[c.expert_name for c in self.expert_configs], excluded_colors=["green"]
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
action, action_input, log = self.router_chain.get_action_and_input(
|
||||||
|
chained_input.input
|
||||||
|
)
|
||||||
|
chained_input.add(log, color="green")
|
||||||
|
if action == self.router_chain.finish_action_name:
|
||||||
|
return {self.output_key: action_input}
|
||||||
|
chain = action_to_chain_map[action]
|
||||||
|
ca = chain(action_input)
|
||||||
|
chained_input.add(f"\n{self.router_chain.observation_prefix}")
|
||||||
|
chained_input.add(ca, color=color_mapping[action])
|
||||||
|
chained_input.add(f"\n{self.router_chain.router_prefix}")
|
@ -0,0 +1,79 @@
|
|||||||
|
"""Chain that does self ask with search."""
|
||||||
|
from typing import Any, Tuple
|
||||||
|
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chains.serpapi import SerpAPIChain
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.smart_chains.router import LLMRouterChain
|
||||||
|
from langchain.smart_chains.router_expert import ExpertConfig, RouterExpertChain
|
||||||
|
from langchain.smart_chains.self_ask_with_search.prompt import PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAskWithSearchRouter(LLMRouterChain):
|
||||||
|
"""Router for the self-ask-with-search paper."""
|
||||||
|
|
||||||
|
def __init__(self, llm: LLM, **kwargs: Any):
|
||||||
|
"""Initialize with an LLM."""
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
||||||
|
super().__init__(llm_chain=llm_chain, **kwargs)
|
||||||
|
|
||||||
|
def _extract_action_and_input(self, text: str) -> Tuple[str, str]:
|
||||||
|
followup = "Follow up:"
|
||||||
|
if "\n" not in text:
|
||||||
|
last_line = text
|
||||||
|
else:
|
||||||
|
last_line = text.split("\n")[-1]
|
||||||
|
|
||||||
|
if followup not in last_line:
|
||||||
|
finish_string = "So the final answer is: "
|
||||||
|
if finish_string not in last_line:
|
||||||
|
raise ValueError("We should probably never get here")
|
||||||
|
return "Final Answer", text[len(finish_string) :]
|
||||||
|
|
||||||
|
if ":" not in last_line:
|
||||||
|
after_colon = last_line
|
||||||
|
else:
|
||||||
|
after_colon = text.split(":")[-1]
|
||||||
|
|
||||||
|
if " " == after_colon[0]:
|
||||||
|
after_colon = after_colon[1:]
|
||||||
|
if "?" != after_colon[-1]:
|
||||||
|
print("we probably should never get here..." + text)
|
||||||
|
|
||||||
|
return "Intermediate Answer", after_colon
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_prefix(self) -> str:
|
||||||
|
"""Prefix to append the observation with."""
|
||||||
|
return "Intermediate answer: "
|
||||||
|
|
||||||
|
@property
|
||||||
|
def router_prefix(self) -> str:
|
||||||
|
"""Prefix to append the router call with."""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAskWithSearchChain(RouterExpertChain):
|
||||||
|
"""Chain that does self ask with search.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain
|
||||||
|
search_chain = SerpAPIChain()
|
||||||
|
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, llm: LLM, search_chain: SerpAPIChain, **kwargs: Any):
|
||||||
|
"""Initialize with just an LLM and a search chain."""
|
||||||
|
intermediate = "\nIntermediate answer:"
|
||||||
|
router = SelfAskWithSearchRouter(llm, stops=[intermediate])
|
||||||
|
expert_configs = [
|
||||||
|
ExpertConfig(expert_name="Intermediate Answer", expert=search_chain.run)
|
||||||
|
]
|
||||||
|
super().__init__(
|
||||||
|
router_chain=router,
|
||||||
|
expert_configs=expert_configs,
|
||||||
|
starter_string="\nAre follow up questions needed here:",
|
||||||
|
**kwargs
|
||||||
|
)
|
@ -0,0 +1 @@
|
|||||||
|
"""Test smart chain functionality."""
|
Loading…
Reference in New Issue