Add type inference for output parsers (#2769)

Currently, the output type of a number of OutputParser's `parse` methods
is `Any` when it can in fact be inferred.

This PR makes BaseOutputParser use a generic type and fixes the output
types of the following parsers:
- `PydanticOutputParser`
- `OutputFixingParser`
- `RetryOutputParser`
- `RetryWithErrorOutputParser`

The output of the `StructuredOutputParser` is corrected from `BaseModel`
to `Any` since there are no type guarantees provided by the parser.

Fixes issue #2715
This commit is contained in:
Joshua Snyder 2023-04-12 18:12:20 +02:00 committed by GitHub
parent 789cc314c5
commit 59d054308c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 44 additions and 32 deletions

View File

@ -1,30 +1,32 @@
from __future__ import annotations
from typing import Any
from typing import TypeVar
from langchain.chains.llm import LLMChain
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
T = TypeVar("T")
class OutputFixingParser(BaseOutputParser):
class OutputFixingParser(BaseOutputParser[T]):
"""Wraps a parser and tries to fix parsing errors."""
parser: BaseOutputParser
parser: BaseOutputParser[T]
retry_chain: LLMChain
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
) -> OutputFixingParser:
) -> OutputFixingParser[T]:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)
def parse(self, completion: str) -> Any:
def parse(self, completion: str) -> T:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:

View File

@ -1,17 +1,19 @@
import json
import re
from typing import Any
from typing import Type, TypeVar
from pydantic import BaseModel, ValidationError
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser, OutputParserException
T = TypeVar("T", bound=BaseModel)
class PydanticOutputParser(BaseOutputParser):
pydantic_object: Any
def parse(self, text: str) -> BaseModel:
class PydanticOutputParser(BaseOutputParser[T]):
pydantic_object: Type[T]
def parse(self, text: str) -> T:
try:
# Greedy search for 1st json candidate.
match = re.search(
@ -38,6 +40,6 @@ class PydanticOutputParser(BaseOutputParser):
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema = json.dumps(reduced_schema)
schema_str = json.dumps(reduced_schema)
return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema)
return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any
from typing import TypeVar
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
@ -34,28 +34,30 @@ NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template(
NAIVE_COMPLETION_RETRY_WITH_ERROR
)
T = TypeVar("T")
class RetryOutputParser(BaseOutputParser):
class RetryOutputParser(BaseOutputParser[T]):
"""Wraps a parser and tries to fix parsing errors.
Does this by passing the original prompt and the completion to another
LLM, and telling it the completion did not satisfy criteria in the prompt.
"""
parser: BaseOutputParser
parser: BaseOutputParser[T]
retry_chain: LLMChain
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
) -> RetryOutputParser:
) -> RetryOutputParser[T]:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException:
@ -66,7 +68,7 @@ class RetryOutputParser(BaseOutputParser):
return parsed_completion
def parse(self, completion: str) -> Any:
def parse(self, completion: str) -> T:
raise NotImplementedError(
"This OutputParser can only be called by the `parse_with_prompt` method."
)
@ -75,7 +77,7 @@ class RetryOutputParser(BaseOutputParser):
return self.parser.get_format_instructions()
class RetryWithErrorOutputParser(BaseOutputParser):
class RetryWithErrorOutputParser(BaseOutputParser[T]):
"""Wraps a parser and tries to fix parsing errors.
Does this by passing the original prompt, the completion, AND the error
@ -85,20 +87,20 @@ class RetryWithErrorOutputParser(BaseOutputParser):
LLM, which in theory should give it more information on how to fix it.
"""
parser: BaseOutputParser
parser: BaseOutputParser[T]
retry_chain: LLMChain
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
) -> RetryWithErrorOutputParser:
) -> RetryWithErrorOutputParser[T]:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
@ -109,7 +111,7 @@ class RetryWithErrorOutputParser(BaseOutputParser):
return parsed_completion
def parse(self, completion: str) -> Any:
def parse(self, completion: str) -> T:
raise NotImplementedError(
"This OutputParser can only be called by the `parse_with_prompt` method."
)

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import json
from typing import List
from typing import Any, List
from pydantic import BaseModel
@ -37,7 +37,7 @@ class StructuredOutputParser(BaseOutputParser):
)
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
def parse(self, text: str) -> BaseModel:
def parse(self, text: str) -> Any:
json_string = text.split("```json")[1].strip().strip("```").strip()
try:
json_obj = json.loads(json_string)

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, NamedTuple, Optional
from typing import Any, Dict, Generic, List, NamedTuple, Optional, TypeVar
from pydantic import BaseModel, Extra, Field, root_validator
@ -327,15 +327,17 @@ class BaseRetriever(ABC):
Memory = BaseMemory
T = TypeVar("T")
class BaseOutputParser(BaseModel, ABC):
class BaseOutputParser(BaseModel, ABC, Generic[T]):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.
"""
@abstractmethod
def parse(self, text: str) -> Any:
def parse(self, text: str) -> T:
"""Parse the output of an LLM call.
A method which takes in a string (assumed output of language model )

View File

@ -46,7 +46,9 @@ DEF_EXPECTED_RESULT = TestModel(
def test_pydantic_output_parser() -> None:
"""Test PydanticOutputParser."""
pydantic_parser = PydanticOutputParser(pydantic_object=TestModel)
pydantic_parser: PydanticOutputParser[TestModel] = PydanticOutputParser(
pydantic_object=TestModel
)
result = pydantic_parser.parse(DEF_RESULT)
print("parse_result:", result)
@ -56,7 +58,9 @@ def test_pydantic_output_parser() -> None:
def test_pydantic_output_parser_fail() -> None:
"""Test PydanticOutputParser where completion result fails schema validation."""
pydantic_parser = PydanticOutputParser(pydantic_object=TestModel)
pydantic_parser: PydanticOutputParser[TestModel] = PydanticOutputParser(
pydantic_object=TestModel
)
try:
pydantic_parser.parse(DEF_RESULT_FAIL)