@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any
from typing import TypeVar
from langchain . chains . llm import LLMChain
from langchain . prompts . base import BasePromptTemplate
@ -34,28 +34,30 @@ NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template(
NAIVE_COMPLETION_RETRY_WITH_ERROR
)
T = TypeVar ( " T " )
class RetryOutputParser ( BaseOutputParser ) :
class RetryOutputParser ( BaseOutputParser [ T ] ) :
""" Wraps a parser and tries to fix parsing errors.
Does this by passing the original prompt and the completion to another
LLM , and telling it the completion did not satisfy criteria in the prompt .
"""
parser : BaseOutputParser
parser : BaseOutputParser [ T ]
retry_chain : LLMChain
@classmethod
def from_llm (
cls ,
llm : BaseLanguageModel ,
parser : BaseOutputParser ,
parser : BaseOutputParser [T ] ,
prompt : BasePromptTemplate = NAIVE_RETRY_PROMPT ,
) - > RetryOutputParser :
) - > RetryOutputParser [T ] :
chain = LLMChain ( llm = llm , prompt = prompt )
return cls ( parser = parser , retry_chain = chain )
def parse_with_prompt ( self , completion : str , prompt_value : PromptValue ) - > Any :
def parse_with_prompt ( self , completion : str , prompt_value : PromptValue ) - > T :
try :
parsed_completion = self . parser . parse ( completion )
except OutputParserException :
@ -66,7 +68,7 @@ class RetryOutputParser(BaseOutputParser):
return parsed_completion
def parse ( self , completion : str ) - > Any :
def parse ( self , completion : str ) - > T :
raise NotImplementedError (
" This OutputParser can only be called by the `parse_with_prompt` method. "
)
@ -75,7 +77,7 @@ class RetryOutputParser(BaseOutputParser):
return self . parser . get_format_instructions ( )
class RetryWithErrorOutputParser ( BaseOutputParser ):
class RetryWithErrorOutputParser ( BaseOutputParser [T ] ):
""" Wraps a parser and tries to fix parsing errors.
Does this by passing the original prompt , the completion , AND the error
@ -85,20 +87,20 @@ class RetryWithErrorOutputParser(BaseOutputParser):
LLM , which in theory should give it more information on how to fix it .
"""
parser : BaseOutputParser
parser : BaseOutputParser [ T ]
retry_chain : LLMChain
@classmethod
def from_llm (
cls ,
llm : BaseLanguageModel ,
parser : BaseOutputParser ,
parser : BaseOutputParser [T ] ,
prompt : BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT ,
) - > RetryWithErrorOutputParser :
) - > RetryWithErrorOutputParser [T ] :
chain = LLMChain ( llm = llm , prompt = prompt )
return cls ( parser = parser , retry_chain = chain )
def parse_with_prompt ( self , completion : str , prompt_value : PromptValue ) - > Any :
def parse_with_prompt ( self , completion : str , prompt_value : PromptValue ) - > T :
try :
parsed_completion = self . parser . parse ( completion )
except OutputParserException as e :
@ -109,7 +111,7 @@ class RetryWithErrorOutputParser(BaseOutputParser):
return parsed_completion
def parse ( self , completion : str ) - > Any :
def parse ( self , completion : str ) - > T :
raise NotImplementedError (
" This OutputParser can only be called by the `parse_with_prompt` method. "
)