docs: Document and test PydanticOutputFunctionsParser (#15759)

This PR adds documentation and testing to
`PydanticOutputFunctionsParser(OutputFunctionsParser)`.
pull/16174/head
Eugene Yurtsev 5 months ago committed by GitHub
parent 3502a407d9
commit 5d8c147332
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -136,10 +136,52 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
class PydanticOutputFunctionsParser(OutputFunctionsParser):
"""Parse an output as a pydantic object."""
"""Parse an output as a pydantic object.
This parser is used to parse the output of a ChatModel that uses
OpenAI function format to invoke functions.
The parser extracts the function call invocation and matches
them to the pydantic schema provided.
An exception will be raised if the function call does not match
the provided schema.
Example:
... code-block:: python
message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {
"name": "cookie",
"arguments": json.dumps({"name": "value", "age": 10}),
}
},
)
chat_generation = ChatGeneration(message=message)
class Cookie(BaseModel):
name: str
age: int
class Dog(BaseModel):
species: str
# Full output
parser = PydanticOutputFunctionsParser(
pydantic_schema={"cookie": Cookie, "dog": Dog}
)
result = parser.parse_result([chat_generation])
"""
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
"""The pydantic schema to parse the output with."""
"""The pydantic schema to parse the output with.
If multiple schemas are provided, then the function name will be used to
determine which schema to use.
"""
@root_validator(pre=True)
def validate_schema(cls, values: Dict) -> Dict:

@ -1,3 +1,4 @@
import json
from typing import Any, Dict
import pytest
@ -7,7 +8,9 @@ from langchain_core.outputs import ChatGeneration
from langchain.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain.pydantic_v1 import BaseModel
def test_json_output_function_parser() -> None:
@ -134,3 +137,61 @@ def test_exceptions_raised_while_parsing(bad_message: BaseMessage) -> None:
with pytest.raises(OutputParserException):
JsonOutputFunctionsParser().parse_result([chat_generation])
def test_pydantic_output_functions_parser() -> None:
"""Test pydantic output functions parser."""
message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {
"name": "function_name",
"arguments": json.dumps({"name": "value", "age": 10}),
}
},
)
chat_generation = ChatGeneration(message=message)
class Model(BaseModel):
"""Test model."""
name: str
age: int
# Full output
parser = PydanticOutputFunctionsParser(pydantic_schema=Model)
result = parser.parse_result([chat_generation])
assert result == Model(name="value", age=10)
def test_pydantic_output_functions_parser_multiple_schemas() -> None:
"""Test that the parser works if providing multiple pydantic schemas."""
message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {
"name": "cookie",
"arguments": json.dumps({"name": "value", "age": 10}),
}
},
)
chat_generation = ChatGeneration(message=message)
class Cookie(BaseModel):
"""Test model."""
name: str
age: int
class Dog(BaseModel):
"""Test model."""
species: str
# Full output
parser = PydanticOutputFunctionsParser(
pydantic_schema={"cookie": Cookie, "dog": Dog}
)
result = parser.parse_result([chat_generation])
assert result == Cookie(name="value", age=10)

Loading…
Cancel
Save