core[minor]: support pydantic v2 models in PydanticOutputParser (#18811)

As mentioned in #18322, the current PydanticOutputParser won't work for
anyone trying to parse to pydantic v2 models. This PR adds a separate
`PydanticV2OutputParser`, as well as a `langchain_core.pydantic_v2`
namespace that will fail on import to any projects using pydantic<2.
Happy to update the docs for output parsers if this is something we're
interesting in adding.

On a separate note, I also updated `check_pydantic.sh` to detect
pydantic imports with leading whitespace and excluded the internal
namespaces. That change can be separated into its own PR if needed.

---------

Co-authored-by: Jan Nissen <jan23@gmail.com>
pull/18991/head
Jan Nissen 2 months ago committed by GitHub
parent d0accc3275
commit 2e0ddd6fb8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,34 +1,64 @@
import json
from typing import Generic, List, Type, TypeVar
from typing import Generic, List, Type, TypeVar, Union
import pydantic # pydantic: ignore
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
if PYDANTIC_MAJOR_VERSION < 2:
PydanticBaseModel = pydantic.BaseModel
else:
from pydantic.v1 import BaseModel # pydantic: ignore
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
TBaseModel = TypeVar("TBaseModel", bound=BaseModel)
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
"""Parse an output using a pydantic model."""
pydantic_object: Type[TBaseModel]
"""The pydantic model to parse.
Attention: To avoid potential compatibility issues, it's recommended to use
pydantic <2 or leverage the v1 namespace in pydantic >= 2.
"""
pydantic_object: Type[TBaseModel] # type: ignore
"""The pydantic model to parse."""
def _parse_obj(self, obj: dict) -> TBaseModel:
if PYDANTIC_MAJOR_VERSION == 2:
try:
if issubclass(self.pydantic_object, pydantic.BaseModel):
return self.pydantic_object.model_validate(obj)
elif issubclass(self.pydantic_object, pydantic.v1.BaseModel):
return self.pydantic_object.parse_obj(obj)
else:
raise OutputParserException(
f"Unsupported model version for PydanticOutputParser: \
{self.pydantic_object.__class__}"
)
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
raise self._parser_exception(e, obj)
else: # pydantic v1
try:
return self.pydantic_object.parse_obj(obj)
except pydantic.ValidationError as e:
raise self._parser_exception(e, obj)
def _parser_exception(
self, e: Exception, json_object: dict
) -> OutputParserException:
json_string = json.dumps(json_object)
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {json_string}. Got: {e}"
return OutputParserException(msg, llm_output=json_string)
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> TBaseModel:
json_object = super().parse_result(result)
try:
return self.pydantic_object.parse_obj(json_object)
except ValidationError as e:
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}"
raise OutputParserException(msg, llm_output=json_object)
return self._parse_obj(json_object)
def parse(self, text: str) -> TBaseModel:
return super().parse(text)

@ -14,7 +14,10 @@ fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
result=$(
git -C "$repository_path" grep -E '^[[:space:]]*import pydantic|^[[:space:]]*from pydantic' \
-- ':!langchain_core/pydantic_*' ':!langchain_core/utils' | grep -v 'pydantic: ignore'
)
# Check if any matching lines were found
if [ -n "$result" ]; then
@ -23,5 +26,6 @@ if [ -n "$result" ]; then
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
echo "If this was intentional, you can add # pydantic: ignore after the import to ignore this error."
exit 1
fi

@ -0,0 +1,72 @@
from typing import Literal
import pydantic # pydantic: ignore
import pytest
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import ParrotFakeChatModel
from langchain_core.output_parsers.pydantic import PydanticOutputParser, TBaseModel
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
V1BaseModel = pydantic.BaseModel
if PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel # pydantic: ignore
V1BaseModel = BaseModel # type: ignore
class ForecastV2(pydantic.BaseModel):
temperature: int
f_or_c: Literal["F", "C"]
forecast: str
class ForecastV1(V1BaseModel):
temperature: int
f_or_c: Literal["F", "C"]
forecast: str
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
def test_pydantic_parser_chaining(
pydantic_object: TBaseModel,
) -> None:
prompt = PromptTemplate(
template="""{{
"temperature": 20,
"f_or_c": "C",
"forecast": "Sunny"
}}""",
input_variables=[],
)
model = ParrotFakeChatModel()
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore
chain = prompt | model | parser
res = chain.invoke({})
assert type(res) == pydantic_object
assert res.f_or_c == "C"
assert res.temperature == 20
assert res.forecast == "Sunny"
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None:
bad_prompt = PromptTemplate(
template="""{{
"temperature": "oof",
"f_or_c": 1,
"forecast": "Sunny"
}}""",
input_variables=[],
)
model = ParrotFakeChatModel()
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore
chain = bad_prompt | model | parser
with pytest.raises(OutputParserException):
chain.invoke({})
Loading…
Cancel
Save