From 2e0ddd6fb805af53d9bfb0edc463543a0312530f Mon Sep 17 00:00:00 2001 From: Jan Nissen Date: Wed, 27 Mar 2024 15:37:52 -0400 Subject: [PATCH] core[minor]: support pydantic v2 models in PydanticOutputParser (#18811) As mentioned in #18322, the current PydanticOutputParser won't work for anyone trying to parse to pydantic v2 models. This PR adds a separate `PydanticV2OutputParser`, as well as a `langchain_core.pydantic_v2` namespace that will fail on import to any projects using pydantic<2. Happy to update the docs for output parsers if this is something we're interesting in adding. On a separate note, I also updated `check_pydantic.sh` to detect pydantic imports with leading whitespace and excluded the internal namespaces. That change can be separated into its own PR if needed. --------- Co-authored-by: Jan Nissen --- .../langchain_core/output_parsers/pydantic.py | 60 ++++++++++++---- libs/core/scripts/check_pydantic.sh | 6 +- .../output_parsers/test_pydantic_parser.py | 72 +++++++++++++++++++ 3 files changed, 122 insertions(+), 16 deletions(-) create mode 100644 libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index 9dd0a33d71..73444d45af 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -1,34 +1,64 @@ import json -from typing import Generic, List, Type, TypeVar +from typing import Generic, List, Type, TypeVar, Union + +import pydantic # pydantic: ignore from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import JsonOutputParser from langchain_core.outputs import Generation -from langchain_core.pydantic_v1 import BaseModel, ValidationError +from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION + +if PYDANTIC_MAJOR_VERSION < 2: + PydanticBaseModel = pydantic.BaseModel + +else: + from pydantic.v1 import BaseModel # pydantic: ignore + + # Union type needs to be last assignment to PydanticBaseModel to make mypy happy. + PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore -TBaseModel = TypeVar("TBaseModel", bound=BaseModel) +TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel) class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): """Parse an output using a pydantic model.""" - pydantic_object: Type[TBaseModel] - """The pydantic model to parse. - - Attention: To avoid potential compatibility issues, it's recommended to use - pydantic <2 or leverage the v1 namespace in pydantic >= 2. - """ + pydantic_object: Type[TBaseModel] # type: ignore + """The pydantic model to parse.""" + + def _parse_obj(self, obj: dict) -> TBaseModel: + if PYDANTIC_MAJOR_VERSION == 2: + try: + if issubclass(self.pydantic_object, pydantic.BaseModel): + return self.pydantic_object.model_validate(obj) + elif issubclass(self.pydantic_object, pydantic.v1.BaseModel): + return self.pydantic_object.parse_obj(obj) + else: + raise OutputParserException( + f"Unsupported model version for PydanticOutputParser: \ + {self.pydantic_object.__class__}" + ) + except (pydantic.ValidationError, pydantic.v1.ValidationError) as e: + raise self._parser_exception(e, obj) + else: # pydantic v1 + try: + return self.pydantic_object.parse_obj(obj) + except pydantic.ValidationError as e: + raise self._parser_exception(e, obj) + + def _parser_exception( + self, e: Exception, json_object: dict + ) -> OutputParserException: + json_string = json.dumps(json_object) + name = self.pydantic_object.__name__ + msg = f"Failed to parse {name} from completion {json_string}. Got: {e}" + return OutputParserException(msg, llm_output=json_string) def parse_result( self, result: List[Generation], *, partial: bool = False ) -> TBaseModel: json_object = super().parse_result(result) - try: - return self.pydantic_object.parse_obj(json_object) - except ValidationError as e: - name = self.pydantic_object.__name__ - msg = f"Failed to parse {name} from completion {json_object}. Got: {e}" - raise OutputParserException(msg, llm_output=json_object) + return self._parse_obj(json_object) def parse(self, text: str) -> TBaseModel: return super().parse(text) diff --git a/libs/core/scripts/check_pydantic.sh b/libs/core/scripts/check_pydantic.sh index 06b5bb81ae..941fa6b1f4 100755 --- a/libs/core/scripts/check_pydantic.sh +++ b/libs/core/scripts/check_pydantic.sh @@ -14,7 +14,10 @@ fi repository_path="$1" # Search for lines matching the pattern within the specified repository -result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic') +result=$( + git -C "$repository_path" grep -E '^[[:space:]]*import pydantic|^[[:space:]]*from pydantic' \ + -- ':!langchain_core/pydantic_*' ':!langchain_core/utils' | grep -v 'pydantic: ignore' +) # Check if any matching lines were found if [ -n "$result" ]; then @@ -23,5 +26,6 @@ if [ -n "$result" ]; then echo "Please replace the code with an import from langchain_core.pydantic_v1." echo "For example, replace 'from pydantic import BaseModel'" echo "with 'from langchain_core.pydantic_v1 import BaseModel'" + echo "If this was intentional, you can add # pydantic: ignore after the import to ignore this error." exit 1 fi diff --git a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py new file mode 100644 index 0000000000..bfb9f5c4cf --- /dev/null +++ b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -0,0 +1,72 @@ +from typing import Literal + +import pydantic # pydantic: ignore +import pytest + +from langchain_core.exceptions import OutputParserException +from langchain_core.language_models import ParrotFakeChatModel +from langchain_core.output_parsers.pydantic import PydanticOutputParser, TBaseModel +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION + +V1BaseModel = pydantic.BaseModel +if PYDANTIC_MAJOR_VERSION == 2: + from pydantic.v1 import BaseModel # pydantic: ignore + + V1BaseModel = BaseModel # type: ignore + + +class ForecastV2(pydantic.BaseModel): + temperature: int + f_or_c: Literal["F", "C"] + forecast: str + + +class ForecastV1(V1BaseModel): + temperature: int + f_or_c: Literal["F", "C"] + forecast: str + + +@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1]) +def test_pydantic_parser_chaining( + pydantic_object: TBaseModel, +) -> None: + prompt = PromptTemplate( + template="""{{ + "temperature": 20, + "f_or_c": "C", + "forecast": "Sunny" + }}""", + input_variables=[], + ) + + model = ParrotFakeChatModel() + + parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore + chain = prompt | model | parser + + res = chain.invoke({}) + assert type(res) == pydantic_object + assert res.f_or_c == "C" + assert res.temperature == 20 + assert res.forecast == "Sunny" + + +@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1]) +def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None: + bad_prompt = PromptTemplate( + template="""{{ + "temperature": "oof", + "f_or_c": 1, + "forecast": "Sunny" + }}""", + input_variables=[], + ) + + model = ParrotFakeChatModel() + + parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore + chain = bad_prompt | model | parser + with pytest.raises(OutputParserException): + chain.invoke({})