diff --git a/libs/langchain/langchain/output_parsers/fix.py b/libs/langchain/langchain/output_parsers/fix.py index b0f23c0125..0b66e750a5 100644 --- a/libs/langchain/langchain/output_parsers/fix.py +++ b/libs/langchain/langchain/output_parsers/fix.py @@ -19,6 +19,7 @@ class OutputFixingParser(BaseOutputParser[T]): parser: BaseOutputParser[T] # Should be an LLMChain but we want to avoid top-level imports from langchain.chains retry_chain: Any + max_retries: int = 1 @classmethod def from_llm( @@ -26,6 +27,7 @@ class OutputFixingParser(BaseOutputParser[T]): llm: BaseLanguageModel, parser: BaseOutputParser[T], prompt: BasePromptTemplate = NAIVE_FIX_PROMPT, + max_retries: int = 1, ) -> OutputFixingParser[T]: """Create an OutputFixingParser from a language model and a parser. @@ -33,6 +35,7 @@ class OutputFixingParser(BaseOutputParser[T]): llm: llm to use for fixing parser: parser to use for parsing prompt: prompt to use for fixing + max_retries: Maximum number of retries to parser. Returns: OutputFixingParser @@ -40,33 +43,45 @@ class OutputFixingParser(BaseOutputParser[T]): from langchain.chains.llm import LLMChain chain = LLMChain(llm=llm, prompt=prompt) - return cls(parser=parser, retry_chain=chain) + return cls(parser=parser, retry_chain=chain, max_retries=max_retries) def parse(self, completion: str) -> T: - try: - parsed_completion = self.parser.parse(completion) - except OutputParserException as e: - new_completion = self.retry_chain.run( - instructions=self.parser.get_format_instructions(), - completion=completion, - error=repr(e), - ) - parsed_completion = self.parser.parse(new_completion) + retries = 0 - return parsed_completion + while retries <= self.max_retries: + try: + return self.parser.parse(completion) + except OutputParserException as e: + if retries == self.max_retries: + raise e + else: + retries += 1 + completion = self.retry_chain.run( + instructions=self.parser.get_format_instructions(), + completion=completion, + error=repr(e), + ) + + raise OutputParserException("Failed to parse") async def aparse(self, completion: str) -> T: - try: - parsed_completion = self.parser.parse(completion) - except OutputParserException as e: - new_completion = await self.retry_chain.arun( - instructions=self.parser.get_format_instructions(), - completion=completion, - error=repr(e), - ) - parsed_completion = self.parser.parse(new_completion) + retries = 0 - return parsed_completion + while retries <= self.max_retries: + try: + return await self.parser.aparse(completion) + except OutputParserException as e: + if retries == self.max_retries: + raise e + else: + retries += 1 + completion = await self.retry_chain.arun( + instructions=self.parser.get_format_instructions(), + completion=completion, + error=repr(e), + ) + + raise OutputParserException("Failed to parse") def get_format_instructions(self) -> str: return self.parser.get_format_instructions()