From 4e42c737f835ae202f78f3737d9a92cf8f2f0d8e Mon Sep 17 00:00:00 2001 From: blob42 Date: Thu, 18 May 2023 16:43:12 +0200 Subject: [PATCH] conv_chat: raise parsing error on output parser --- .../conversational_chat/output_parser.py | 41 ++++++++++--------- langchain/schema.py | 2 +- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/langchain/agents/conversational_chat/output_parser.py b/langchain/agents/conversational_chat/output_parser.py index 99880fac..d8a2c593 100644 --- a/langchain/agents/conversational_chat/output_parser.py +++ b/langchain/agents/conversational_chat/output_parser.py @@ -5,7 +5,7 @@ from typing import Union from langchain.agents import AgentOutputParser from langchain.agents.conversational_chat.prompt import FORMAT_INSTRUCTIONS -from langchain.schema import AgentAction, AgentFinish +from langchain.schema import AgentAction, AgentFinish, OutputParserException class ConvoOutputParser(AgentOutputParser): @@ -13,24 +13,27 @@ class ConvoOutputParser(AgentOutputParser): return FORMAT_INSTRUCTIONS def parse(self, text: str) -> Union[AgentAction, AgentFinish]: - cleaned_output = text.strip() - if "```json" in cleaned_output: - _, cleaned_output = cleaned_output.split("```json") - if "```" in cleaned_output: - cleaned_output, _ = cleaned_output.split("```") - if cleaned_output.startswith("```json"): - cleaned_output = cleaned_output[len("```json") :] - if cleaned_output.startswith("```"): - cleaned_output = cleaned_output[len("```") :] - if cleaned_output.endswith("```"): - cleaned_output = cleaned_output[: -len("```")] - cleaned_output = cleaned_output.strip() - response = json.loads(cleaned_output) - action, action_input = response["action"], response["action_input"] - if action == "Final Answer": - return AgentFinish({"output": action_input}, text) - else: - return AgentAction(action, action_input, text) + try: + cleaned_output = text.strip() + if "```json" in cleaned_output: + _, cleaned_output = cleaned_output.split("```json") + if "```" in cleaned_output: + cleaned_output, _ = cleaned_output.split("```") + if cleaned_output.startswith("```json"): + cleaned_output = cleaned_output[len("```json") :] + if cleaned_output.startswith("```"): + cleaned_output = cleaned_output[len("```") :] + if cleaned_output.endswith("```"): + cleaned_output = cleaned_output[: -len("```")] + cleaned_output = cleaned_output.strip() + response = json.loads(cleaned_output) + action, action_input = response["action"], response["action_input"] + if action == "Final Answer": + return AgentFinish({"output": action_input}, text) + else: + return AgentAction(action, action_input, text) + except Exception as e: + raise OutputParserException(f"Could not parse LLM output: {text}") from e @property def _type(self) -> str: diff --git a/langchain/schema.py b/langchain/schema.py index 21552b9b..156b781b 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -360,7 +360,7 @@ class BaseOutputParser(BaseModel, ABC, Generic[T]): return output_parser_dict -class OutputParserException(Exception): +class OutputParserException(ValueError): """Exception that output parsers should raise to signify a parsing error. This exists to differentiate parsing errors from other code or execution errors