forked from Archives/langchain
59d054308c
Currently, the output type of a number of OutputParser's `parse` methods is `Any` when it can in fact be inferred. This PR makes BaseOutputParser use a generic type and fixes the output types of the following parsers: - `PydanticOutputParser` - `OutputFixingParser` - `RetryOutputParser` - `RetryWithErrorOutputParser` The output of the `StructuredOutputParser` is corrected from `BaseModel` to `Any` since there are no type guarantees provided by the parser. Fixes issue #2715
72 lines
1.9 KiB
Python
72 lines
1.9 KiB
Python
"""Test PydanticOutputParser"""
|
|
from enum import Enum
|
|
from typing import Optional
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from langchain.output_parsers.pydantic import PydanticOutputParser
|
|
from langchain.schema import OutputParserException
|
|
|
|
|
|
class Actions(Enum):
|
|
SEARCH = "Search"
|
|
CREATE = "Create"
|
|
UPDATE = "Update"
|
|
DELETE = "Delete"
|
|
|
|
|
|
class TestModel(BaseModel):
|
|
action: Actions = Field(description="Action to be performed")
|
|
action_input: str = Field(description="Input to be used in the action")
|
|
additional_fields: Optional[str] = Field(
|
|
description="Additional fields", default=None
|
|
)
|
|
|
|
|
|
DEF_RESULT = """{
|
|
"action": "Update",
|
|
"action_input": "The PydanticOutputParser class is powerful",
|
|
"additional_fields": null
|
|
}"""
|
|
|
|
# action 'update' with a lowercase 'u' to test schema validation failure.
|
|
DEF_RESULT_FAIL = """{
|
|
"action": "update",
|
|
"action_input": "The PydanticOutputParser class is powerful",
|
|
"additional_fields": null
|
|
}"""
|
|
|
|
DEF_EXPECTED_RESULT = TestModel(
|
|
action=Actions.UPDATE,
|
|
action_input="The PydanticOutputParser class is powerful",
|
|
additional_fields=None,
|
|
)
|
|
|
|
|
|
def test_pydantic_output_parser() -> None:
|
|
"""Test PydanticOutputParser."""
|
|
|
|
pydantic_parser: PydanticOutputParser[TestModel] = PydanticOutputParser(
|
|
pydantic_object=TestModel
|
|
)
|
|
|
|
result = pydantic_parser.parse(DEF_RESULT)
|
|
print("parse_result:", result)
|
|
assert DEF_EXPECTED_RESULT == result
|
|
|
|
|
|
def test_pydantic_output_parser_fail() -> None:
|
|
"""Test PydanticOutputParser where completion result fails schema validation."""
|
|
|
|
pydantic_parser: PydanticOutputParser[TestModel] = PydanticOutputParser(
|
|
pydantic_object=TestModel
|
|
)
|
|
|
|
try:
|
|
pydantic_parser.parse(DEF_RESULT_FAIL)
|
|
except OutputParserException as e:
|
|
print("parse_result:", e)
|
|
assert "Failed to parse TestModel from completion" in str(e)
|
|
else:
|
|
assert False, "Expected OutputParserException"
|