format intermediate steps (#10794)

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
pull/10856/head
Harrison Chase 12 months ago committed by GitHub
parent 386ef1e654
commit 7dec2d399b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,16 @@
from typing import List, Tuple
from langchain.schema.agent import AgentAction
def format_log_to_str(
intermediate_steps: List[Tuple[AgentAction, str]],
observation_prefix: str = "Observation: ",
llm_prefix: str = "Thought: ",
) -> str:
"""Construct the scratchpad that lets the agent continue its thought process."""
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
return thoughts

@ -0,0 +1,19 @@
from typing import List, Tuple
from langchain.schema.agent import AgentAction
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
def format_log_to_messages(
intermediate_steps: List[Tuple[AgentAction, str]],
template_tool_response: str = "{observation}",
) -> List[BaseMessage]:
"""Construct the scratchpad that lets the agent continue its thought process."""
thoughts: List[BaseMessage] = []
for action, observation in intermediate_steps:
thoughts.append(AIMessage(content=action.log))
human_message = HumanMessage(
content=template_tool_response.format(observation=observation)
)
thoughts.append(human_message)
return thoughts

@ -0,0 +1,66 @@
import json
from typing import List, Sequence, Tuple
from langchain.schema.agent import AgentAction, AgentActionMessageLog
from langchain.schema.messages import AIMessage, BaseMessage, FunctionMessage
def _convert_agent_action_to_messages(
agent_action: AgentAction, observation: str
) -> List[BaseMessage]:
"""Convert an agent action to a message.
This code is used to reconstruct the original AI message from the agent action.
Args:
agent_action: Agent action to convert.
Returns:
AIMessage that corresponds to the original tool invocation.
"""
if isinstance(agent_action, AgentActionMessageLog):
return list(agent_action.message_log) + [
_create_function_message(agent_action, observation)
]
else:
return [AIMessage(content=agent_action.log)]
def _create_function_message(
agent_action: AgentAction, observation: str
) -> FunctionMessage:
"""Convert agent action and observation into a function message.
Args:
agent_action: the tool invocation request from the agent
observation: the result of the tool invocation
Returns:
FunctionMessage that corresponds to the original tool invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)
def format_to_openai_functions(
intermediate_steps: Sequence[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Format intermediate steps.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
"""
messages = []
for agent_action, observation in intermediate_steps:
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
return messages

@ -0,0 +1,15 @@
from typing import List, Tuple
from langchain.schema.agent import AgentAction
def format_xml(
intermediate_steps: List[Tuple[AgentAction, str]],
) -> str:
log = ""
for action, observation in intermediate_steps:
log += (
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}"
f"</tool_input><observation>{observation}</observation>"
)
return log

@ -1,7 +1,9 @@
"""Memory used to save agent output AND intermediate steps.""" """Memory used to save agent output AND intermediate steps."""
from typing import Any, Dict, List from typing import Any, Dict, List
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps from langchain.agents.format_scratchpad.openai_functions import (
format_to_openai_functions,
)
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage, get_buffer_string from langchain.schema.messages import BaseMessage, get_buffer_string
@ -50,7 +52,7 @@ class AgentTokenBufferMemory(BaseChatMemory):
"""Save context from this conversation to buffer. Pruned.""" """Save context from this conversation to buffer. Pruned."""
input_str, output_str = self._get_input_output(inputs, outputs) input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_user_message(input_str) self.chat_memory.add_user_message(input_str)
steps = _format_intermediate_steps(outputs[self.intermediate_steps_key]) steps = format_to_openai_functions(outputs[self.intermediate_steps_key])
for msg in steps: for msg in steps:
self.chat_memory.add_message(msg) self.chat_memory.add_message(msg)
self.chat_memory.add_ai_message(output_str) self.chat_memory.add_ai_message(output_str)

@ -1,8 +1,10 @@
"""Module implements an agent that uses OpenAI's APIs function enabled API.""" """Module implements an agent that uses OpenAI's APIs function enabled API."""
import json
from typing import Any, List, Optional, Sequence, Tuple, Union from typing import Any, List, Optional, Sequence, Tuple, Union
from langchain.agents import BaseSingleActionAgent from langchain.agents import BaseSingleActionAgent
from langchain.agents.format_scratchpad.openai_functions import (
format_to_openai_functions,
)
from langchain.agents.output_parsers.openai_functions import ( from langchain.agents.output_parsers.openai_functions import (
OpenAIFunctionsAgentOutputParser, OpenAIFunctionsAgentOutputParser,
) )
@ -21,82 +23,14 @@ from langchain.schema import (
AgentFinish, AgentFinish,
BasePromptTemplate, BasePromptTemplate,
) )
from langchain.schema.agent import AgentActionMessageLog
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage,
BaseMessage, BaseMessage,
FunctionMessage,
SystemMessage, SystemMessage,
) )
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.tools.convert_to_openai import format_tool_to_openai_function from langchain.tools.convert_to_openai import format_tool_to_openai_function
# For backwards compatibility
_FunctionsAgentAction = AgentActionMessageLog
def _convert_agent_action_to_messages(
agent_action: AgentAction, observation: str
) -> List[BaseMessage]:
"""Convert an agent action to a message.
This code is used to reconstruct the original AI message from the agent action.
Args:
agent_action: Agent action to convert.
Returns:
AIMessage that corresponds to the original tool invocation.
"""
if isinstance(agent_action, _FunctionsAgentAction):
return list(agent_action.message_log) + [
_create_function_message(agent_action, observation)
]
else:
return [AIMessage(content=agent_action.log)]
def _create_function_message(
agent_action: AgentAction, observation: str
) -> FunctionMessage:
"""Convert agent action and observation into a function message.
Args:
agent_action: the tool invocation request from the agent
observation: the result of the tool invocation
Returns:
FunctionMessage that corresponds to the original tool invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)
def _format_intermediate_steps(
intermediate_steps: List[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Format intermediate steps.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
"""
messages = []
for intermediate_step in intermediate_steps:
agent_action, observation = intermediate_step
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
return messages
class OpenAIFunctionsAgent(BaseSingleActionAgent): class OpenAIFunctionsAgent(BaseSingleActionAgent):
"""An Agent driven by OpenAIs function powered API. """An Agent driven by OpenAIs function powered API.
@ -159,7 +93,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
Returns: Returns:
Action specifying what tool to use. Action specifying what tool to use.
""" """
agent_scratchpad = _format_intermediate_steps(intermediate_steps) agent_scratchpad = format_to_openai_functions(intermediate_steps)
selected_inputs = { selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
} }
@ -198,7 +132,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
Returns: Returns:
Action specifying what tool to use. Action specifying what tool to use.
""" """
agent_scratchpad = _format_intermediate_steps(intermediate_steps) agent_scratchpad = format_to_openai_functions(intermediate_steps)
selected_inputs = { selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
} }

@ -4,6 +4,9 @@ from json import JSONDecodeError
from typing import Any, List, Optional, Sequence, Tuple, Union from typing import Any, List, Optional, Sequence, Tuple, Union
from langchain.agents import BaseMultiActionAgent from langchain.agents import BaseMultiActionAgent
from langchain.agents.format_scratchpad.openai_functions import (
format_to_openai_functions,
)
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.openai import ChatOpenAI
@ -25,7 +28,6 @@ from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
FunctionMessage,
SystemMessage, SystemMessage,
) )
from langchain.tools import BaseTool from langchain.tools import BaseTool
@ -34,68 +36,6 @@ from langchain.tools import BaseTool
_FunctionsAgentAction = AgentActionMessageLog _FunctionsAgentAction = AgentActionMessageLog
def _convert_agent_action_to_messages(
agent_action: AgentAction, observation: str
) -> List[BaseMessage]:
"""Convert an agent action to a message.
This code is used to reconstruct the original AI message from the agent action.
Args:
agent_action: Agent action to convert.
Returns:
AIMessage that corresponds to the original tool invocation.
"""
if isinstance(agent_action, _FunctionsAgentAction):
return list(agent_action.message_log) + [
_create_function_message(agent_action, observation)
]
else:
return [AIMessage(content=agent_action.log)]
def _create_function_message(
agent_action: AgentAction, observation: str
) -> FunctionMessage:
"""Convert agent action and observation into a function message.
Args:
agent_action: the tool invocation request from the agent
observation: the result of the tool invocation
Returns:
FunctionMessage that corresponds to the original tool invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)
def _format_intermediate_steps(
intermediate_steps: List[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Format intermediate steps.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
"""
messages = []
for intermediate_step in intermediate_steps:
agent_action, observation = intermediate_step
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
return messages
def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFinish]: def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFinish]:
"""Parse an AI message.""" """Parse an AI message."""
if not isinstance(message, AIMessage): if not isinstance(message, AIMessage):
@ -257,7 +197,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
Returns: Returns:
Action specifying what tool to use. Action specifying what tool to use.
""" """
agent_scratchpad = _format_intermediate_steps(intermediate_steps) agent_scratchpad = format_to_openai_functions(intermediate_steps)
selected_inputs = { selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
} }
@ -286,7 +226,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
Returns: Returns:
Action specifying what tool to use. Action specifying what tool to use.
""" """
agent_scratchpad = _format_intermediate_steps(intermediate_steps) agent_scratchpad = format_to_openai_functions(intermediate_steps)
selected_inputs = { selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
} }

@ -0,0 +1,40 @@
from langchain.agents.format_scratchpad.log import format_log_to_str
from langchain.schema.agent import AgentAction
def test_single_agent_action_observation() -> None:
intermediate_steps = [
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1")
]
expected_result = "Log1\nObservation: Observation1\nThought: "
assert format_log_to_str(intermediate_steps) == expected_result
def test_multiple_agent_actions_observations() -> None:
intermediate_steps = [
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1"),
(AgentAction(tool="Tool2", tool_input="input2", log="Log2"), "Observation2"),
(AgentAction(tool="Tool3", tool_input="input3", log="Log3"), "Observation3"),
]
expected_result = """Log1\nObservation: Observation1\nThought: \
Log2\nObservation: Observation2\nThought: Log3\nObservation: \
Observation3\nThought: """
assert format_log_to_str(intermediate_steps) == expected_result
def test_custom_prefixes() -> None:
intermediate_steps = [
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1")
]
observation_prefix = "Custom Observation: "
llm_prefix = "Custom Thought: "
expected_result = "Log1\nCustom Observation: Observation1\nCustom Thought: "
assert (
format_log_to_str(intermediate_steps, observation_prefix, llm_prefix)
== expected_result
)
def test_empty_intermediate_steps() -> None:
output = format_log_to_str([])
assert output == ""

@ -0,0 +1,49 @@
from langchain.agents.format_scratchpad.log_to_messages import format_log_to_messages
from langchain.schema.agent import AgentAction
from langchain.schema.messages import AIMessage, HumanMessage
def test_single_intermediate_step_default_response() -> None:
intermediate_steps = [
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1")
]
expected_result = [AIMessage(content="Log1"), HumanMessage(content="Observation1")]
assert format_log_to_messages(intermediate_steps) == expected_result
def test_multiple_intermediate_steps_default_response() -> None:
intermediate_steps = [
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1"),
(AgentAction(tool="Tool2", tool_input="input2", log="Log2"), "Observation2"),
(AgentAction(tool="Tool3", tool_input="input3", log="Log3"), "Observation3"),
]
expected_result = [
AIMessage(content="Log1"),
HumanMessage(content="Observation1"),
AIMessage(content="Log2"),
HumanMessage(content="Observation2"),
AIMessage(content="Log3"),
HumanMessage(content="Observation3"),
]
assert format_log_to_messages(intermediate_steps) == expected_result
def test_custom_template_tool_response() -> None:
intermediate_steps = [
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1")
]
template_tool_response = "Response: {observation}"
expected_result = [
AIMessage(content="Log1"),
HumanMessage(content="Response: Observation1"),
]
assert (
format_log_to_messages(
intermediate_steps, template_tool_response=template_tool_response
)
== expected_result
)
def test_empty_steps() -> None:
assert format_log_to_messages([]) == []

@ -0,0 +1,60 @@
from langchain.agents.format_scratchpad.openai_functions import (
format_to_openai_functions,
)
from langchain.schema.agent import AgentActionMessageLog
from langchain.schema.messages import AIMessage, FunctionMessage
def test_calls_convert_agent_action_to_messages() -> None:
additional_kwargs1 = {
"function_call": {
"name": "tool1",
"arguments": "input1",
}
}
message1 = AIMessage(content="", additional_kwargs=additional_kwargs1)
action1 = AgentActionMessageLog(
tool="tool1", tool_input="input1", log="log1", message_log=[message1]
)
additional_kwargs2 = {
"function_call": {
"name": "tool2",
"arguments": "input2",
}
}
message2 = AIMessage(content="", additional_kwargs=additional_kwargs2)
action2 = AgentActionMessageLog(
tool="tool2", tool_input="input2", log="log2", message_log=[message2]
)
additional_kwargs3 = {
"function_call": {
"name": "tool3",
"arguments": "input3",
}
}
message3 = AIMessage(content="", additional_kwargs=additional_kwargs3)
action3 = AgentActionMessageLog(
tool="tool3", tool_input="input3", log="log3", message_log=[message3]
)
intermediate_steps = [
(action1, "observation1"),
(action2, "observation2"),
(action3, "observation3"),
]
expected_messages = [
message1,
FunctionMessage(name="tool1", content="observation1"),
message2,
FunctionMessage(name="tool2", content="observation2"),
message3,
FunctionMessage(name="tool3", content="observation3"),
]
output = format_to_openai_functions(intermediate_steps)
assert output == expected_messages
def test_handles_empty_input_list() -> None:
output = format_to_openai_functions([])
assert output == []

@ -0,0 +1,40 @@
from langchain.agents.format_scratchpad.xml import format_xml
from langchain.schema.agent import AgentAction
def test_single_agent_action_observation() -> None:
# Arrange
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
observation = "Observation1"
intermediate_steps = [(agent_action, observation)]
# Act
result = format_xml(intermediate_steps)
expected_result = """<tool>Tool1</tool><tool_input>Input1\
</tool_input><observation>Observation1</observation>"""
# Assert
assert result == expected_result
def test_multiple_agent_actions_observations() -> None:
# Arrange
agent_action1 = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
agent_action2 = AgentAction(tool="Tool2", tool_input="Input2", log="Log2")
observation1 = "Observation1"
observation2 = "Observation2"
intermediate_steps = [(agent_action1, observation1), (agent_action2, observation2)]
# Act
result = format_xml(intermediate_steps)
# Assert
expected_result = """<tool>Tool1</tool><tool_input>Input1\
</tool_input><observation>Observation1</observation><tool>\
Tool2</tool><tool_input>Input2</tool_input><observation>\
Observation2</observation>"""
assert result == expected_result
def test_empty_list_agent_actions() -> None:
result = format_xml([])
assert result == ""
Loading…
Cancel
Save