core[minor]: support pydantic V2 for JSONOutputParser, allow for other sources of JSON schemas (#19716)

This PR supports using Pydantic v2 objects to generate the schema for
the JSONOutputParser (#19441). This also adds a `json_schema` parameter
to allow users to pass any JSON schema to validate with, not just
pydantic.
This commit is contained in:
Jan Nissen 2024-04-04 10:57:47 -04:00 committed by GitHub
parent f97de4e275
commit 31e3ecc728
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 4 deletions

View File

@ -3,15 +3,27 @@ from __future__ import annotations
import json
import re
from json import JSONDecodeError
from typing import Any, Callable, List, Optional, Type
from typing import Any, Callable, List, Optional, Type, TypeVar, Union
import jsonpatch # type: ignore[import]
import pydantic # pydantic: ignore
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import Generation
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
if PYDANTIC_MAJOR_VERSION < 2:
PydanticBaseModel = pydantic.BaseModel
else:
from pydantic.v1 import BaseModel # pydantic: ignore
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
def _replace_new_line(match: re.Match[str]) -> str:
@ -200,11 +212,19 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
describing the difference between the previous and the current object.
"""
pydantic_object: Optional[Type[BaseModel]] = None
pydantic_object: Optional[Type[TBaseModel]] = None # type: ignore
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def _get_schema(self, pydantic_object: Type[TBaseModel]) -> dict[str, Any]:
if PYDANTIC_MAJOR_VERSION == 2:
if issubclass(pydantic_object, pydantic.BaseModel):
return pydantic_object.model_json_schema()
elif issubclass(pydantic_object, pydantic.v1.BaseModel):
return pydantic_object.schema()
return pydantic_object.schema()
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
text = result[0].text
text = text.strip()
@ -228,7 +248,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
return "Return a JSON object."
else:
# Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self.pydantic_object.schema().items()}
schema = {k: v for k, v in self._get_schema(self.pydantic_object).items()}
# Remove extraneous fields.
reduced_schema = schema

View File

@ -5,6 +5,7 @@ import pytest
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import ParrotFakeChatModel
from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.output_parsers.pydantic import PydanticOutputParser, TBaseModel
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
@ -70,3 +71,28 @@ def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None:
chain = bad_prompt | model | parser
with pytest.raises(OutputParserException):
chain.invoke({})
# JSON output parser tests
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
def test_json_parser_chaining(
pydantic_object: TBaseModel,
) -> None:
prompt = PromptTemplate(
template="""{{
"temperature": 20,
"f_or_c": "C",
"forecast": "Sunny"
}}""",
input_variables=[],
)
model = ParrotFakeChatModel()
parser = JsonOutputParser(pydantic_object=pydantic_object) # type: ignore
chain = prompt | model | parser
res = chain.invoke({})
assert res["f_or_c"] == "C"
assert res["temperature"] == 20
assert res["forecast"] == "Sunny"