Add type inference for output parsers (#2769)

Currently, the output type of a number of OutputParser's `parse` methods
is `Any` when it can in fact be inferred.

This PR makes BaseOutputParser use a generic type and fixes the output
types of the following parsers:
- `PydanticOutputParser`
- `OutputFixingParser`
- `RetryOutputParser`
- `RetryWithErrorOutputParser`

The output of the `StructuredOutputParser` is corrected from `BaseModel`
to `Any` since there are no type guarantees provided by the parser.

Fixes issue #2715
This commit is contained in:
Joshua Snyder 2023-04-12 18:12:20 +02:00 committed by GitHub
parent 789cc314c5
commit 59d054308c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 44 additions and 32 deletions

View File

@ -1,30 +1,32 @@
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import TypeVar
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
T = TypeVar("T")
class OutputFixingParser(BaseOutputParser):
class OutputFixingParser(BaseOutputParser[T]):
"""Wraps a parser and tries to fix parsing errors.""" """Wraps a parser and tries to fix parsing errors."""
parser: BaseOutputParser parser: BaseOutputParser[T]
retry_chain: LLMChain retry_chain: LLMChain
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: BaseLanguageModel, llm: BaseLanguageModel,
parser: BaseOutputParser, parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT, prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
) -> OutputFixingParser: ) -> OutputFixingParser[T]:
chain = LLMChain(llm=llm, prompt=prompt) chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain) return cls(parser=parser, retry_chain=chain)
def parse(self, completion: str) -> Any: def parse(self, completion: str) -> T:
try: try:
parsed_completion = self.parser.parse(completion) parsed_completion = self.parser.parse(completion)
except OutputParserException as e: except OutputParserException as e:

View File

@ -1,17 +1,19 @@
import json import json
import re import re
from typing import Any from typing import Type, TypeVar
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser, OutputParserException from langchain.schema import BaseOutputParser, OutputParserException
T = TypeVar("T", bound=BaseModel)
class PydanticOutputParser(BaseOutputParser):
pydantic_object: Any
def parse(self, text: str) -> BaseModel: class PydanticOutputParser(BaseOutputParser[T]):
pydantic_object: Type[T]
def parse(self, text: str) -> T:
try: try:
# Greedy search for 1st json candidate. # Greedy search for 1st json candidate.
match = re.search( match = re.search(
@ -38,6 +40,6 @@ class PydanticOutputParser(BaseOutputParser):
if "type" in reduced_schema: if "type" in reduced_schema:
del reduced_schema["type"] del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes. # Ensure json in context is well-formed with double quotes.
schema = json.dumps(reduced_schema) schema_str = json.dumps(reduced_schema)
return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema) return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import TypeVar
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
@ -34,28 +34,30 @@ NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template(
NAIVE_COMPLETION_RETRY_WITH_ERROR NAIVE_COMPLETION_RETRY_WITH_ERROR
) )
T = TypeVar("T")
class RetryOutputParser(BaseOutputParser):
class RetryOutputParser(BaseOutputParser[T]):
"""Wraps a parser and tries to fix parsing errors. """Wraps a parser and tries to fix parsing errors.
Does this by passing the original prompt and the completion to another Does this by passing the original prompt and the completion to another
LLM, and telling it the completion did not satisfy criteria in the prompt. LLM, and telling it the completion did not satisfy criteria in the prompt.
""" """
parser: BaseOutputParser parser: BaseOutputParser[T]
retry_chain: LLMChain retry_chain: LLMChain
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: BaseLanguageModel, llm: BaseLanguageModel,
parser: BaseOutputParser, parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT, prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
) -> RetryOutputParser: ) -> RetryOutputParser[T]:
chain = LLMChain(llm=llm, prompt=prompt) chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain) return cls(parser=parser, retry_chain=chain)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any: def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
try: try:
parsed_completion = self.parser.parse(completion) parsed_completion = self.parser.parse(completion)
except OutputParserException: except OutputParserException:
@ -66,7 +68,7 @@ class RetryOutputParser(BaseOutputParser):
return parsed_completion return parsed_completion
def parse(self, completion: str) -> Any: def parse(self, completion: str) -> T:
raise NotImplementedError( raise NotImplementedError(
"This OutputParser can only be called by the `parse_with_prompt` method." "This OutputParser can only be called by the `parse_with_prompt` method."
) )
@ -75,7 +77,7 @@ class RetryOutputParser(BaseOutputParser):
return self.parser.get_format_instructions() return self.parser.get_format_instructions()
class RetryWithErrorOutputParser(BaseOutputParser): class RetryWithErrorOutputParser(BaseOutputParser[T]):
"""Wraps a parser and tries to fix parsing errors. """Wraps a parser and tries to fix parsing errors.
Does this by passing the original prompt, the completion, AND the error Does this by passing the original prompt, the completion, AND the error
@ -85,20 +87,20 @@ class RetryWithErrorOutputParser(BaseOutputParser):
LLM, which in theory should give it more information on how to fix it. LLM, which in theory should give it more information on how to fix it.
""" """
parser: BaseOutputParser parser: BaseOutputParser[T]
retry_chain: LLMChain retry_chain: LLMChain
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: BaseLanguageModel, llm: BaseLanguageModel,
parser: BaseOutputParser, parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT, prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
) -> RetryWithErrorOutputParser: ) -> RetryWithErrorOutputParser[T]:
chain = LLMChain(llm=llm, prompt=prompt) chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain) return cls(parser=parser, retry_chain=chain)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any: def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
try: try:
parsed_completion = self.parser.parse(completion) parsed_completion = self.parser.parse(completion)
except OutputParserException as e: except OutputParserException as e:
@ -109,7 +111,7 @@ class RetryWithErrorOutputParser(BaseOutputParser):
return parsed_completion return parsed_completion
def parse(self, completion: str) -> Any: def parse(self, completion: str) -> T:
raise NotImplementedError( raise NotImplementedError(
"This OutputParser can only be called by the `parse_with_prompt` method." "This OutputParser can only be called by the `parse_with_prompt` method."
) )

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
from typing import List from typing import Any, List
from pydantic import BaseModel from pydantic import BaseModel
@ -37,7 +37,7 @@ class StructuredOutputParser(BaseOutputParser):
) )
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str) return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
def parse(self, text: str) -> BaseModel: def parse(self, text: str) -> Any:
json_string = text.split("```json")[1].strip().strip("```").strip() json_string = text.split("```json")[1].strip().strip("```").strip()
try: try:
json_obj = json.loads(json_string) json_obj = json.loads(json_string)

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, NamedTuple, Optional from typing import Any, Dict, Generic, List, NamedTuple, Optional, TypeVar
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
@ -327,15 +327,17 @@ class BaseRetriever(ABC):
Memory = BaseMemory Memory = BaseMemory
T = TypeVar("T")
class BaseOutputParser(BaseModel, ABC):
class BaseOutputParser(BaseModel, ABC, Generic[T]):
"""Class to parse the output of an LLM call. """Class to parse the output of an LLM call.
Output parsers help structure language model responses. Output parsers help structure language model responses.
""" """
@abstractmethod @abstractmethod
def parse(self, text: str) -> Any: def parse(self, text: str) -> T:
"""Parse the output of an LLM call. """Parse the output of an LLM call.
A method which takes in a string (assumed output of language model ) A method which takes in a string (assumed output of language model )

View File

@ -46,7 +46,9 @@ DEF_EXPECTED_RESULT = TestModel(
def test_pydantic_output_parser() -> None: def test_pydantic_output_parser() -> None:
"""Test PydanticOutputParser.""" """Test PydanticOutputParser."""
pydantic_parser = PydanticOutputParser(pydantic_object=TestModel) pydantic_parser: PydanticOutputParser[TestModel] = PydanticOutputParser(
pydantic_object=TestModel
)
result = pydantic_parser.parse(DEF_RESULT) result = pydantic_parser.parse(DEF_RESULT)
print("parse_result:", result) print("parse_result:", result)
@ -56,7 +58,9 @@ def test_pydantic_output_parser() -> None:
def test_pydantic_output_parser_fail() -> None: def test_pydantic_output_parser_fail() -> None:
"""Test PydanticOutputParser where completion result fails schema validation.""" """Test PydanticOutputParser where completion result fails schema validation."""
pydantic_parser = PydanticOutputParser(pydantic_object=TestModel) pydantic_parser: PydanticOutputParser[TestModel] = PydanticOutputParser(
pydantic_object=TestModel
)
try: try:
pydantic_parser.parse(DEF_RESULT_FAIL) pydantic_parser.parse(DEF_RESULT_FAIL)