core[patch]: improve PydanticOutputParser typing (#18740)

This PR adds generic typing to `PydanticOutputParser` so we get a typed
output from `.parse` instead of `Any`. It should provide a better DX by
way of Intellisense and for anyone strictly typing.

Pre-change:

![Screenshot 2024-03-07 at 10 22
31 AM](https://github.com/langchain-ai/langchain/assets/22690160/fd22dde0-9fdc-4283-b283-4c98f0bc46e5)

Post-change:

![Screenshot 2024-03-07 at 10 26
31 AM](https://github.com/langchain-ai/langchain/assets/22690160/7e23d2b7-8f8c-494f-80b3-187530a173ee)

I haven't dug too deep, but I think a similar change could probably be
added to `JsonOutputParser` so we don't have to pull up `.parse`.

Co-authored-by: Jan Nissen <jan23@gmail.com>
This commit is contained in:
Jan Nissen 2024-03-07 22:25:24 -05:00 committed by GitHub
parent 3b975c6ebe
commit b8922480ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,23 +1,27 @@
import json import json
from typing import Any, List, Type from typing import Generic, List, Type, TypeVar
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation from langchain_core.outputs import Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError from langchain_core.pydantic_v1 import BaseModel, ValidationError
TBaseModel = TypeVar("TBaseModel", bound=BaseModel)
class PydanticOutputParser(JsonOutputParser):
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
"""Parse an output using a pydantic model.""" """Parse an output using a pydantic model."""
pydantic_object: Type[BaseModel] pydantic_object: Type[TBaseModel]
"""The pydantic model to parse. """The pydantic model to parse.
Attention: To avoid potential compatibility issues, it's recommended to use Attention: To avoid potential compatibility issues, it's recommended to use
pydantic <2 or leverage the v1 namespace in pydantic >= 2. pydantic <2 or leverage the v1 namespace in pydantic >= 2.
""" """
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> TBaseModel:
json_object = super().parse_result(result) json_object = super().parse_result(result)
try: try:
return self.pydantic_object.parse_obj(json_object) return self.pydantic_object.parse_obj(json_object)
@ -26,6 +30,9 @@ class PydanticOutputParser(JsonOutputParser):
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}" msg = f"Failed to parse {name} from completion {json_object}. Got: {e}"
raise OutputParserException(msg, llm_output=json_object) raise OutputParserException(msg, llm_output=json_object)
def parse(self, text: str) -> TBaseModel:
return super().parse(text)
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
# Copy schema to avoid altering original Pydantic schema. # Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self.pydantic_object.schema().items()} schema = {k: v for k, v in self.pydantic_object.schema().items()}
@ -46,7 +53,7 @@ class PydanticOutputParser(JsonOutputParser):
return "pydantic" return "pydantic"
@property @property
def OutputType(self) -> Type[BaseModel]: def OutputType(self) -> Type[TBaseModel]:
"""Return the pydantic model.""" """Return the pydantic model."""
return self.pydantic_object return self.pydantic_object