mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
core[patch]: improve PydanticOutputParser typing (#18740)
This PR adds generic typing to `PydanticOutputParser` so we get a typed output from `.parse` instead of `Any`. It should provide a better DX by way of Intellisense and for anyone strictly typing. Pre-change: ![Screenshot 2024-03-07 at 10 22 31 AM](https://github.com/langchain-ai/langchain/assets/22690160/fd22dde0-9fdc-4283-b283-4c98f0bc46e5) Post-change: ![Screenshot 2024-03-07 at 10 26 31 AM](https://github.com/langchain-ai/langchain/assets/22690160/7e23d2b7-8f8c-494f-80b3-187530a173ee) I haven't dug too deep, but I think a similar change could probably be added to `JsonOutputParser` so we don't have to pull up `.parse`. Co-authored-by: Jan Nissen <jan23@gmail.com>
This commit is contained in:
parent
3b975c6ebe
commit
b8922480ed
@ -1,23 +1,27 @@
|
||||
import json
|
||||
from typing import Any, List, Type
|
||||
from typing import Generic, List, Type, TypeVar
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
|
||||
TBaseModel = TypeVar("TBaseModel", bound=BaseModel)
|
||||
|
||||
class PydanticOutputParser(JsonOutputParser):
|
||||
|
||||
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
"""Parse an output using a pydantic model."""
|
||||
|
||||
pydantic_object: Type[BaseModel]
|
||||
pydantic_object: Type[TBaseModel]
|
||||
"""The pydantic model to parse.
|
||||
|
||||
Attention: To avoid potential compatibility issues, it's recommended to use
|
||||
pydantic <2 or leverage the v1 namespace in pydantic >= 2.
|
||||
"""
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> TBaseModel:
|
||||
json_object = super().parse_result(result)
|
||||
try:
|
||||
return self.pydantic_object.parse_obj(json_object)
|
||||
@ -26,6 +30,9 @@ class PydanticOutputParser(JsonOutputParser):
|
||||
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}"
|
||||
raise OutputParserException(msg, llm_output=json_object)
|
||||
|
||||
def parse(self, text: str) -> TBaseModel:
|
||||
return super().parse(text)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
# Copy schema to avoid altering original Pydantic schema.
|
||||
schema = {k: v for k, v in self.pydantic_object.schema().items()}
|
||||
@ -46,7 +53,7 @@ class PydanticOutputParser(JsonOutputParser):
|
||||
return "pydantic"
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[BaseModel]:
|
||||
def OutputType(self) -> Type[TBaseModel]:
|
||||
"""Return the pydantic model."""
|
||||
return self.pydantic_object
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user