|
|
|
@ -19,6 +19,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|
|
|
|
parser: BaseOutputParser[T]
|
|
|
|
|
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
|
|
|
|
retry_chain: Any
|
|
|
|
|
max_retries: int = 1
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_llm(
|
|
|
|
@ -26,6 +27,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|
|
|
|
llm: BaseLanguageModel,
|
|
|
|
|
parser: BaseOutputParser[T],
|
|
|
|
|
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
|
|
|
|
|
max_retries: int = 1,
|
|
|
|
|
) -> OutputFixingParser[T]:
|
|
|
|
|
"""Create an OutputFixingParser from a language model and a parser.
|
|
|
|
|
|
|
|
|
@ -33,6 +35,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|
|
|
|
llm: llm to use for fixing
|
|
|
|
|
parser: parser to use for parsing
|
|
|
|
|
prompt: prompt to use for fixing
|
|
|
|
|
max_retries: Maximum number of retries to parser.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
OutputFixingParser
|
|
|
|
@ -40,33 +43,45 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
|
|
|
|
|
|
chain = LLMChain(llm=llm, prompt=prompt)
|
|
|
|
|
return cls(parser=parser, retry_chain=chain)
|
|
|
|
|
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
|
|
|
|
|
|
|
|
|
def parse(self, completion: str) -> T:
|
|
|
|
|
try:
|
|
|
|
|
parsed_completion = self.parser.parse(completion)
|
|
|
|
|
except OutputParserException as e:
|
|
|
|
|
new_completion = self.retry_chain.run(
|
|
|
|
|
instructions=self.parser.get_format_instructions(),
|
|
|
|
|
completion=completion,
|
|
|
|
|
error=repr(e),
|
|
|
|
|
)
|
|
|
|
|
parsed_completion = self.parser.parse(new_completion)
|
|
|
|
|
|
|
|
|
|
return parsed_completion
|
|
|
|
|
retries = 0
|
|
|
|
|
|
|
|
|
|
while retries <= self.max_retries:
|
|
|
|
|
try:
|
|
|
|
|
return self.parser.parse(completion)
|
|
|
|
|
except OutputParserException as e:
|
|
|
|
|
if retries == self.max_retries:
|
|
|
|
|
raise e
|
|
|
|
|
else:
|
|
|
|
|
retries += 1
|
|
|
|
|
completion = self.retry_chain.run(
|
|
|
|
|
instructions=self.parser.get_format_instructions(),
|
|
|
|
|
completion=completion,
|
|
|
|
|
error=repr(e),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
raise OutputParserException("Failed to parse")
|
|
|
|
|
|
|
|
|
|
async def aparse(self, completion: str) -> T:
|
|
|
|
|
try:
|
|
|
|
|
parsed_completion = self.parser.parse(completion)
|
|
|
|
|
except OutputParserException as e:
|
|
|
|
|
new_completion = await self.retry_chain.arun(
|
|
|
|
|
instructions=self.parser.get_format_instructions(),
|
|
|
|
|
completion=completion,
|
|
|
|
|
error=repr(e),
|
|
|
|
|
)
|
|
|
|
|
parsed_completion = self.parser.parse(new_completion)
|
|
|
|
|
|
|
|
|
|
return parsed_completion
|
|
|
|
|
retries = 0
|
|
|
|
|
|
|
|
|
|
while retries <= self.max_retries:
|
|
|
|
|
try:
|
|
|
|
|
return await self.parser.aparse(completion)
|
|
|
|
|
except OutputParserException as e:
|
|
|
|
|
if retries == self.max_retries:
|
|
|
|
|
raise e
|
|
|
|
|
else:
|
|
|
|
|
retries += 1
|
|
|
|
|
completion = await self.retry_chain.arun(
|
|
|
|
|
instructions=self.parser.get_format_instructions(),
|
|
|
|
|
completion=completion,
|
|
|
|
|
error=repr(e),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
raise OutputParserException("Failed to parse")
|
|
|
|
|
|
|
|
|
|
def get_format_instructions(self) -> str:
|
|
|
|
|
return self.parser.get_format_instructions()
|
|
|
|
|