mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
core[minor]: support pydantic V2 for JSONOutputParser, allow for other sources of JSON schemas (#19716)
This PR supports using Pydantic v2 objects to generate the schema for the JSONOutputParser (#19441). This also adds a `json_schema` parameter to allow users to pass any JSON schema to validate with, not just pydantic.
This commit is contained in:
parent
f97de4e275
commit
31e3ecc728
@ -3,15 +3,27 @@ from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Callable, List, Optional, Type
|
||||
from typing import Any, Callable, List, Optional, Type, TypeVar, Union
|
||||
|
||||
import jsonpatch # type: ignore[import]
|
||||
import pydantic # pydantic: ignore
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
|
||||
if PYDANTIC_MAJOR_VERSION < 2:
|
||||
PydanticBaseModel = pydantic.BaseModel
|
||||
|
||||
else:
|
||||
from pydantic.v1 import BaseModel # pydantic: ignore
|
||||
|
||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
|
||||
|
||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
||||
|
||||
|
||||
def _replace_new_line(match: re.Match[str]) -> str:
|
||||
@ -200,11 +212,19 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
describing the difference between the previous and the current object.
|
||||
"""
|
||||
|
||||
pydantic_object: Optional[Type[BaseModel]] = None
|
||||
pydantic_object: Optional[Type[TBaseModel]] = None # type: ignore
|
||||
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def _get_schema(self, pydantic_object: Type[TBaseModel]) -> dict[str, Any]:
|
||||
if PYDANTIC_MAJOR_VERSION == 2:
|
||||
if issubclass(pydantic_object, pydantic.BaseModel):
|
||||
return pydantic_object.model_json_schema()
|
||||
elif issubclass(pydantic_object, pydantic.v1.BaseModel):
|
||||
return pydantic_object.schema()
|
||||
return pydantic_object.schema()
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
text = result[0].text
|
||||
text = text.strip()
|
||||
@ -228,7 +248,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
return "Return a JSON object."
|
||||
else:
|
||||
# Copy schema to avoid altering original Pydantic schema.
|
||||
schema = {k: v for k, v in self.pydantic_object.schema().items()}
|
||||
schema = {k: v for k, v in self._get_schema(self.pydantic_object).items()}
|
||||
|
||||
# Remove extraneous fields.
|
||||
reduced_schema = schema
|
||||
|
@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import ParrotFakeChatModel
|
||||
from langchain_core.output_parsers.json import JsonOutputParser
|
||||
from langchain_core.output_parsers.pydantic import PydanticOutputParser, TBaseModel
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
@ -70,3 +71,28 @@ def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None:
|
||||
chain = bad_prompt | model | parser
|
||||
with pytest.raises(OutputParserException):
|
||||
chain.invoke({})
|
||||
|
||||
|
||||
# JSON output parser tests
|
||||
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
|
||||
def test_json_parser_chaining(
|
||||
pydantic_object: TBaseModel,
|
||||
) -> None:
|
||||
prompt = PromptTemplate(
|
||||
template="""{{
|
||||
"temperature": 20,
|
||||
"f_or_c": "C",
|
||||
"forecast": "Sunny"
|
||||
}}""",
|
||||
input_variables=[],
|
||||
)
|
||||
|
||||
model = ParrotFakeChatModel()
|
||||
|
||||
parser = JsonOutputParser(pydantic_object=pydantic_object) # type: ignore
|
||||
chain = prompt | model | parser
|
||||
|
||||
res = chain.invoke({})
|
||||
assert res["f_or_c"] == "C"
|
||||
assert res["temperature"] == 20
|
||||
assert res["forecast"] == "Sunny"
|
||||
|
Loading…
Reference in New Issue
Block a user