mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core: pydantic output parser streaming fix (#24415)
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
c316361115
commit
6096c80b71
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Generic, List, Type
|
||||
from typing import Generic, List, Optional, Type
|
||||
|
||||
import pydantic # pydantic: ignore
|
||||
|
||||
@ -49,7 +49,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
|
||||
def parse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> TBaseModel:
|
||||
) -> Optional[TBaseModel]:
|
||||
"""Parse the result of an LLM call to a pydantic object.
|
||||
|
||||
Args:
|
||||
@ -62,8 +62,13 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
Returns:
|
||||
The parsed pydantic object.
|
||||
"""
|
||||
json_object = super().parse_result(result)
|
||||
return self._parse_obj(json_object)
|
||||
try:
|
||||
json_object = super().parse_result(result)
|
||||
return self._parse_obj(json_object)
|
||||
except OutputParserException as e:
|
||||
if partial:
|
||||
return None
|
||||
raise e
|
||||
|
||||
def parse(self, text: str) -> TBaseModel:
|
||||
"""Parse the output of an LLM call to a pydantic object.
|
||||
|
@ -1,13 +1,17 @@
|
||||
from typing import Literal
|
||||
"""Test PydanticOutputParser"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
import pydantic # pydantic: ignore
|
||||
import pytest
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import ParrotFakeChatModel
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
from langchain_core.output_parsers.json import JsonOutputParser
|
||||
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, TBaseModel
|
||||
|
||||
V1BaseModel = pydantic.BaseModel
|
||||
@ -96,3 +100,118 @@ def test_json_parser_chaining(
|
||||
assert res["f_or_c"] == "C"
|
||||
assert res["temperature"] == 20
|
||||
assert res["forecast"] == "Sunny"
|
||||
|
||||
|
||||
class Actions(Enum):
|
||||
SEARCH = "Search"
|
||||
CREATE = "Create"
|
||||
UPDATE = "Update"
|
||||
DELETE = "Delete"
|
||||
|
||||
|
||||
class TestModel(BaseModel):
|
||||
action: Actions = Field(description="Action to be performed")
|
||||
action_input: str = Field(description="Input to be used in the action")
|
||||
additional_fields: Optional[str] = Field(
|
||||
description="Additional fields", default=None
|
||||
)
|
||||
for_new_lines: str = Field(description="To be used to test newlines")
|
||||
|
||||
|
||||
# Prevent pytest from trying to run tests on TestModel
|
||||
TestModel.__test__ = False # type: ignore[attr-defined]
|
||||
|
||||
|
||||
DEF_RESULT = """{
|
||||
"action": "Update",
|
||||
"action_input": "The PydanticOutputParser class is powerful",
|
||||
"additional_fields": null,
|
||||
"for_new_lines": "not_escape_newline:\n escape_newline: \\n"
|
||||
}"""
|
||||
|
||||
# action 'update' with a lowercase 'u' to test schema validation failure.
|
||||
DEF_RESULT_FAIL = """{
|
||||
"action": "update",
|
||||
"action_input": "The PydanticOutputParser class is powerful",
|
||||
"additional_fields": null
|
||||
}"""
|
||||
|
||||
DEF_EXPECTED_RESULT = TestModel(
|
||||
action=Actions.UPDATE,
|
||||
action_input="The PydanticOutputParser class is powerful",
|
||||
additional_fields=None,
|
||||
for_new_lines="not_escape_newline:\n escape_newline: \n",
|
||||
)
|
||||
|
||||
|
||||
def test_pydantic_output_parser() -> None:
|
||||
"""Test PydanticOutputParser."""
|
||||
|
||||
pydantic_parser: PydanticOutputParser = PydanticOutputParser(
|
||||
pydantic_object=TestModel
|
||||
)
|
||||
|
||||
result = pydantic_parser.parse(DEF_RESULT)
|
||||
print("parse_result:", result) # noqa: T201
|
||||
assert DEF_EXPECTED_RESULT == result
|
||||
assert pydantic_parser.OutputType is TestModel
|
||||
|
||||
|
||||
def test_pydantic_output_parser_fail() -> None:
|
||||
"""Test PydanticOutputParser where completion result fails schema validation."""
|
||||
|
||||
pydantic_parser: PydanticOutputParser = PydanticOutputParser(
|
||||
pydantic_object=TestModel
|
||||
)
|
||||
|
||||
try:
|
||||
pydantic_parser.parse(DEF_RESULT_FAIL)
|
||||
except OutputParserException as e:
|
||||
print("parse_result:", e) # noqa: T201
|
||||
assert "Failed to parse TestModel from completion" in str(e)
|
||||
else:
|
||||
assert False, "Expected OutputParserException"
|
||||
|
||||
|
||||
def test_pydantic_output_parser_type_inference() -> None:
|
||||
"""Test pydantic output parser type inference."""
|
||||
|
||||
class SampleModel(BaseModel):
|
||||
foo: int
|
||||
bar: str
|
||||
|
||||
# Ignoring mypy error that appears in python 3.8, but not 3.11.
|
||||
# This seems to be functionally correct, so we'll ignore the error.
|
||||
pydantic_parser = PydanticOutputParser(pydantic_object=SampleModel) # type: ignore
|
||||
schema = pydantic_parser.get_output_schema().schema()
|
||||
|
||||
assert schema == {
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "string"},
|
||||
"foo": {"title": "Foo", "type": "integer"},
|
||||
},
|
||||
"required": ["foo", "bar"],
|
||||
"title": "SampleModel",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
def test_format_instructions_preserves_language() -> None:
|
||||
"""Test format instructions does not attempt to encode into ascii."""
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
description = (
|
||||
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
|
||||
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου"
|
||||
)
|
||||
|
||||
class Foo(BaseModel):
|
||||
hello: str = Field(
|
||||
description=(
|
||||
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
|
||||
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου"
|
||||
)
|
||||
)
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=Foo) # type: ignore
|
||||
assert description in parser.get_format_instructions()
|
||||
|
@ -1,123 +0,0 @@
|
||||
"""Test PydanticOutputParser"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
|
||||
class Actions(Enum):
|
||||
SEARCH = "Search"
|
||||
CREATE = "Create"
|
||||
UPDATE = "Update"
|
||||
DELETE = "Delete"
|
||||
|
||||
|
||||
class TestModel(BaseModel):
|
||||
action: Actions = Field(description="Action to be performed")
|
||||
action_input: str = Field(description="Input to be used in the action")
|
||||
additional_fields: Optional[str] = Field(
|
||||
description="Additional fields", default=None
|
||||
)
|
||||
for_new_lines: str = Field(description="To be used to test newlines")
|
||||
|
||||
|
||||
# Prevent pytest from trying to run tests on TestModel
|
||||
TestModel.__test__ = False # type: ignore[attr-defined]
|
||||
|
||||
|
||||
DEF_RESULT = """{
|
||||
"action": "Update",
|
||||
"action_input": "The PydanticOutputParser class is powerful",
|
||||
"additional_fields": null,
|
||||
"for_new_lines": "not_escape_newline:\n escape_newline: \\n"
|
||||
}"""
|
||||
|
||||
# action 'update' with a lowercase 'u' to test schema validation failure.
|
||||
DEF_RESULT_FAIL = """{
|
||||
"action": "update",
|
||||
"action_input": "The PydanticOutputParser class is powerful",
|
||||
"additional_fields": null
|
||||
}"""
|
||||
|
||||
DEF_EXPECTED_RESULT = TestModel(
|
||||
action=Actions.UPDATE,
|
||||
action_input="The PydanticOutputParser class is powerful",
|
||||
additional_fields=None,
|
||||
for_new_lines="not_escape_newline:\n escape_newline: \n",
|
||||
)
|
||||
|
||||
|
||||
def test_pydantic_output_parser() -> None:
|
||||
"""Test PydanticOutputParser."""
|
||||
|
||||
pydantic_parser: PydanticOutputParser = PydanticOutputParser(
|
||||
pydantic_object=TestModel
|
||||
)
|
||||
|
||||
result = pydantic_parser.parse(DEF_RESULT)
|
||||
print("parse_result:", result) # noqa: T201
|
||||
assert DEF_EXPECTED_RESULT == result
|
||||
assert pydantic_parser.OutputType is TestModel
|
||||
|
||||
|
||||
def test_pydantic_output_parser_fail() -> None:
|
||||
"""Test PydanticOutputParser where completion result fails schema validation."""
|
||||
|
||||
pydantic_parser: PydanticOutputParser = PydanticOutputParser(
|
||||
pydantic_object=TestModel
|
||||
)
|
||||
|
||||
try:
|
||||
pydantic_parser.parse(DEF_RESULT_FAIL)
|
||||
except OutputParserException as e:
|
||||
print("parse_result:", e) # noqa: T201
|
||||
assert "Failed to parse TestModel from completion" in str(e)
|
||||
else:
|
||||
assert False, "Expected OutputParserException"
|
||||
|
||||
|
||||
def test_pydantic_output_parser_type_inference() -> None:
|
||||
"""Test pydantic output parser type inference."""
|
||||
|
||||
class SampleModel(BaseModel):
|
||||
foo: int
|
||||
bar: str
|
||||
|
||||
# Ignoring mypy error that appears in python 3.8, but not 3.11.
|
||||
# This seems to be functionally correct, so we'll ignore the error.
|
||||
pydantic_parser = PydanticOutputParser(pydantic_object=SampleModel) # type: ignore
|
||||
schema = pydantic_parser.get_output_schema().schema()
|
||||
|
||||
assert schema == {
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "string"},
|
||||
"foo": {"title": "Foo", "type": "integer"},
|
||||
},
|
||||
"required": ["foo", "bar"],
|
||||
"title": "SampleModel",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
def test_format_instructions_preserves_language() -> None:
|
||||
"""Test format instructions does not attempt to encode into ascii."""
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
description = (
|
||||
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
|
||||
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου"
|
||||
)
|
||||
|
||||
class Foo(BaseModel):
|
||||
hello: str = Field(
|
||||
description=(
|
||||
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
|
||||
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου"
|
||||
)
|
||||
)
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=Foo) # type: ignore
|
||||
assert description in parser.get_format_instructions()
|
Loading…
Reference in New Issue
Block a user