forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
|
|
|
|
|
|
class OutputFixingParser(BaseOutputParser):
|
|
"""Wraps a parser and tries to fix parsing errors."""
|
|
|
|
parser: BaseOutputParser
|
|
retry_chain: LLMChain
|
|
|
|
@classmethod
|
|
def from_llm(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
parser: BaseOutputParser,
|
|
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
|
|
) -> OutputFixingParser:
|
|
chain = LLMChain(llm=llm, prompt=prompt)
|
|
return cls(parser=parser, retry_chain=chain)
|
|
|
|
def parse(self, completion: str) -> Any:
|
|
try:
|
|
parsed_completion = self.parser.parse(completion)
|
|
except OutputParserException as e:
|
|
new_completion = self.retry_chain.run(
|
|
instructions=self.parser.get_format_instructions(),
|
|
completion=completion,
|
|
error=repr(e),
|
|
)
|
|
parsed_completion = self.parser.parse(new_completion)
|
|
|
|
return parsed_completion
|
|
|
|
def get_format_instructions(self) -> str:
|
|
return self.parser.get_format_instructions()
|