Supported OutputFixingParser max_retries (#11754)

Description: Supported OutputFixingParser max_retries
 - max_retries: Maximum number of retries to parser.

Issue: None
Dependencies: None
Tag maintainer: @baskaryan
Twitter handle: @JohnMai95

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
John Mai 2023-10-16 21:25:47 -05:00 committed by GitHub
parent c87b5c209d
commit 0169d45ba8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -19,6 +19,7 @@ class OutputFixingParser(BaseOutputParser[T]):
parser: BaseOutputParser[T] parser: BaseOutputParser[T]
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains # Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any retry_chain: Any
max_retries: int = 1
@classmethod @classmethod
def from_llm( def from_llm(
@ -26,6 +27,7 @@ class OutputFixingParser(BaseOutputParser[T]):
llm: BaseLanguageModel, llm: BaseLanguageModel,
parser: BaseOutputParser[T], parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT, prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
max_retries: int = 1,
) -> OutputFixingParser[T]: ) -> OutputFixingParser[T]:
"""Create an OutputFixingParser from a language model and a parser. """Create an OutputFixingParser from a language model and a parser.
@ -33,6 +35,7 @@ class OutputFixingParser(BaseOutputParser[T]):
llm: llm to use for fixing llm: llm to use for fixing
parser: parser to use for parsing parser: parser to use for parsing
prompt: prompt to use for fixing prompt: prompt to use for fixing
max_retries: Maximum number of retries to parser.
Returns: Returns:
OutputFixingParser OutputFixingParser
@ -40,33 +43,45 @@ class OutputFixingParser(BaseOutputParser[T]):
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
chain = LLMChain(llm=llm, prompt=prompt) chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain) return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
def parse(self, completion: str) -> T: def parse(self, completion: str) -> T:
try: retries = 0
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
new_completion = self.retry_chain.run(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
parsed_completion = self.parser.parse(new_completion)
return parsed_completion while retries <= self.max_retries:
try:
return self.parser.parse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
completion = self.retry_chain.run(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
raise OutputParserException("Failed to parse")
async def aparse(self, completion: str) -> T: async def aparse(self, completion: str) -> T:
try: retries = 0
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
new_completion = await self.retry_chain.arun(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
parsed_completion = self.parser.parse(new_completion)
return parsed_completion while retries <= self.max_retries:
try:
return await self.parser.aparse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
completion = await self.retry_chain.arun(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
raise OutputParserException("Failed to parse")
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
return self.parser.get_format_instructions() return self.parser.get_format_instructions()