From 33cdb06b5c9d4d3e7f54d5e1e7c980dfae33923b Mon Sep 17 00:00:00 2001 From: Bennji94 Date: Mon, 7 Aug 2023 23:42:48 +0200 Subject: [PATCH] Async RetryOutputParser, RetryWithErrorOutputParser and OutputFixingParser (#8776) Added async parsing functions for RetryOutputParser, RetryWithErrorOutputParser and OutputFixingParser. The async parse functions call the arun methods of the used LLMChains. Fix for #7989 --------- Co-authored-by: Benjamin May --- .../langchain/langchain/output_parsers/fix.py | 13 ++++++++ .../langchain/output_parsers/retry.py | 31 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/libs/langchain/langchain/output_parsers/fix.py b/libs/langchain/langchain/output_parsers/fix.py index cc8bc30bf6..af5e33b42f 100644 --- a/libs/langchain/langchain/output_parsers/fix.py +++ b/libs/langchain/langchain/output_parsers/fix.py @@ -53,6 +53,19 @@ class OutputFixingParser(BaseOutputParser[T]): return parsed_completion + async def aparse(self, completion: str) -> T: + try: + parsed_completion = self.parser.parse(completion) + except OutputParserException as e: + new_completion = await self.retry_chain.arun( + instructions=self.parser.get_format_instructions(), + completion=completion, + error=repr(e), + ) + parsed_completion = self.parser.parse(new_completion) + + return parsed_completion + def get_format_instructions(self) -> str: return self.parser.get_format_instructions() diff --git a/libs/langchain/langchain/output_parsers/retry.py b/libs/langchain/langchain/output_parsers/retry.py index 0f2c4f7ac1..c9b7337701 100644 --- a/libs/langchain/langchain/output_parsers/retry.py +++ b/libs/langchain/langchain/output_parsers/retry.py @@ -79,6 +79,26 @@ class RetryOutputParser(BaseOutputParser[T]): return parsed_completion + async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: + """Parse the output of an LLM call using a wrapped parser. + + Args: + completion: The chain completion to parse. + prompt_value: The prompt to use to parse the completion. + + Returns: + The parsed completion. + """ + try: + parsed_completion = self.parser.parse(completion) + except OutputParserException: + new_completion = await self.retry_chain.arun( + prompt=prompt_value.to_string(), completion=completion + ) + parsed_completion = self.parser.parse(new_completion) + + return parsed_completion + def parse(self, completion: str) -> T: raise NotImplementedError( "This OutputParser can only be called by the `parse_with_prompt` method." @@ -136,6 +156,17 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): return parsed_completion + async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: + try: + parsed_completion = self.parser.parse(completion) + except OutputParserException as e: + new_completion = await self.retry_chain.arun( + prompt=prompt_value.to_string(), completion=completion, error=repr(e) + ) + parsed_completion = self.parser.parse(new_completion) + + return parsed_completion + def parse(self, completion: str) -> T: raise NotImplementedError( "This OutputParser can only be called by the `parse_with_prompt` method."