diff --git a/langchain/agents/chat/base.py b/langchain/agents/chat/base.py index 72c5b845..4ea1f23e 100644 --- a/langchain/agents/chat/base.py +++ b/langchain/agents/chat/base.py @@ -4,7 +4,12 @@ from pydantic import Field from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.chat.output_parser import ChatOutputParser -from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX +from langchain.agents.chat.prompt import ( + FORMAT_INSTRUCTIONS, + HUMAN_MESSAGE, + SYSTEM_MESSAGE_PREFIX, + SYSTEM_MESSAGE_SUFFIX, +) from langchain.agents.utils import validate_tools_single_input from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager @@ -64,18 +69,26 @@ class ChatAgent(Agent): def create_prompt( cls, tools: Sequence[BaseTool], - prefix: str = PREFIX, - suffix: str = SUFFIX, + system_message_prefix: str = SYSTEM_MESSAGE_PREFIX, + system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX, + human_message: str = HUMAN_MESSAGE, format_instructions: str = FORMAT_INSTRUCTIONS, input_variables: Optional[List[str]] = None, ) -> BasePromptTemplate: tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) tool_names = ", ".join([tool.name for tool in tools]) format_instructions = format_instructions.format(tool_names=tool_names) - template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) + template = "\n\n".join( + [ + system_message_prefix, + tool_strings, + format_instructions, + system_message_suffix, + ] + ) messages = [ SystemMessagePromptTemplate.from_template(template), - HumanMessagePromptTemplate.from_template("{input}\n\n{agent_scratchpad}"), + HumanMessagePromptTemplate.from_template(human_message), ] if input_variables is None: input_variables = ["input", "agent_scratchpad"] @@ -88,8 +101,9 @@ class ChatAgent(Agent): tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, output_parser: Optional[AgentOutputParser] = None, - prefix: str = PREFIX, - suffix: str = SUFFIX, + system_message_prefix: str = SYSTEM_MESSAGE_PREFIX, + system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX, + human_message: str = HUMAN_MESSAGE, format_instructions: str = FORMAT_INSTRUCTIONS, input_variables: Optional[List[str]] = None, **kwargs: Any, @@ -98,8 +112,9 @@ class ChatAgent(Agent): cls._validate_tools(tools) prompt = cls.create_prompt( tools, - prefix=prefix, - suffix=suffix, + system_message_prefix=system_message_prefix, + system_message_suffix=system_message_suffix, + human_message=human_message, format_instructions=format_instructions, input_variables=input_variables, ) diff --git a/langchain/agents/chat/prompt.py b/langchain/agents/chat/prompt.py index 3625e202..4343739b 100644 --- a/langchain/agents/chat/prompt.py +++ b/langchain/agents/chat/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -PREFIX = """Answer the following questions as best you can. You have access to the following tools:""" +SYSTEM_MESSAGE_PREFIX = """Answer the following questions as best you can. You have access to the following tools:""" FORMAT_INSTRUCTIONS = """The way you use the tools is by specifying a json blob. Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here). @@ -26,4 +26,5 @@ Observation: the result of the action ... (this Thought/Action/Observation can repeat N times) Thought: I now know the final answer Final Answer: the final answer to the original input question""" -SUFFIX = """Begin! Reminder to always use the exact characters `Final Answer` when responding.""" +SYSTEM_MESSAGE_SUFFIX = """Begin! Reminder to always use the exact characters `Final Answer` when responding.""" +HUMAN_MESSAGE = "{input}\n\n{agent_scratchpad}"