FIX: Infer runnable agent single or multi action (#13412)

pull/13420/head
Bagatur 11 months ago committed by GitHub
parent accadccf8e
commit 1372296dc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -319,7 +319,7 @@ class BaseMultiActionAgent(BaseModel):
return {}
class AgentOutputParser(BaseOutputParser):
class AgentOutputParser(BaseOutputParser[Union[AgentAction, AgentFinish]]):
"""Base class for parsing agent output into agent action/finish."""
@abstractmethod
@ -327,7 +327,9 @@ class AgentOutputParser(BaseOutputParser):
"""Parse text into agent action/finish."""
class MultiActionAgentOutputParser(BaseOutputParser):
class MultiActionAgentOutputParser(
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
):
"""Base class for parsing agent output into agent actions/finish."""
@abstractmethod
@ -335,17 +337,87 @@ class MultiActionAgentOutputParser(BaseOutputParser):
"""Parse text into agent actions/finish."""
class RunnableAgent(BaseMultiActionAgent):
class RunnableAgent(BaseSingleActionAgent):
"""Agent powered by runnables."""
runnable: Union[
Runnable[dict, Union[AgentAction, AgentFinish]],
Runnable[dict, Union[List[AgentAction], AgentFinish]],
]
runnable: Runnable[dict, Union[AgentAction, AgentFinish]]
"""Runnable to call to get agent action."""
_input_keys: List[str] = []
"""Input keys."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def return_values(self) -> List[str]:
"""Return values of the agent."""
return []
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
Returns:
List of input keys.
"""
return self._input_keys
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with the observations.
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
output = self.runnable.invoke(inputs, config={"callbacks": callbacks})
return output
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[
AgentAction,
AgentFinish,
]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
output = await self.runnable.ainvoke(inputs, config={"callbacks": callbacks})
return output
class RunnableMultiActionAgent(BaseMultiActionAgent):
"""Agent powered by runnables."""
runnable: Runnable[dict, Union[List[AgentAction], AgentFinish]]
"""Runnable to call to get agent actions."""
_input_keys: List[str] = []
"""Input keys."""
class Config:
"""Configuration for this pydantic object."""
@ -387,8 +459,6 @@ class RunnableAgent(BaseMultiActionAgent):
"""
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
output = self.runnable.invoke(inputs, config={"callbacks": callbacks})
if isinstance(output, AgentAction):
output = [output]
return output
async def aplan(
@ -413,8 +483,6 @@ class RunnableAgent(BaseMultiActionAgent):
"""
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
output = await self.runnable.ainvoke(inputs, config={"callbacks": callbacks})
if isinstance(output, AgentAction):
output = [output]
return output
@ -840,7 +908,17 @@ class AgentExecutor(Chain):
"""Convert runnable to agent if passed in."""
agent = values["agent"]
if isinstance(agent, Runnable):
values["agent"] = RunnableAgent(runnable=agent)
try:
output_type = agent.OutputType
except Exception as _:
multi_action = False
else:
multi_action = output_type == Union[List[AgentAction], AgentFinish]
if multi_action:
values["agent"] = RunnableMultiActionAgent(runnable=agent)
else:
values["agent"] = RunnableAgent(runnable=agent)
return values
def save(self, file_path: Union[Path, str]) -> None:

Loading…
Cancel
Save