diff --git a/langchain/agents/__init__.py b/langchain/agents/__init__.py index d5376947..6c41e16b 100644 --- a/langchain/agents/__init__.py +++ b/langchain/agents/__init__.py @@ -2,7 +2,7 @@ from langchain.agents.agent import Agent from langchain.agents.loading import initialize_agent from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent -from langchain.agents.react.base import ReActChain +from langchain.agents.react.base import ReActChain, ReActTextWorldAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain from langchain.agents.tools import Tool @@ -14,4 +14,5 @@ __all__ = [ "Tool", "initialize_agent", "ZeroShotAgent", + "ReActTextWorldAgent", ] diff --git a/langchain/agents/react/base.py b/langchain/agents/react/base.py index a9f1d4ca..ca380e1a 100644 --- a/langchain/agents/react/base.py +++ b/langchain/agents/react/base.py @@ -5,7 +5,8 @@ from typing import Any, ClassVar, List, Optional, Tuple from pydantic import BaseModel from langchain.agents.agent import Agent -from langchain.agents.react.prompt import PROMPT +from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT +from langchain.agents.react.wiki_prompt import WIKI_PROMPT from langchain.agents.tools import Tool from langchain.chains.llm import LLMChain from langchain.docstore.base import Docstore @@ -17,7 +18,7 @@ from langchain.prompts.base import BasePromptTemplate class ReActDocstoreAgent(Agent, BaseModel): """Agent for the ReAct chin.""" - prompt: ClassVar[BasePromptTemplate] = PROMPT + prompt: ClassVar[BasePromptTemplate] = WIKI_PROMPT i: int = 1 @@ -96,6 +97,22 @@ class DocstoreExplorer: return self.document.lookup(term) +class ReActTextWorldAgent(ReActDocstoreAgent, BaseModel): + """Agent for the ReAct TextWorld chain.""" + + prompt: ClassVar[BasePromptTemplate] = TEXTWORLD_PROMPT + + i: int = 1 + + @classmethod + def _validate_tools(cls, tools: List[Tool]) -> None: + if len(tools) != 1: + raise ValueError(f"Exactly one tool must be specified, but got {tools}") + tool_names = {tool.name for tool in tools} + if tool_names != {"Play"}: + raise ValueError(f"Tool name should be Play, got {tool_names}") + + class ReActChain(ReActDocstoreAgent): """Chain that implements the ReAct paper. @@ -113,5 +130,5 @@ class ReActChain(ReActDocstoreAgent): Tool(name="Search", func=docstore_explorer.search), Tool(name="Lookup", func=docstore_explorer.lookup), ] - llm_chain = LLMChain(llm=llm, prompt=PROMPT) + llm_chain = LLMChain(llm=llm, prompt=WIKI_PROMPT) super().__init__(llm_chain=llm_chain, tools=tools, **kwargs) diff --git a/langchain/agents/react/textworld_prompt.py b/langchain/agents/react/textworld_prompt.py new file mode 100644 index 00000000..d8bb31e8 --- /dev/null +++ b/langchain/agents/react/textworld_prompt.py @@ -0,0 +1,49 @@ +# flake8: noqa +from langchain.prompts.prompt import PromptTemplate + +EXAMPLES = [ + """Setup: You are now playing a fast paced round of TextWorld! Here is your task for +today. First of all, you could, like, try to travel east. After that, take the +binder from the locker. With the binder, place the binder on the mantelpiece. +Alright, thanks! + +-= Vault =- +You've just walked into a vault. You begin to take stock of what's here. + +An open safe is here. What a letdown! The safe is empty! You make out a shelf. +But the thing hasn't got anything on it. What, you think everything in TextWorld +should have stuff on it? + +You don't like doors? Why not try going east, that entranceway is unguarded. + +Thought 1: I need to travel east +Action 1: Play[go east] +Observation 1: -= Office =- +You arrive in an office. An ordinary one. + +You can make out a locker. The locker contains a binder. You see a case. The +case is empty, what a horrible day! You lean against the wall, inadvertently +pressing a secret button. The wall opens up to reveal a mantelpiece. You wonder +idly who left that here. The mantelpiece is standard. The mantelpiece appears to +be empty. If you haven't noticed it already, there seems to be something there +by the wall, it's a table. Unfortunately, there isn't a thing on it. Hm. Oh well +There is an exit to the west. Don't worry, it is unguarded. + +Thought 2: I need to take the binder from the locker +Action 2: Play[take binder] +Observation 2: You take the binder from the locker. + +Thought 3: I need to place the binder on the mantelpiece +Action 3: Play[put binder on mantelpiece] + +Observation 3: You put the binder on the mantelpiece. +Your score has just gone up by one point. +*** The End *** +Thought 4: The End has occurred +Action 4: Finish[yes] + +""" +] +SUFFIX = """\n\nSetup: {input}""" + +TEXTWORLD_PROMPT = PromptTemplate.from_examples(EXAMPLES, SUFFIX, ["input"]) diff --git a/langchain/agents/react/prompt.py b/langchain/agents/react/wiki_prompt.py similarity index 98% rename from langchain/agents/react/prompt.py rename to langchain/agents/react/wiki_prompt.py index 33fbada6..27f7565e 100644 --- a/langchain/agents/react/prompt.py +++ b/langchain/agents/react/wiki_prompt.py @@ -109,4 +109,4 @@ Action 3: Finish[yes]""", ] SUFFIX = """\n\nQuestion: {input}""" -PROMPT = PromptTemplate.from_examples(EXAMPLES, SUFFIX, ["input"]) +WIKI_PROMPT = PromptTemplate.from_examples(EXAMPLES, SUFFIX, ["input"])