Run tools concurrently in _atake_next_step (#2537)

small refactor to allow this
This commit is contained in:
Ankush Gola 2023-04-07 16:23:03 +02:00 committed by GitHub
parent 6dbd29e440
commit dca21078ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
"""Chain that takes in an input and produces an action and action input.""" """Chain that takes in an input and produces an action and action input."""
from __future__ import annotations from __future__ import annotations
import asyncio
import json import json
import logging import logging
import time import time
@ -749,8 +750,10 @@ class AgentExecutor(Chain):
actions = [output] actions = [output]
else: else:
actions = output actions = output
result = []
for agent_action in actions: async def _aperform_agent_action(
agent_action: AgentAction,
) -> Tuple[AgentAction, str]:
if self.callback_manager.is_async: if self.callback_manager.is_async:
await self.callback_manager.on_agent_action( await self.callback_manager.on_agent_action(
agent_action, verbose=self.verbose, color="green" agent_action, verbose=self.verbose, color="green"
@ -782,9 +785,14 @@ class AgentExecutor(Chain):
color=None, color=None,
**tool_run_kwargs, **tool_run_kwargs,
) )
result.append((agent_action, observation)) return agent_action, observation
return result # Use asyncio.gather to run multiple tool.arun() calls concurrently
result = await asyncio.gather(
*[_aperform_agent_action(agent_action) for agent_action in actions]
)
return list(result)
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run text through and get agent response.""" """Run text through and get agent response."""