diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index abfcb73fcd..9dd0a33d71 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -1,23 +1,27 @@ import json -from typing import Any, List, Type +from typing import Generic, List, Type, TypeVar from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import JsonOutputParser from langchain_core.outputs import Generation from langchain_core.pydantic_v1 import BaseModel, ValidationError +TBaseModel = TypeVar("TBaseModel", bound=BaseModel) -class PydanticOutputParser(JsonOutputParser): + +class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): """Parse an output using a pydantic model.""" - pydantic_object: Type[BaseModel] + pydantic_object: Type[TBaseModel] """The pydantic model to parse. Attention: To avoid potential compatibility issues, it's recommended to use pydantic <2 or leverage the v1 namespace in pydantic >= 2. """ - def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + def parse_result( + self, result: List[Generation], *, partial: bool = False + ) -> TBaseModel: json_object = super().parse_result(result) try: return self.pydantic_object.parse_obj(json_object) @@ -26,6 +30,9 @@ class PydanticOutputParser(JsonOutputParser): msg = f"Failed to parse {name} from completion {json_object}. Got: {e}" raise OutputParserException(msg, llm_output=json_object) + def parse(self, text: str) -> TBaseModel: + return super().parse(text) + def get_format_instructions(self) -> str: # Copy schema to avoid altering original Pydantic schema. schema = {k: v for k, v in self.pydantic_object.schema().items()} @@ -46,7 +53,7 @@ class PydanticOutputParser(JsonOutputParser): return "pydantic" @property - def OutputType(self) -> Type[BaseModel]: + def OutputType(self) -> Type[TBaseModel]: """Return the pydantic model.""" return self.pydantic_object