Supported RetryOutputParser & RetryWithErrorOutputParser max_retries (#11903)

Description: Supported RetryOutputParser & RetryWithErrorOutputParser
max_retries
- max_retries: Maximum number of retries to parser.

Issue: None
Dependencies: None
Tag maintainer: @baskaryan 
Twitter handle:
pull/12016/head
John Mai 12 months ago committed by GitHub
parent 008c7df80d
commit a6b483dcbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,9 +17,12 @@ class OutputFixingParser(BaseOutputParser[T]):
return True
parser: BaseOutputParser[T]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any
"""The LLMChain to use to retry the completion."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
@classmethod
def from_llm(
@ -35,7 +38,7 @@ class OutputFixingParser(BaseOutputParser[T]):
llm: llm to use for fixing
parser: parser to use for parsing
prompt: prompt to use for fixing
max_retries: Maximum number of retries to parser.
max_retries: Maximum number of retries to parse.
Returns:
OutputFixingParser

@ -48,6 +48,8 @@ class RetryOutputParser(BaseOutputParser[T]):
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any
"""The LLMChain to use to retry the completion."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
@classmethod
def from_llm(
@ -55,11 +57,23 @@ class RetryOutputParser(BaseOutputParser[T]):
llm: BaseLanguageModel,
parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
max_retries: int = 1,
) -> RetryOutputParser[T]:
"""Create an OutputFixingParser from a language model and a parser.
Args:
llm: llm to use for fixing
parser: parser to use for parsing
prompt: prompt to use for fixing
max_retries: Maximum number of retries to parse.
Returns:
RetryOutputParser
"""
from langchain.chains.llm import LLMChain
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
"""Parse the output of an LLM call using a wrapped parser.
@ -71,15 +85,21 @@ class RetryOutputParser(BaseOutputParser[T]):
Returns:
The parsed completion.
"""
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException:
new_completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion
)
parsed_completion = self.parser.parse(new_completion)
return parsed_completion
retries = 0
while retries <= self.max_retries:
try:
return self.parser.parse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion
)
raise OutputParserException("Failed to parse")
async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
"""Parse the output of an LLM call using a wrapped parser.
@ -91,15 +111,21 @@ class RetryOutputParser(BaseOutputParser[T]):
Returns:
The parsed completion.
"""
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException:
new_completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(), completion=completion
)
parsed_completion = self.parser.parse(new_completion)
return parsed_completion
retries = 0
while retries <= self.max_retries:
try:
return await self.parser.aparse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(), completion=completion
)
raise OutputParserException("Failed to parse")
def parse(self, completion: str) -> T:
raise NotImplementedError(
@ -125,8 +151,12 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
"""
parser: BaseOutputParser[T]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any
"""The LLMChain to use to retry the completion."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
@classmethod
def from_llm(
@ -134,6 +164,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
llm: BaseLanguageModel,
parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
max_retries: int = 1,
) -> RetryWithErrorOutputParser[T]:
"""Create a RetryWithErrorOutputParser from an LLM.
@ -141,6 +172,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
llm: The LLM to use to retry the completion.
parser: The parser to use to parse the output.
prompt: The prompt to use to retry the completion.
max_retries: The maximum number of times to retry the completion.
Returns:
A RetryWithErrorOutputParser.
@ -148,29 +180,45 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
from langchain.chains.llm import LLMChain
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
new_completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion, error=repr(e)
)
parsed_completion = self.parser.parse(new_completion)
return parsed_completion
retries = 0
while retries <= self.max_retries:
try:
return self.parser.parse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
completion = self.retry_chain.run(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
raise OutputParserException("Failed to parse")
async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
new_completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(), completion=completion, error=repr(e)
)
parsed_completion = self.parser.parse(new_completion)
return parsed_completion
retries = 0
while retries <= self.max_retries:
try:
return await self.parser.aparse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
raise OutputParserException("Failed to parse")
def parse(self, completion: str) -> T:
raise NotImplementedError(

Loading…
Cancel
Save