retry to parsing (#3696)

This commit is contained in:
Harrison Chase 2023-05-01 22:05:42 -07:00 committed by GitHub
parent 3993166b5e
commit ca08a34a98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,7 +17,9 @@ from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
AsyncCallbackManagerForToolRun,
CallbackManagerForChainRun, CallbackManagerForChainRun,
CallbackManagerForToolRun,
Callbacks, Callbacks,
) )
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -578,6 +580,25 @@ class Agent(BaseSingleActionAgent):
} }
class ExceptionTool(BaseTool):
name = "_Exception"
description = "Exception tool"
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
return query
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
return query
class AgentExecutor(Chain): class AgentExecutor(Chain):
"""Consists of an agent using tools.""" """Consists of an agent using tools."""
@ -587,6 +608,7 @@ class AgentExecutor(Chain):
max_iterations: Optional[int] = 15 max_iterations: Optional[int] = 15
max_execution_time: Optional[float] = None max_execution_time: Optional[float] = None
early_stopping_method: str = "force" early_stopping_method: str = "force"
handle_parsing_errors: bool = False
@classmethod @classmethod
def from_agent_and_tools( def from_agent_and_tools(
@ -714,12 +736,28 @@ class AgentExecutor(Chain):
Override this to take control of how the agent makes and acts on choices. Override this to take control of how the agent makes and acts on choices.
""" """
# Call the LLM to see what to do. try:
output = self.agent.plan( # Call the LLM to see what to do.
intermediate_steps, output = self.agent.plan(
callbacks=run_manager.get_child() if run_manager else None, intermediate_steps,
**inputs, callbacks=run_manager.get_child() if run_manager else None,
) **inputs,
)
except Exception as e:
if not self.handle_parsing_errors:
raise e
text = str(e).split("`")[1]
observation = "Invalid or incomplete response"
output = AgentAction("_Exception", observation, text)
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = ExceptionTool().run(
output.tool,
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
return [(output, observation)]
# If the tool chosen is the finishing tool, then we end and return. # If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish): if isinstance(output, AgentFinish):
return output return output
@ -772,12 +810,28 @@ class AgentExecutor(Chain):
Override this to take control of how the agent makes and acts on choices. Override this to take control of how the agent makes and acts on choices.
""" """
# Call the LLM to see what to do. try:
output = await self.agent.aplan( # Call the LLM to see what to do.
intermediate_steps, output = await self.agent.aplan(
callbacks=run_manager.get_child() if run_manager else None, intermediate_steps,
**inputs, callbacks=run_manager.get_child() if run_manager else None,
) **inputs,
)
except Exception as e:
if not self.handle_parsing_errors:
raise e
text = str(e).split("`")[1]
observation = "Invalid or incomplete response"
output = AgentAction("_Exception", observation, text)
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = await ExceptionTool().arun(
output.tool,
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
return [(output, observation)]
# If the tool chosen is the finishing tool, then we end and return. # If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish): if isinstance(output, AgentFinish):
return output return output