openai[patch]: ChatOpenAI.with_structured_output json_schema support (#25123)

This commit is contained in:
Bagatur 2024-08-07 08:09:07 -07:00 committed by GitHub
parent 0ba125c3cd
commit 09fbce13c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 895 additions and 603 deletions

View File

@ -7,6 +7,7 @@ import json
import logging
import os
import sys
import warnings
from io import BytesIO
from math import ceil
from operator import itemgetter
@ -27,7 +28,6 @@ from typing import (
TypeVar,
Union,
cast,
overload,
)
from urllib.parse import urlparse
@ -74,7 +74,7 @@ from langchain_core.output_parsers.openai_tools import (
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain
from langchain_core.runnables.config import run_in_executor
from langchain_core.tools import BaseTool
from langchain_core.utils import (
@ -86,7 +86,11 @@ from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.pydantic import (
PydanticBaseModel,
TypeBaseModel,
is_basemodel_subclass,
)
from langchain_core.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__)
@ -298,6 +302,8 @@ class _AllReturnType(TypedDict):
class BaseChatOpenAI(BaseChatModel):
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
root_client: Any = Field(default=None, exclude=True) #: :meta private:
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use."""
temperature: float = 0.7
@ -445,9 +451,8 @@ class BaseChatOpenAI(BaseChatModel):
) from e
values["http_client"] = httpx.Client(proxy=values["openai_proxy"])
sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.OpenAI(
**client_params, **sync_specific
).chat.completions
values["root_client"] = openai.OpenAI(**client_params, **sync_specific)
values["client"] = values["root_client"].chat.completions
if not values.get("async_client"):
if values["openai_proxy"] and not values["http_async_client"]:
try:
@ -461,10 +466,10 @@ class BaseChatOpenAI(BaseChatModel):
proxy=values["openai_proxy"]
)
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncOpenAI(
values["root_async_client"] = openai.AsyncOpenAI(
**client_params, **async_specific
).chat.completions
)
values["async_client"] = values["root_async_client"].chat.completions
return values
@property
@ -525,13 +530,32 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if "response_format" in payload and is_basemodel_subclass(
payload["response_format"]
):
# TODO: Add support for streaming with Pydantic response_format.
warnings.warn("Streaming with Pydantic response_format not yet supported.")
chat_result = self._generate(
messages, stop, run_manager=run_manager, **kwargs
)
msg = chat_result.generations[0].message
yield ChatGenerationChunk(
message=AIMessageChunk(
**msg.dict(exclude={"type", "additional_kwargs"}),
# preserve the "parsed" Pydantic object without converting to dict
additional_kwargs=msg.additional_kwargs,
),
generation_info=chat_result.generations[0].generation_info,
)
return
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
base_generation_info = {}
with response:
is_first_chunk = True
for chunk in response:
@ -594,13 +618,21 @@ class BaseChatOpenAI(BaseChatModel):
)
return generate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
if self.include_response_headers:
generation_info = None
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response = self.root_client.beta.chat.completions.parse(**payload)
elif self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
generation_info = None
return self._create_chat_result(response, generation_info)
def _get_request_payload(
@ -625,18 +657,19 @@ class BaseChatOpenAI(BaseChatModel):
generation_info: Optional[Dict] = None,
) -> ChatResult:
generations = []
if not isinstance(response, dict):
response = response.model_dump()
response_dict = (
response if isinstance(response, dict) else response.model_dump()
)
# Sometimes the AI Model calling will get error, we should raise it.
# Otherwise, the next code 'choices.extend(response["choices"])'
# will throw a "TypeError: 'NoneType' object is not iterable" error
# to mask the true error. Because 'response["choices"]' is None.
if response.get("error"):
raise ValueError(response.get("error"))
if response_dict.get("error"):
raise ValueError(response_dict.get("error"))
token_usage = response.get("usage", {})
for res in response["choices"]:
token_usage = response_dict.get("usage", {})
for res in response_dict["choices"]:
message = _convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
message.usage_metadata = {
@ -656,9 +689,19 @@ class BaseChatOpenAI(BaseChatModel):
generations.append(gen)
llm_output = {
"token_usage": token_usage,
"model_name": response.get("model", self.model_name),
"system_fingerprint": response.get("system_fingerprint", ""),
"model_name": response_dict.get("model", self.model_name),
"system_fingerprint": response_dict.get("system_fingerprint", ""),
}
if isinstance(response, openai.BaseModel) and getattr(
response, "choices", None
):
message = response.choices[0].message # type: ignore[attr-defined]
if hasattr(message, "parsed"):
generations[0].message.additional_kwargs["parsed"] = message.parsed
if hasattr(message, "refusal"):
generations[0].message.additional_kwargs["refusal"] = message.refusal
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
@ -671,13 +714,31 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if "response_format" in payload and is_basemodel_subclass(
payload["response_format"]
):
# TODO: Add support for streaming with Pydantic response_format.
warnings.warn("Streaming with Pydantic response_format not yet supported.")
chat_result = await self._agenerate(
messages, stop, run_manager=run_manager, **kwargs
)
msg = chat_result.generations[0].message
yield ChatGenerationChunk(
message=AIMessageChunk(
**msg.dict(exclude={"type", "additional_kwargs"}),
# preserve the "parsed" Pydantic object without converting to dict
additional_kwargs=msg.additional_kwargs,
),
generation_info=chat_result.generations[0].generation_info,
)
return
if self.include_response_headers:
raw_response = self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = await self.async_client.create(**payload)
base_generation_info = {}
async with response:
is_first_chunk = True
async for chunk in response:
@ -745,13 +806,23 @@ class BaseChatOpenAI(BaseChatModel):
)
return await agenerate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
if self.include_response_headers:
generation_info = None
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response = await self.root_async_client.beta.chat.completions.parse(
**payload
)
elif self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
else:
response = await self.async_client.create(**payload)
generation_info = None
return await run_in_executor(
None, self._create_chat_result, response, generation_info
)
@ -1028,34 +1099,13 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
# TODO: Fix typing.
@overload # type: ignore[override]
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: Literal[True] = True,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _AllReturnType]: ...
@overload
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: Literal[False] = False,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]: ...
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
strict: Optional[bool] = None,
**kwargs: Any,
@ -1065,10 +1115,12 @@ class BaseChatOpenAI(BaseChatModel):
.. versionchanged:: 0.1.21
Support for ``strict`` argument added.
Support for ``method`` = "json_schema" added.
Args:
schema:
The output schema. Can be passed in as:
- an OpenAI function/tool schema,
- a JSON Schema,
- a TypedDict class (support added in 0.1.20),
@ -1085,12 +1137,36 @@ class BaseChatOpenAI(BaseChatModel):
Added support for TypedDict class.
method:
The method for steering model generation, one of "function_calling"
or "json_mode". If "function_calling" then the schema will be converted
to an OpenAI function and the returned model will make use of the
function-calling API. If "json_mode" then OpenAI's JSON mode will be
used. Note that if using "json_mode" then you must include instructions
for formatting the output into the desired schema into the model call.
The method for steering model generation, one of:
- "function_calling":
Uses OpenAI's tool-calling (formerly called function calling)
API: https://platform.openai.com/docs/guides/function-calling
- "json_schema":
Uses OpenAI's Structured Output API:
https://platform.openai.com/docs/guides/structured-outputs.
Supported for "gpt-4o-mini", "gpt-4o-2024-08-06", and later
models.
- "json_mode":
Uses OpenAI's JSON mode. Note that if using JSON mode then you
must include instructions for formatting the output into the
desired schema into the model call:
https://platform.openai.com/docs/guides/structured-outputs/json-mode
Learn more about the differences between the methods and which models
support which methods here:
- https://platform.openai.com/docs/guides/structured-outputs/structured-outputs-vs-json-mode
- https://platform.openai.com/docs/guides/structured-outputs/function-calling-vs-response-format
.. versionchanged:: 0.1.21
Added support for "json_schema".
.. note:: Planned breaking change in version `0.2.0`
``method`` default will be changed to "json_schema" from
"function_calling".
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
@ -1098,14 +1174,20 @@ class BaseChatOpenAI(BaseChatModel):
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
strict: If True and ``method`` = "function_calling", model output is
guaranteed to exactly match the schema
If True, the input schema will also be
validated according to
https://platform.openai.com/docs/guides/structured-outputs/supported-schemas.
If False, input schema will not be validated and model output will not
be validated.
If None, ``strict`` argument will not be passed to the model.
strict:
- True:
Model output is guaranteed to exactly match the schema.
The input schema will also be validated according to
https://platform.openai.com/docs/guides/structured-outputs/supported-schemas.
- False:
Input schema will not be validated and model output will not be
validated.
- None:
``strict`` argument will not be passed to the model.
If ``method`` is "json_schema" defaults to True. If ``method`` is
"function_calling" or "json_mode" defaults to None. Can only be
non-null if ``method`` is "function_calling" or "json_schema".
.. versionadded:: 0.1.21
@ -1124,9 +1206,10 @@ class BaseChatOpenAI(BaseChatModel):
Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
If ``include_raw`` is True, then Runnable outputs a dict with keys:
- ``"raw"``: BaseMessage
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- ``"parsing_error"``: Optional[BaseException]
- "raw": BaseMessage
- "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- "parsing_error": Optional[BaseException]
Example: schema=Pydantic class, method="function_calling", include_raw=False, strict=True:
.. note:: Valid schemas when using ``strict`` = True
@ -1305,15 +1388,15 @@ class BaseChatOpenAI(BaseChatModel):
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
if strict is not None and method != "function_calling":
if strict is not None and method == "json_mode":
raise ValueError(
"Argument `strict` is only supported for `method`='function_calling'"
"Argument `strict` is not supported with `method`='json_mode'"
)
is_pydantic_schema = _is_pydantic_class(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"schema must be specified when method is not 'json_mode'. "
"Received None."
)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
@ -1339,6 +1422,20 @@ class BaseChatOpenAI(BaseChatModel):
if is_pydantic_schema
else JsonOutputParser()
)
elif method == "json_schema":
if schema is None:
raise ValueError(
"schema must be specified when method is not 'json_mode'. "
"Received None."
)
strict = strict if strict is not None else True
response_format = _convert_to_openai_response_format(schema, strict=strict)
llm = self.bind(response_format=response_format)
output_parser = (
cast(Runnable, _oai_structured_outputs_parser)
if is_pydantic_schema
else JsonOutputParser()
)
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "
@ -1975,3 +2072,40 @@ def _resize(width: int, height: int) -> Tuple[int, int]:
height = (width * 768) // height
width = 768
return width, height
def _convert_to_openai_response_format(
schema: Union[Dict[str, Any], Type], strict: bool
) -> Union[Dict, TypeBaseModel]:
if isinstance(schema, type) and is_basemodel_subclass(schema):
return schema
else:
function = convert_to_openai_function(schema, strict=strict)
function["schema"] = function.pop("parameters")
return {"type": "json_schema", "json_schema": function}
@chain
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel:
if ai_msg.additional_kwargs.get("parsed"):
return ai_msg.additional_kwargs["parsed"]
elif ai_msg.additional_kwargs.get("refusal"):
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
else:
raise ValueError(
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
"field."
)
class OpenAIRefusalError(Exception):
"""Error raised when OpenAI Structured Outputs API returns a refusal.
When using OpenAI's Structured Outputs API with user-generated input, the model
may occasionally refuse to fulfill the request for safety reasons.
See here for more on refusals:
https://platform.openai.com/docs/guides/structured-outputs/refusals
.. versionadded:: 0.1.21
"""

View File

@ -1527,4 +1527,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "23d99a41f0cff5bf1869e8e18ac953a9802b0d1912eedcecca624650c1ff3af6"
content-hash = "a08bed7f2e62b3f6c7fc52a31c2529b44d4e5adcc55aba5047be027596fdb31f"

View File

@ -24,7 +24,7 @@ ignore_missing_imports = true
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = { version = "^0.2.29rc1", allow-prereleases=true }
openai = "^1.32.0"
openai = "^1.40.0"
tiktoken = ">=0.7,<1"
[tool.ruff.lint]

View File

@ -1,7 +1,7 @@
"""Test ChatOpenAI chat model."""
import base64
from typing import Any, AsyncIterator, List, Optional, cast
from typing import Any, AsyncIterator, List, Literal, Optional, cast
import httpx
import openai
@ -796,13 +796,21 @@ def test_tool_calling_strict() -> None:
next(model_with_invalid_tool_schema.stream(query))
def test_structured_output_strict() -> None:
@pytest.mark.parametrize(
("model", "method", "strict"),
[("gpt-4o", "function_calling", True), ("gpt-4o-2024-08-06", "json_schema", None)],
)
def test_structured_output_strict(
model: str,
method: Literal["function_calling", "json_schema"],
strict: Optional[bool],
) -> None:
"""Test to verify structured output with strict=True."""
from pydantic import BaseModel as BaseModelProper
from pydantic import Field as FieldProper
model = ChatOpenAI(model="gpt-4o", temperature=0)
llm = ChatOpenAI(model=model, temperature=0)
class Joke(BaseModelProper):
"""Joke to tell user."""
@ -814,7 +822,7 @@ def test_structured_output_strict() -> None:
# Type ignoring since the interface only officially supports pydantic 1
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
# We'll need to do a pass updating the type signatures.
chat = model.with_structured_output(Joke, strict=True) # type: ignore[arg-type]
chat = llm.with_structured_output(Joke, method=method, strict=strict)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, Joke)
@ -822,7 +830,9 @@ def test_structured_output_strict() -> None:
assert isinstance(chunk, Joke)
# Schema
chat = model.with_structured_output(Joke.model_json_schema(), strict=True)
chat = llm.with_structured_output(
Joke.model_json_schema(), method=method, strict=strict
)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
@ -831,3 +841,27 @@ def test_structured_output_strict() -> None:
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}
# Invalid schema with optional fields:
class InvalidJoke(BaseModelProper):
"""Joke to tell user."""
setup: str = FieldProper(description="question to set up a joke")
# Invalid field, can't have default value.
punchline: str = FieldProper(
default="foo", description="answer to resolve the joke"
)
chat = llm.with_structured_output(InvalidJoke, method=method, strict=strict)
with pytest.raises(openai.BadRequestError):
chat.invoke("Tell me a joke about cats.")
with pytest.raises(openai.BadRequestError):
next(chat.stream("Tell me a joke about cats."))
chat = llm.with_structured_output(
InvalidJoke.model_json_schema(), method=method, strict=strict
)
with pytest.raises(openai.BadRequestError):
chat.invoke("Tell me a joke about cats.")
with pytest.raises(openai.BadRequestError):
next(chat.stream("Tell me a joke about cats."))

View File

@ -1,7 +1,7 @@
"""Test OpenAI Chat API wrapper."""
import json
from typing import Any, List, Type, Union
from typing import Any, Dict, List, Literal, Optional, Type, Union
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -343,17 +343,32 @@ class MakeASandwich(BaseModel):
None,
],
)
def test_bind_tools_tool_choice(tool_choice: Any) -> None:
@pytest.mark.parametrize("strict", [True, False, None])
def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> None:
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm.bind_tools(tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice)
llm.bind_tools(
tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice, strict=strict
)
@pytest.mark.parametrize("schema", [GenerateUsername, GenerateUsername.schema()])
def test_with_structured_output(schema: Union[Type[BaseModel], dict]) -> None:
@pytest.mark.parametrize("method", ["json_schema", "function_calling", "json_mode"])
@pytest.mark.parametrize("include_raw", [True, False])
@pytest.mark.parametrize("strict", [True, False, None])
def test_with_structured_output(
schema: Union[Type, Dict[str, Any], None],
method: Literal["function_calling", "json_mode", "json_schema"],
include_raw: bool,
strict: Optional[bool],
) -> None:
"""Test passing in manually construct tool call message."""
if method == "json_mode":
strict = None
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm.with_structured_output(schema)
llm.with_structured_output(
schema, method=method, strict=strict, include_raw=include_raw
)
def test_get_num_tokens_from_messages() -> None:

1157
poetry.lock generated

File diff suppressed because it is too large Load Diff