From 7642f2159cb850032ed7d5670903c924b8554a10 Mon Sep 17 00:00:00 2001 From: Richard He Date: Thu, 18 May 2023 12:09:31 -0400 Subject: [PATCH] Add human message as input variable to chat agent prompt creation (#4542) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Add human message as input variable to chat agent prompt creation This PR adds human message and system message input to `CHAT_ZERO_SHOT_REACT_DESCRIPTION` agent, similar to [conversational chat agent](https://github.com/hwchase17/langchain/blob/7bcf238a1acf40aef21a5a198cf0e62d76f93c15/langchain/agents/conversational_chat/base.py#L64-L71). I met this issue trying to use `create_prompt` function when using the [BabyAGI agent with tools notebook](https://python.langchain.com/en/latest/use_cases/autonomous_agents/baby_agi_with_agent.html), since BabyAGI uses “task” instead of “input” input variable. For normal zero shot react agent this is fine because I can manually change the suffix to “{input}/n/n{agent_scratchpad}” just like the notebook, but I cannot do this with conversational chat agent, therefore blocking me to use BabyAGI with chat zero shot agent. I tested this in my own project [Chrome-GPT](https://github.com/richardyc/Chrome-GPT) and this fix worked. ## Request for review Agents / Tools / Toolkits - @vowelparrot --- langchain/agents/chat/base.py | 33 ++++++++++++++++++++++++--------- langchain/agents/chat/prompt.py | 5 +++-- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/langchain/agents/chat/base.py b/langchain/agents/chat/base.py index 72c5b845..4ea1f23e 100644 --- a/langchain/agents/chat/base.py +++ b/langchain/agents/chat/base.py @@ -4,7 +4,12 @@ from pydantic import Field from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.chat.output_parser import ChatOutputParser -from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX +from langchain.agents.chat.prompt import ( + FORMAT_INSTRUCTIONS, + HUMAN_MESSAGE, + SYSTEM_MESSAGE_PREFIX, + SYSTEM_MESSAGE_SUFFIX, +) from langchain.agents.utils import validate_tools_single_input from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager @@ -64,18 +69,26 @@ class ChatAgent(Agent): def create_prompt( cls, tools: Sequence[BaseTool], - prefix: str = PREFIX, - suffix: str = SUFFIX, + system_message_prefix: str = SYSTEM_MESSAGE_PREFIX, + system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX, + human_message: str = HUMAN_MESSAGE, format_instructions: str = FORMAT_INSTRUCTIONS, input_variables: Optional[List[str]] = None, ) -> BasePromptTemplate: tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) tool_names = ", ".join([tool.name for tool in tools]) format_instructions = format_instructions.format(tool_names=tool_names) - template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) + template = "\n\n".join( + [ + system_message_prefix, + tool_strings, + format_instructions, + system_message_suffix, + ] + ) messages = [ SystemMessagePromptTemplate.from_template(template), - HumanMessagePromptTemplate.from_template("{input}\n\n{agent_scratchpad}"), + HumanMessagePromptTemplate.from_template(human_message), ] if input_variables is None: input_variables = ["input", "agent_scratchpad"] @@ -88,8 +101,9 @@ class ChatAgent(Agent): tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, output_parser: Optional[AgentOutputParser] = None, - prefix: str = PREFIX, - suffix: str = SUFFIX, + system_message_prefix: str = SYSTEM_MESSAGE_PREFIX, + system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX, + human_message: str = HUMAN_MESSAGE, format_instructions: str = FORMAT_INSTRUCTIONS, input_variables: Optional[List[str]] = None, **kwargs: Any, @@ -98,8 +112,9 @@ class ChatAgent(Agent): cls._validate_tools(tools) prompt = cls.create_prompt( tools, - prefix=prefix, - suffix=suffix, + system_message_prefix=system_message_prefix, + system_message_suffix=system_message_suffix, + human_message=human_message, format_instructions=format_instructions, input_variables=input_variables, ) diff --git a/langchain/agents/chat/prompt.py b/langchain/agents/chat/prompt.py index 3625e202..4343739b 100644 --- a/langchain/agents/chat/prompt.py +++ b/langchain/agents/chat/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -PREFIX = """Answer the following questions as best you can. You have access to the following tools:""" +SYSTEM_MESSAGE_PREFIX = """Answer the following questions as best you can. You have access to the following tools:""" FORMAT_INSTRUCTIONS = """The way you use the tools is by specifying a json blob. Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here). @@ -26,4 +26,5 @@ Observation: the result of the action ... (this Thought/Action/Observation can repeat N times) Thought: I now know the final answer Final Answer: the final answer to the original input question""" -SUFFIX = """Begin! Reminder to always use the exact characters `Final Answer` when responding.""" +SYSTEM_MESSAGE_SUFFIX = """Begin! Reminder to always use the exact characters `Final Answer` when responding.""" +HUMAN_MESSAGE = "{input}\n\n{agent_scratchpad}"