langchain[patch]: Extract _aperform_agent_action from _aiter_next_step from AgentExecutor (#15707)

- **Description:** extreact the _aperform_agent_action in the
AgentExecutor class to allow for easier overriding. Extracted logic from
_iter_next_step into a new method _perform_agent_action for consistency
and easier overriding.
- **Issue:** #15706

Closes #15706
pull/16487/head
Gianfranco Demarco 5 months ago committed by GitHub
parent 95ee69a301
commit c69f599594
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1179,37 +1179,48 @@ class AgentExecutor(Chain):
for agent_action in actions:
yield agent_action
for agent_action in actions:
if run_manager:
run_manager.on_agent_action(agent_action, color="green")
# Otherwise we lookup the tool
if agent_action.tool in name_to_tool_map:
tool = name_to_tool_map[agent_action.tool]
return_direct = tool.return_direct
color = color_mapping[agent_action.tool]
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
observation = tool.run(
agent_action.tool_input,
verbose=self.verbose,
color=color,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = InvalidTool().run(
{
"requested_tool_name": agent_action.tool,
"available_tool_names": list(name_to_tool_map.keys()),
},
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
yield AgentStep(action=agent_action, observation=observation)
yield self._perform_agent_action(
name_to_tool_map, color_mapping, agent_action, run_manager
)
def _perform_agent_action(
self,
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
agent_action: AgentAction,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> AgentStep:
if run_manager:
run_manager.on_agent_action(agent_action, color="green")
# Otherwise we lookup the tool
if agent_action.tool in name_to_tool_map:
tool = name_to_tool_map[agent_action.tool]
return_direct = tool.return_direct
color = color_mapping[agent_action.tool]
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
observation = tool.run(
agent_action.tool_input,
verbose=self.verbose,
color=color,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = InvalidTool().run(
{
"requested_tool_name": agent_action.tool,
"available_tool_names": list(name_to_tool_map.keys()),
},
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
return AgentStep(action=agent_action, observation=observation)
async def _atake_next_step(
self,
@ -1303,52 +1314,61 @@ class AgentExecutor(Chain):
for agent_action in actions:
yield agent_action
async def _aperform_agent_action(
agent_action: AgentAction,
) -> AgentStep:
if run_manager:
await run_manager.on_agent_action(
agent_action, verbose=self.verbose, color="green"
)
# Otherwise we lookup the tool
if agent_action.tool in name_to_tool_map:
tool = name_to_tool_map[agent_action.tool]
return_direct = tool.return_direct
color = color_mapping[agent_action.tool]
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
observation = await tool.arun(
agent_action.tool_input,
verbose=self.verbose,
color=color,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = await InvalidTool().arun(
{
"requested_tool_name": agent_action.tool,
"available_tool_names": list(name_to_tool_map.keys()),
},
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
return AgentStep(action=agent_action, observation=observation)
# Use asyncio.gather to run multiple tool.arun() calls concurrently
result = await asyncio.gather(
*[_aperform_agent_action(agent_action) for agent_action in actions]
*[
self._aperform_agent_action(
name_to_tool_map, color_mapping, agent_action, run_manager
)
for agent_action in actions
],
)
# TODO This could yield each result as it becomes available
for chunk in result:
yield chunk
async def _aperform_agent_action(
self,
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
agent_action: AgentAction,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> AgentStep:
if run_manager:
await run_manager.on_agent_action(
agent_action, verbose=self.verbose, color="green"
)
# Otherwise we lookup the tool
if agent_action.tool in name_to_tool_map:
tool = name_to_tool_map[agent_action.tool]
return_direct = tool.return_direct
color = color_mapping[agent_action.tool]
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
observation = await tool.arun(
agent_action.tool_input,
verbose=self.verbose,
color=color,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = await InvalidTool().arun(
{
"requested_tool_name": agent_action.tool,
"available_tool_names": list(name_to_tool_map.keys()),
},
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
return AgentStep(action=agent_action, observation=observation)
def _call(
self,
inputs: Dict[str, str],

Loading…
Cancel
Save