From 7dec2d399b3e012136843d168f804e7c958bb4a7 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 20 Sep 2023 13:02:55 -0700 Subject: [PATCH] format intermediate steps (#10794) Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> --- .../agents/format_scratchpad/__init__.py | 0 .../langchain/agents/format_scratchpad/log.py | 16 ++++ .../format_scratchpad/log_to_messages.py | 19 +++++ .../format_scratchpad/openai_functions.py | 66 ++++++++++++++++ .../langchain/agents/format_scratchpad/xml.py | 15 ++++ .../agent_token_buffer_memory.py | 6 +- .../agents/openai_functions_agent/base.py | 76 ++----------------- .../openai_functions_multi_agent/base.py | 70 ++--------------- .../agents/format_scratchpad/__init__.py | 0 .../agents/format_scratchpad/test_log.py | 40 ++++++++++ .../format_scratchpad/test_log_to_messages.py | 49 ++++++++++++ .../test_openai_functions.py | 60 +++++++++++++++ .../agents/format_scratchpad/test_xml.py | 40 ++++++++++ 13 files changed, 319 insertions(+), 138 deletions(-) create mode 100644 libs/langchain/langchain/agents/format_scratchpad/__init__.py create mode 100644 libs/langchain/langchain/agents/format_scratchpad/log.py create mode 100644 libs/langchain/langchain/agents/format_scratchpad/log_to_messages.py create mode 100644 libs/langchain/langchain/agents/format_scratchpad/openai_functions.py create mode 100644 libs/langchain/langchain/agents/format_scratchpad/xml.py create mode 100644 libs/langchain/tests/unit_tests/agents/format_scratchpad/__init__.py create mode 100644 libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log.py create mode 100644 libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log_to_messages.py create mode 100644 libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_functions.py create mode 100644 libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py diff --git a/libs/langchain/langchain/agents/format_scratchpad/__init__.py b/libs/langchain/langchain/agents/format_scratchpad/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/langchain/langchain/agents/format_scratchpad/log.py b/libs/langchain/langchain/agents/format_scratchpad/log.py new file mode 100644 index 0000000000..810556b2c0 --- /dev/null +++ b/libs/langchain/langchain/agents/format_scratchpad/log.py @@ -0,0 +1,16 @@ +from typing import List, Tuple + +from langchain.schema.agent import AgentAction + + +def format_log_to_str( + intermediate_steps: List[Tuple[AgentAction, str]], + observation_prefix: str = "Observation: ", + llm_prefix: str = "Thought: ", +) -> str: + """Construct the scratchpad that lets the agent continue its thought process.""" + thoughts = "" + for action, observation in intermediate_steps: + thoughts += action.log + thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}" + return thoughts diff --git a/libs/langchain/langchain/agents/format_scratchpad/log_to_messages.py b/libs/langchain/langchain/agents/format_scratchpad/log_to_messages.py new file mode 100644 index 0000000000..c370d3a987 --- /dev/null +++ b/libs/langchain/langchain/agents/format_scratchpad/log_to_messages.py @@ -0,0 +1,19 @@ +from typing import List, Tuple + +from langchain.schema.agent import AgentAction +from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage + + +def format_log_to_messages( + intermediate_steps: List[Tuple[AgentAction, str]], + template_tool_response: str = "{observation}", +) -> List[BaseMessage]: + """Construct the scratchpad that lets the agent continue its thought process.""" + thoughts: List[BaseMessage] = [] + for action, observation in intermediate_steps: + thoughts.append(AIMessage(content=action.log)) + human_message = HumanMessage( + content=template_tool_response.format(observation=observation) + ) + thoughts.append(human_message) + return thoughts diff --git a/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py b/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py new file mode 100644 index 0000000000..7294181c94 --- /dev/null +++ b/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py @@ -0,0 +1,66 @@ +import json +from typing import List, Sequence, Tuple + +from langchain.schema.agent import AgentAction, AgentActionMessageLog +from langchain.schema.messages import AIMessage, BaseMessage, FunctionMessage + + +def _convert_agent_action_to_messages( + agent_action: AgentAction, observation: str +) -> List[BaseMessage]: + """Convert an agent action to a message. + + This code is used to reconstruct the original AI message from the agent action. + + Args: + agent_action: Agent action to convert. + + Returns: + AIMessage that corresponds to the original tool invocation. + """ + if isinstance(agent_action, AgentActionMessageLog): + return list(agent_action.message_log) + [ + _create_function_message(agent_action, observation) + ] + else: + return [AIMessage(content=agent_action.log)] + + +def _create_function_message( + agent_action: AgentAction, observation: str +) -> FunctionMessage: + """Convert agent action and observation into a function message. + Args: + agent_action: the tool invocation request from the agent + observation: the result of the tool invocation + Returns: + FunctionMessage that corresponds to the original tool invocation + """ + if not isinstance(observation, str): + try: + content = json.dumps(observation, ensure_ascii=False) + except Exception: + content = str(observation) + else: + content = observation + return FunctionMessage( + name=agent_action.tool, + content=content, + ) + + +def format_to_openai_functions( + intermediate_steps: Sequence[Tuple[AgentAction, str]], +) -> List[BaseMessage]: + """Format intermediate steps. + Args: + intermediate_steps: Steps the LLM has taken to date, along with observations + Returns: + list of messages to send to the LLM for the next prediction + """ + messages = [] + + for agent_action, observation in intermediate_steps: + messages.extend(_convert_agent_action_to_messages(agent_action, observation)) + + return messages diff --git a/libs/langchain/langchain/agents/format_scratchpad/xml.py b/libs/langchain/langchain/agents/format_scratchpad/xml.py new file mode 100644 index 0000000000..6d02cab08b --- /dev/null +++ b/libs/langchain/langchain/agents/format_scratchpad/xml.py @@ -0,0 +1,15 @@ +from typing import List, Tuple + +from langchain.schema.agent import AgentAction + + +def format_xml( + intermediate_steps: List[Tuple[AgentAction, str]], +) -> str: + log = "" + for action, observation in intermediate_steps: + log += ( + f"{action.tool}{action.tool_input}" + f"{observation}" + ) + return log diff --git a/libs/langchain/langchain/agents/openai_functions_agent/agent_token_buffer_memory.py b/libs/langchain/langchain/agents/openai_functions_agent/agent_token_buffer_memory.py index aa24454804..6aff31f385 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/agent_token_buffer_memory.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/agent_token_buffer_memory.py @@ -1,7 +1,9 @@ """Memory used to save agent output AND intermediate steps.""" from typing import Any, Dict, List -from langchain.agents.openai_functions_agent.base import _format_intermediate_steps +from langchain.agents.format_scratchpad.openai_functions import ( + format_to_openai_functions, +) from langchain.memory.chat_memory import BaseChatMemory from langchain.schema.language_model import BaseLanguageModel from langchain.schema.messages import BaseMessage, get_buffer_string @@ -50,7 +52,7 @@ class AgentTokenBufferMemory(BaseChatMemory): """Save context from this conversation to buffer. Pruned.""" input_str, output_str = self._get_input_output(inputs, outputs) self.chat_memory.add_user_message(input_str) - steps = _format_intermediate_steps(outputs[self.intermediate_steps_key]) + steps = format_to_openai_functions(outputs[self.intermediate_steps_key]) for msg in steps: self.chat_memory.add_message(msg) self.chat_memory.add_ai_message(output_str) diff --git a/libs/langchain/langchain/agents/openai_functions_agent/base.py b/libs/langchain/langchain/agents/openai_functions_agent/base.py index c54ff33319..abac565206 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/base.py @@ -1,8 +1,10 @@ """Module implements an agent that uses OpenAI's APIs function enabled API.""" -import json from typing import Any, List, Optional, Sequence, Tuple, Union from langchain.agents import BaseSingleActionAgent +from langchain.agents.format_scratchpad.openai_functions import ( + format_to_openai_functions, +) from langchain.agents.output_parsers.openai_functions import ( OpenAIFunctionsAgentOutputParser, ) @@ -21,82 +23,14 @@ from langchain.schema import ( AgentFinish, BasePromptTemplate, ) -from langchain.schema.agent import AgentActionMessageLog from langchain.schema.language_model import BaseLanguageModel from langchain.schema.messages import ( - AIMessage, BaseMessage, - FunctionMessage, SystemMessage, ) from langchain.tools import BaseTool from langchain.tools.convert_to_openai import format_tool_to_openai_function -# For backwards compatibility -_FunctionsAgentAction = AgentActionMessageLog - - -def _convert_agent_action_to_messages( - agent_action: AgentAction, observation: str -) -> List[BaseMessage]: - """Convert an agent action to a message. - - This code is used to reconstruct the original AI message from the agent action. - - Args: - agent_action: Agent action to convert. - - Returns: - AIMessage that corresponds to the original tool invocation. - """ - if isinstance(agent_action, _FunctionsAgentAction): - return list(agent_action.message_log) + [ - _create_function_message(agent_action, observation) - ] - else: - return [AIMessage(content=agent_action.log)] - - -def _create_function_message( - agent_action: AgentAction, observation: str -) -> FunctionMessage: - """Convert agent action and observation into a function message. - Args: - agent_action: the tool invocation request from the agent - observation: the result of the tool invocation - Returns: - FunctionMessage that corresponds to the original tool invocation - """ - if not isinstance(observation, str): - try: - content = json.dumps(observation, ensure_ascii=False) - except Exception: - content = str(observation) - else: - content = observation - return FunctionMessage( - name=agent_action.tool, - content=content, - ) - - -def _format_intermediate_steps( - intermediate_steps: List[Tuple[AgentAction, str]], -) -> List[BaseMessage]: - """Format intermediate steps. - Args: - intermediate_steps: Steps the LLM has taken to date, along with observations - Returns: - list of messages to send to the LLM for the next prediction - """ - messages = [] - - for intermediate_step in intermediate_steps: - agent_action, observation = intermediate_step - messages.extend(_convert_agent_action_to_messages(agent_action, observation)) - - return messages - class OpenAIFunctionsAgent(BaseSingleActionAgent): """An Agent driven by OpenAIs function powered API. @@ -159,7 +93,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): Returns: Action specifying what tool to use. """ - agent_scratchpad = _format_intermediate_steps(intermediate_steps) + agent_scratchpad = format_to_openai_functions(intermediate_steps) selected_inputs = { k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" } @@ -198,7 +132,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): Returns: Action specifying what tool to use. """ - agent_scratchpad = _format_intermediate_steps(intermediate_steps) + agent_scratchpad = format_to_openai_functions(intermediate_steps) selected_inputs = { k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" } diff --git a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py index 5849cf9718..0fbafd591b 100644 --- a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py @@ -4,6 +4,9 @@ from json import JSONDecodeError from typing import Any, List, Optional, Sequence, Tuple, Union from langchain.agents import BaseMultiActionAgent +from langchain.agents.format_scratchpad.openai_functions import ( + format_to_openai_functions, +) from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.chat_models.openai import ChatOpenAI @@ -25,7 +28,6 @@ from langchain.schema.language_model import BaseLanguageModel from langchain.schema.messages import ( AIMessage, BaseMessage, - FunctionMessage, SystemMessage, ) from langchain.tools import BaseTool @@ -34,68 +36,6 @@ from langchain.tools import BaseTool _FunctionsAgentAction = AgentActionMessageLog -def _convert_agent_action_to_messages( - agent_action: AgentAction, observation: str -) -> List[BaseMessage]: - """Convert an agent action to a message. - - This code is used to reconstruct the original AI message from the agent action. - - Args: - agent_action: Agent action to convert. - - Returns: - AIMessage that corresponds to the original tool invocation. - """ - if isinstance(agent_action, _FunctionsAgentAction): - return list(agent_action.message_log) + [ - _create_function_message(agent_action, observation) - ] - else: - return [AIMessage(content=agent_action.log)] - - -def _create_function_message( - agent_action: AgentAction, observation: str -) -> FunctionMessage: - """Convert agent action and observation into a function message. - Args: - agent_action: the tool invocation request from the agent - observation: the result of the tool invocation - Returns: - FunctionMessage that corresponds to the original tool invocation - """ - if not isinstance(observation, str): - try: - content = json.dumps(observation, ensure_ascii=False) - except Exception: - content = str(observation) - else: - content = observation - return FunctionMessage( - name=agent_action.tool, - content=content, - ) - - -def _format_intermediate_steps( - intermediate_steps: List[Tuple[AgentAction, str]], -) -> List[BaseMessage]: - """Format intermediate steps. - Args: - intermediate_steps: Steps the LLM has taken to date, along with observations - Returns: - list of messages to send to the LLM for the next prediction - """ - messages = [] - - for intermediate_step in intermediate_steps: - agent_action, observation = intermediate_step - messages.extend(_convert_agent_action_to_messages(agent_action, observation)) - - return messages - - def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFinish]: """Parse an AI message.""" if not isinstance(message, AIMessage): @@ -257,7 +197,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent): Returns: Action specifying what tool to use. """ - agent_scratchpad = _format_intermediate_steps(intermediate_steps) + agent_scratchpad = format_to_openai_functions(intermediate_steps) selected_inputs = { k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" } @@ -286,7 +226,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent): Returns: Action specifying what tool to use. """ - agent_scratchpad = _format_intermediate_steps(intermediate_steps) + agent_scratchpad = format_to_openai_functions(intermediate_steps) selected_inputs = { k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" } diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/__init__.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log.py new file mode 100644 index 0000000000..411b57695c --- /dev/null +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log.py @@ -0,0 +1,40 @@ +from langchain.agents.format_scratchpad.log import format_log_to_str +from langchain.schema.agent import AgentAction + + +def test_single_agent_action_observation() -> None: + intermediate_steps = [ + (AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1") + ] + expected_result = "Log1\nObservation: Observation1\nThought: " + assert format_log_to_str(intermediate_steps) == expected_result + + +def test_multiple_agent_actions_observations() -> None: + intermediate_steps = [ + (AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1"), + (AgentAction(tool="Tool2", tool_input="input2", log="Log2"), "Observation2"), + (AgentAction(tool="Tool3", tool_input="input3", log="Log3"), "Observation3"), + ] + expected_result = """Log1\nObservation: Observation1\nThought: \ +Log2\nObservation: Observation2\nThought: Log3\nObservation: \ +Observation3\nThought: """ + assert format_log_to_str(intermediate_steps) == expected_result + + +def test_custom_prefixes() -> None: + intermediate_steps = [ + (AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1") + ] + observation_prefix = "Custom Observation: " + llm_prefix = "Custom Thought: " + expected_result = "Log1\nCustom Observation: Observation1\nCustom Thought: " + assert ( + format_log_to_str(intermediate_steps, observation_prefix, llm_prefix) + == expected_result + ) + + +def test_empty_intermediate_steps() -> None: + output = format_log_to_str([]) + assert output == "" diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log_to_messages.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log_to_messages.py new file mode 100644 index 0000000000..ed7664c8b0 --- /dev/null +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log_to_messages.py @@ -0,0 +1,49 @@ +from langchain.agents.format_scratchpad.log_to_messages import format_log_to_messages +from langchain.schema.agent import AgentAction +from langchain.schema.messages import AIMessage, HumanMessage + + +def test_single_intermediate_step_default_response() -> None: + intermediate_steps = [ + (AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1") + ] + expected_result = [AIMessage(content="Log1"), HumanMessage(content="Observation1")] + assert format_log_to_messages(intermediate_steps) == expected_result + + +def test_multiple_intermediate_steps_default_response() -> None: + intermediate_steps = [ + (AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1"), + (AgentAction(tool="Tool2", tool_input="input2", log="Log2"), "Observation2"), + (AgentAction(tool="Tool3", tool_input="input3", log="Log3"), "Observation3"), + ] + expected_result = [ + AIMessage(content="Log1"), + HumanMessage(content="Observation1"), + AIMessage(content="Log2"), + HumanMessage(content="Observation2"), + AIMessage(content="Log3"), + HumanMessage(content="Observation3"), + ] + assert format_log_to_messages(intermediate_steps) == expected_result + + +def test_custom_template_tool_response() -> None: + intermediate_steps = [ + (AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1") + ] + template_tool_response = "Response: {observation}" + expected_result = [ + AIMessage(content="Log1"), + HumanMessage(content="Response: Observation1"), + ] + assert ( + format_log_to_messages( + intermediate_steps, template_tool_response=template_tool_response + ) + == expected_result + ) + + +def test_empty_steps() -> None: + assert format_log_to_messages([]) == [] diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_functions.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_functions.py new file mode 100644 index 0000000000..92dc5bcb6d --- /dev/null +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_functions.py @@ -0,0 +1,60 @@ +from langchain.agents.format_scratchpad.openai_functions import ( + format_to_openai_functions, +) +from langchain.schema.agent import AgentActionMessageLog +from langchain.schema.messages import AIMessage, FunctionMessage + + +def test_calls_convert_agent_action_to_messages() -> None: + additional_kwargs1 = { + "function_call": { + "name": "tool1", + "arguments": "input1", + } + } + message1 = AIMessage(content="", additional_kwargs=additional_kwargs1) + action1 = AgentActionMessageLog( + tool="tool1", tool_input="input1", log="log1", message_log=[message1] + ) + additional_kwargs2 = { + "function_call": { + "name": "tool2", + "arguments": "input2", + } + } + message2 = AIMessage(content="", additional_kwargs=additional_kwargs2) + action2 = AgentActionMessageLog( + tool="tool2", tool_input="input2", log="log2", message_log=[message2] + ) + + additional_kwargs3 = { + "function_call": { + "name": "tool3", + "arguments": "input3", + } + } + message3 = AIMessage(content="", additional_kwargs=additional_kwargs3) + action3 = AgentActionMessageLog( + tool="tool3", tool_input="input3", log="log3", message_log=[message3] + ) + + intermediate_steps = [ + (action1, "observation1"), + (action2, "observation2"), + (action3, "observation3"), + ] + expected_messages = [ + message1, + FunctionMessage(name="tool1", content="observation1"), + message2, + FunctionMessage(name="tool2", content="observation2"), + message3, + FunctionMessage(name="tool3", content="observation3"), + ] + output = format_to_openai_functions(intermediate_steps) + assert output == expected_messages + + +def test_handles_empty_input_list() -> None: + output = format_to_openai_functions([]) + assert output == [] diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py new file mode 100644 index 0000000000..2509091ffd --- /dev/null +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py @@ -0,0 +1,40 @@ +from langchain.agents.format_scratchpad.xml import format_xml +from langchain.schema.agent import AgentAction + + +def test_single_agent_action_observation() -> None: + # Arrange + agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="Log1") + observation = "Observation1" + intermediate_steps = [(agent_action, observation)] + + # Act + result = format_xml(intermediate_steps) + expected_result = """Tool1Input1\ +Observation1""" + # Assert + assert result == expected_result + + +def test_multiple_agent_actions_observations() -> None: + # Arrange + agent_action1 = AgentAction(tool="Tool1", tool_input="Input1", log="Log1") + agent_action2 = AgentAction(tool="Tool2", tool_input="Input2", log="Log2") + observation1 = "Observation1" + observation2 = "Observation2" + intermediate_steps = [(agent_action1, observation1), (agent_action2, observation2)] + + # Act + result = format_xml(intermediate_steps) + + # Assert + expected_result = """Tool1Input1\ +Observation1\ +Tool2Input2\ +Observation2""" + assert result == expected_result + + +def test_empty_list_agent_actions() -> None: + result = format_xml([]) + assert result == ""