From 31e3ecc72866f6de74788e572e6829beb1a3a957 Mon Sep 17 00:00:00 2001 From: Jan Nissen Date: Thu, 4 Apr 2024 10:57:47 -0400 Subject: [PATCH] core[minor]: support pydantic V2 for JSONOutputParser, allow for other sources of JSON schemas (#19716) This PR supports using Pydantic v2 objects to generate the schema for the JSONOutputParser (#19441). This also adds a `json_schema` parameter to allow users to pass any JSON schema to validate with, not just pydantic. --- .../langchain_core/output_parsers/json.py | 28 ++++++++++++++++--- .../output_parsers/test_pydantic_parser.py | 26 +++++++++++++++++ 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index 3b986786ff..5d8298986b 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -3,15 +3,27 @@ from __future__ import annotations import json import re from json import JSONDecodeError -from typing import Any, Callable, List, Optional, Type +from typing import Any, Callable, List, Optional, Type, TypeVar, Union import jsonpatch # type: ignore[import] +import pydantic # pydantic: ignore from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser from langchain_core.outputs import Generation -from langchain_core.pydantic_v1 import BaseModel +from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION + +if PYDANTIC_MAJOR_VERSION < 2: + PydanticBaseModel = pydantic.BaseModel + +else: + from pydantic.v1 import BaseModel # pydantic: ignore + + # Union type needs to be last assignment to PydanticBaseModel to make mypy happy. + PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore + +TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel) def _replace_new_line(match: re.Match[str]) -> str: @@ -200,11 +212,19 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): describing the difference between the previous and the current object. """ - pydantic_object: Optional[Type[BaseModel]] = None + pydantic_object: Optional[Type[TBaseModel]] = None # type: ignore def _diff(self, prev: Optional[Any], next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch + def _get_schema(self, pydantic_object: Type[TBaseModel]) -> dict[str, Any]: + if PYDANTIC_MAJOR_VERSION == 2: + if issubclass(pydantic_object, pydantic.BaseModel): + return pydantic_object.model_json_schema() + elif issubclass(pydantic_object, pydantic.v1.BaseModel): + return pydantic_object.schema() + return pydantic_object.schema() + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: text = result[0].text text = text.strip() @@ -228,7 +248,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): return "Return a JSON object." else: # Copy schema to avoid altering original Pydantic schema. - schema = {k: v for k, v in self.pydantic_object.schema().items()} + schema = {k: v for k, v in self._get_schema(self.pydantic_object).items()} # Remove extraneous fields. reduced_schema = schema diff --git a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py index bfb9f5c4cf..0bb8d47815 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -5,6 +5,7 @@ import pytest from langchain_core.exceptions import OutputParserException from langchain_core.language_models import ParrotFakeChatModel +from langchain_core.output_parsers.json import JsonOutputParser from langchain_core.output_parsers.pydantic import PydanticOutputParser, TBaseModel from langchain_core.prompts.prompt import PromptTemplate from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION @@ -70,3 +71,28 @@ def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None: chain = bad_prompt | model | parser with pytest.raises(OutputParserException): chain.invoke({}) + + +# JSON output parser tests +@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1]) +def test_json_parser_chaining( + pydantic_object: TBaseModel, +) -> None: + prompt = PromptTemplate( + template="""{{ + "temperature": 20, + "f_or_c": "C", + "forecast": "Sunny" + }}""", + input_variables=[], + ) + + model = ParrotFakeChatModel() + + parser = JsonOutputParser(pydantic_object=pydantic_object) # type: ignore + chain = prompt | model | parser + + res = chain.invoke({}) + assert res["f_or_c"] == "C" + assert res["temperature"] == 20 + assert res["forecast"] == "Sunny"