core[minor], openai[minor], langchain[patch]: BaseLanguageModel.with_structured_output #17302)

```python
class Foo(BaseModel):
  bar: str

structured_llm = ChatOpenAI().with_structured_output(Foo)
```

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/17296/head^2
Bagatur 4 months ago committed by GitHub
parent f685d2f50c
commit b5f8cf9509
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -5,17 +5,19 @@ from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Sequence,
Set,
Type,
TypeVar,
Union,
)
from typing_extensions import TypeAlias
from langchain_core._api import deprecated
from langchain_core._api import beta, deprecated
from langchain_core.messages import (
AnyMessage,
BaseMessage,
@ -23,6 +25,7 @@ from langchain_core.messages import (
get_buffer_string,
)
from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable, RunnableSerializable
from langchain_core.utils import get_pydantic_field_names
@ -155,6 +158,13 @@ class BaseLanguageModel(
prompt and additional model provider-specific output.
"""
@beta()
def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Implement this if there is a way of steering the model to generate responses that match a given schema.""" # noqa: E501
raise NotImplementedError()
@deprecated("0.1.7", alternative="invoke", removal="0.2.0")
@abstractmethod
def predict(

@ -24,6 +24,7 @@ from langchain_core.output_parsers.list import (
MarkdownListOutputParser,
NumberedListOutputParser,
)
from langchain_core.output_parsers.pydantic import PydanticOutputParser
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers.transform import (
BaseCumulativeTransformOutputParser,
@ -45,4 +46,5 @@ __all__ = [
"SimpleJsonOutputParser",
"XMLOutputParser",
"JsonOutputParser",
"PydanticOutputParser",
]

@ -15,15 +15,17 @@ from typing import (
from typing_extensions import get_args
from langchain_core.language_models import LanguageModelOutput
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING:
from langchain_core.prompt_values import PromptValue
T = TypeVar("T")
OutputParserLike = Runnable[LanguageModelOutput, T]
class BaseLLMOutputParser(Generic[T], ABC):
@ -57,7 +59,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
class BaseGenerationOutputParser(
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T]
):
"""Base class to parse the output of an LLM call."""
@ -116,7 +118,7 @@ class BaseGenerationOutputParser(
class BaseOutputParser(
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T]
):
"""Base class to parse the output of an LLM call.

@ -0,0 +1,62 @@
import json
from typing import Any, List, Type
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError
class PydanticOutputParser(JsonOutputParser):
"""Parse an output using a pydantic model."""
pydantic_object: Type[BaseModel]
"""The pydantic model to parse.
Attention: To avoid potential compatibility issues, it's recommended to use
pydantic <2 or leverage the v1 namespace in pydantic >= 2.
"""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
json_object = super().parse_result(result)
try:
return self.pydantic_object.parse_obj(json_object)
except ValidationError as e:
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}"
raise OutputParserException(msg, llm_output=json_object)
def get_format_instructions(self) -> str:
# Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self.pydantic_object.schema().items()}
# Remove extraneous fields.
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema_str = json.dumps(reduced_schema)
return _PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)
@property
def _type(self) -> str:
return "pydantic"
@property
def OutputType(self) -> Type[BaseModel]:
"""Return the pydantic model."""
return self.pydantic_object
_PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
Here is the output schema:
```
{schema}
```""" # noqa: E501

@ -14,6 +14,7 @@ EXPECTED_ALL = [
"SimpleJsonOutputParser",
"XMLOutputParser",
"JsonOutputParser",
"PydanticOutputParser",
]

@ -1,53 +1,3 @@
import json
from typing import Any, List, Type
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
class PydanticOutputParser(JsonOutputParser):
"""Parse an output using a pydantic model."""
pydantic_object: Type[BaseModel]
"""The pydantic model to parse.
Attention: To avoid potential compatibility issues, it's recommended to use
pydantic <2 or leverage the v1 namespace in pydantic >= 2.
"""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
json_object = super().parse_result(result)
try:
return self.pydantic_object.parse_obj(json_object)
except ValidationError as e:
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}"
raise OutputParserException(msg, llm_output=json_object)
def get_format_instructions(self) -> str:
# Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self.pydantic_object.schema().items()}
# Remove extraneous fields.
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema_str = json.dumps(reduced_schema)
return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)
@property
def _type(self) -> str:
return "pydantic"
@property
def OutputType(self) -> Type[BaseModel]:
"""Return the pydantic model."""
return self.pydantic_object
__all__ = ["PydanticOutputParser"]

@ -5,6 +5,7 @@ from __future__ import annotations
import logging
import os
import sys
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
@ -19,12 +20,15 @@ from typing import (
Tuple,
Type,
TypedDict,
TypeVar,
Union,
cast,
overload,
)
import openai
import tiktoken
from langchain_core._api import beta
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@ -51,9 +55,14 @@ from langchain_core.messages import (
ToolMessage,
ToolMessageChunk,
)
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import (
convert_to_secret_str,
@ -66,6 +75,11 @@ from langchain_core.utils.function_calling import (
)
from langchain_core.utils.utils import build_extra_kwargs
from langchain_openai.output_parsers import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
logger = logging.getLogger(__name__)
@ -189,6 +203,17 @@ class _FunctionCall(TypedDict):
name: str
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM]]
_DictOrPydantic = Union[Dict, _BM]
class _AllReturnType(TypedDict):
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
class ChatOpenAI(BaseChatModel):
"""`OpenAI` Chat large language models API.
@ -673,7 +698,7 @@ class ChatOpenAI(BaseChatModel):
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[Union[dict, str, Literal["auto", "none"]]] = None,
tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
@ -695,21 +720,215 @@ class ChatOpenAI(BaseChatModel):
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None:
if isinstance(tool_choice, str) and (tool_choice not in ("auto", "none")):
tool_choice = {"type": "function", "function": {"name": tool_choice}}
if isinstance(tool_choice, dict) and (len(formatted_tools) != 1):
if tool_choice is not None and tool_choice:
if len(formatted_tools) != 1:
raise ValueError(
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools."
)
if isinstance(tool_choice, dict) and (
formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"]
):
if isinstance(tool_choice, str):
if tool_choice not in ("auto", "none"):
tool_choice = {
"type": "function",
"function": {"name": tool_choice},
}
elif isinstance(tool_choice, bool):
tool_choice = formatted_tools[0]
elif isinstance(tool_choice, dict):
if (
formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"]
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}."
)
else:
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}."
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}"
)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
@overload
def with_structured_output(
self,
schema: _DictOrPydanticClass,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
return_type: Literal["all"] = "all",
**kwargs: Any,
) -> Runnable[LanguageModelInput, _AllReturnType]:
...
@overload
def with_structured_output(
self,
schema: _DictOrPydanticClass,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
return_type: Literal["parsed"] = "parsed",
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
...
@beta()
def with_structured_output(
self,
schema: _DictOrPydanticClass,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
return_type: Literal["parsed", "all"] = "parsed",
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be. If
`method` is "function_calling" and `schema` is a dict, then the dict
must match the OpenAI function-calling spec.
method: The method for steering model generation, either "function_calling"
or "json_mode". If "function_calling" then the schema will be converted
to an OpenAI function and the returned model will make use of the
function-calling API. If "json_mode" then OpenAI's JSON mode will be
used.
return_type: The wrapped model's return type, either "parsed" or "all". If
"parsed" then only the parsed structured output is returned. If an
error occurs during model output parsing it will be raised. If "all"
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes any ChatModel input and returns as output:
If return_type == "all" then a dict with keys:
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
If return_type == "parsed" then just _DictOrPydantic is returned,
where _DictOrPydantic depends on the schema:
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class.
If schema is a dict then _DictOrPydantic is a dict.
Example: Function-calling, Pydantic schema (method="function_calling", return_type="parsed"):
.. code-block:: python
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> AnswerWithJustification(
# answer='They weigh the same',
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
# )
Example: Function-calling, Pydantic schema (method="function_calling", return_type="all"):
.. code-block:: python
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification, return_type="all")
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
# 'parsing_error': None
# }
Example: Function-calling, dict schema (method="function_calling", return_type="parsed"):
.. code-block:: python
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_tool
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
dict_schema = convert_to_openai_tool(AnswerWithJustification)
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = llm.with_structured_output(dict_schema)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
# }
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if method == "function_calling":
llm = self.bind_tools([schema], tool_choice=True)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
output_parser = (
PydanticOutputParser(pydantic_object=schema)
if is_pydantic_schema
else JsonOutputParser()
)
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_format'. Received: '{method}'"
)
if return_type == "parsed":
return llm | output_parser
elif return_type == "all":
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
raise ValueError(
f"Unrecognized return_type argument. Expected one of 'parsed' or "
f"'all'. Received: '{return_type}'"
)
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)

@ -0,0 +1,11 @@
from langchain_openai.output_parsers.tools import (
JsonOutputKeyToolsParser,
JsonOutputToolsParser,
PydanticToolsParser,
)
__all__ = [
"JsonOutputToolsParser",
"JsonOutputKeyToolsParser",
"PydanticToolsParser",
]

@ -0,0 +1,123 @@
import copy
import json
from json import JSONDecodeError
from typing import Any, List, Type
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.output_parsers.json import parse_partial_json
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel
class JsonOutputToolsParser(BaseGenerationOutputParser[Any]):
"""Parse tools from OpenAI response."""
strict: bool = False
"""Whether to allow non-JSON-compliant strings.
See: https://docs.python.org/3/library/json.html#encoders-and-decoders
Useful when the parsed output may include unicode characters or new lines.
"""
return_id: bool = False
"""Whether to return the tool call id."""
first_tool_only: bool = False
"""Whether to return only the first tool call.
If False, the result will be a list of tool calls, or an empty list
if no tool calls are found.
If true, and multiple tool calls are found, only the first one will be returned,
and the other tool calls will be ignored.
If no tool calls are found, None will be returned.
"""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
except KeyError:
return []
final_tools = []
exceptions = []
for tool_call in tool_calls:
if "function" not in tool_call:
continue
try:
if partial:
function_args = parse_partial_json(
tool_call["function"]["arguments"], strict=self.strict
)
else:
function_args = json.loads(
tool_call["function"]["arguments"], strict=self.strict
)
except JSONDecodeError as e:
exceptions.append(
f"Function {tool_call['function']['name']} arguments:\n\n"
f"{tool_call['function']['arguments']}\n\nare not valid JSON. "
f"Received JSONDecodeError {e}"
)
continue
parsed = {
"type": tool_call["function"]["name"],
"args": function_args,
}
if self.return_id:
parsed["id"] = tool_call["id"]
final_tools.append(parsed)
if exceptions:
raise OutputParserException("\n\n".join(exceptions))
if self.first_tool_only:
return final_tools[0] if final_tools else None
return final_tools
class JsonOutputKeyToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response."""
key_name: str
"""The type of tools to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
parsed_result = super().parse_result(result, partial=partial)
if self.first_tool_only:
single_result = (
parsed_result
if parsed_result and parsed_result["type"] == self.key_name
else None
)
if self.return_id:
return single_result
elif single_result:
return single_result["args"]
else:
return None
parsed_result = [res for res in parsed_result if res["type"] == self.key_name]
if not self.return_id:
parsed_result = [res["args"] for res in parsed_result]
return parsed_result
class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response."""
tools: List[Type[BaseModel]]
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
parsed_result = super().parse_result(result, partial=partial)
name_dict = {tool.__name__: tool for tool in self.tools}
if self.first_tool_only:
return (
name_dict[parsed_result["type"]](**parsed_result["args"])
if parsed_result
else None
)
return [name_dict[res["type"]](**res["args"]) for res in parsed_result]
Loading…
Cancel
Save