Don't type chains in output_parsers (#11092)

Can't use TYPE_CHECKING style imports for pydantic params because it will try to instantiate the typed object by default.
pull/11109/head
Bagatur 12 months ago committed by GitHub
parent 64385c4eae
commit 5514ebe859
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,14 +1,11 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar from typing import Any, TypeVar
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
if TYPE_CHECKING:
from langchain.chains.llm import LLMChain
T = TypeVar("T") T = TypeVar("T")
@ -20,7 +17,8 @@ class OutputFixingParser(BaseOutputParser[T]):
return True return True
parser: BaseOutputParser[T] parser: BaseOutputParser[T]
retry_chain: LLMChain # Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any
@classmethod @classmethod
def from_llm( def from_llm(

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar from typing import Any, TypeVar
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import ( from langchain.schema import (
@ -11,9 +11,6 @@ from langchain.schema import (
) )
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
if TYPE_CHECKING:
from langchain.chains.llm import LLMChain
NAIVE_COMPLETION_RETRY = """Prompt: NAIVE_COMPLETION_RETRY = """Prompt:
{prompt} {prompt}
Completion: Completion:
@ -48,7 +45,8 @@ class RetryOutputParser(BaseOutputParser[T]):
parser: BaseOutputParser[T] parser: BaseOutputParser[T]
"""The parser to use to parse the output.""" """The parser to use to parse the output."""
retry_chain: LLMChain # Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any
"""The LLMChain to use to retry the completion.""" """The LLMChain to use to retry the completion."""
@classmethod @classmethod
@ -127,7 +125,8 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
""" """
parser: BaseOutputParser[T] parser: BaseOutputParser[T]
retry_chain: LLMChain # Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any
@classmethod @classmethod
def from_llm( def from_llm(

Loading…
Cancel
Save