You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

156 lines
5.3 KiB

"""Chain that implements the ReAct paper from"""
from typing import Any, List, Optional, Sequence
from pydantic import Field
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.agent_types import AgentType
from langchain.agents.react.output_parser import ReActOutputParser
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
from langchain.agents.react.wiki_prompt import WIKI_PROMPT
from import Tool
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from import BaseTool
class ReActDocstoreAgent(Agent):
"""Agent for the ReAct chain."""
output_parser: AgentOutputParser = Field(default_factory=ReActOutputParser)
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return ReActOutputParser()
def _agent_type(self) -> str:
"""Return Identifier of agent type."""
return AgentType.REACT_DOCSTORE
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
"""Return default prompt."""
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
if len(tools) != 2:
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
tool_names = { for tool in tools}
if tool_names != {"Lookup", "Search"}:
raise ValueError(
f"Tool names should be Lookup and Search, got {tool_names}"
def observation_prefix(self) -> str:
"""Prefix to append the observation with."""
return "Observation: "
def _stop(self) -> List[str]:
return ["\nObservation:"]
def llm_prefix(self) -> str:
"""Prefix to append the LLM call with."""
return "Thought:"
class DocstoreExplorer:
"""Class to assist with exploration of a document store."""
def __init__(self, docstore: Docstore):
"""Initialize with a docstore, and set initial document to None."""
self.docstore = docstore
self.document: Optional[Document] = None
self.lookup_str = ""
self.lookup_index = 0
def search(self, term: str) -> str:
"""Search for a term in the docstore, and if found save."""
result =
if isinstance(result, Document):
self.document = result
return self._summary
self.document = None
return result
def lookup(self, term: str) -> str:
"""Lookup a term in document (if saved)."""
if self.document is None:
raise ValueError("Cannot lookup without a successful search first")
if term.lower() != self.lookup_str:
self.lookup_str = term.lower()
self.lookup_index = 0
self.lookup_index += 1
lookups = [p for p in self._paragraphs if self.lookup_str in p.lower()]
if len(lookups) == 0:
return "No Results"
elif self.lookup_index >= len(lookups):
return "No More Results"
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
return f"{result_prefix} {lookups[self.lookup_index]}"
def _summary(self) -> str:
return self._paragraphs[0]
def _paragraphs(self) -> List[str]:
if self.document is None:
raise ValueError("Cannot get paragraphs without a document")
return self.document.page_content.split("\n\n")
class ReActTextWorldAgent(ReActDocstoreAgent):
"""Agent for the ReAct TextWorld chain."""
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
"""Return default prompt."""
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
if len(tools) != 1:
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
tool_names = { for tool in tools}
if tool_names != {"Play"}:
raise ValueError(f"Tool name should be Play, got {tool_names}")
class ReActChain(AgentExecutor):
"""Chain that implements the ReAct paper.
.. code-block:: python
from langchain import ReActChain, OpenAI
react = ReAct(llm=OpenAI())
def __init__(self, llm: BaseLLM, docstore: Docstore, **kwargs: Any):
"""Initialize with the LLM and a docstore."""
docstore_explorer = DocstoreExplorer(docstore)
tools = [
description="Search for a term in the docstore.",
description="Lookup a term in the docstore.",
agent = ReActDocstoreAgent.from_llm_and_tools(llm, tools)
super().__init__(agent=agent, tools=tools, **kwargs)