core[minor]: BaseChatModel with_structured_output implementation (#22859)

pull/20491/head^2
Brace Sproul 2 weeks ago committed by GitHub
parent 360a70c8a8
commit abe7566d7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -5,6 +5,7 @@ import inspect
import uuid
import warnings
from abc import ABC, abstractmethod
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
@ -54,10 +55,13 @@ from langchain_core.outputs import (
)
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.utils.function_calling import convert_to_openai_tool
if TYPE_CHECKING:
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import BaseTool
@ -1024,6 +1028,140 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError()
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be. If
`method` is "function_calling" and `schema` is a dict, then the dict
must match the OpenAI function-calling spec.
include_raw: If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes any ChatModel input and returns as output:
If include_raw is True then a dict with keys:
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
If include_raw is False then just _DictOrPydantic is returned,
where _DictOrPydantic depends on the schema:
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class.
If schema is a dict then _DictOrPydantic is a dict.
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=False):
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatModel(model="model-name", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> AnswerWithJustification(
# answer='They weigh the same',
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
# )
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=True):
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatModel(model="model-name", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
# 'parsing_error': None
# }
Example: Function-calling, dict schema (method="function_calling", include_raw=False):
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_tool
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
dict_schema = convert_to_openai_tool(AnswerWithJustification)
llm = ChatModel(model="model-name", temperature=0)
structured_llm = llm.with_structured_output(dict_schema)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
# }
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
if self.bind_tools is BaseChatModel.bind_tools:
raise NotImplementedError(
"with_structured_output is not implemented for this model."
)
llm = self.bind_tools([schema], tool_choice="any")
if isinstance(schema, type) and issubclass(schema, BaseModel):
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
class SimpleChatModel(BaseChatModel):
"""Simplified implementation for a chat model to inherit from.

@ -334,7 +334,7 @@ class FakeStructuredOutputModel(BaseChatModel):
def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
return self | (lambda x: {"foo": self.foo})
return RunnableLambda(lambda x: {"foo": self.foo})
@property
def _llm_type(self) -> str:
@ -388,6 +388,3 @@ def test_fallbacks_getattr_runnable_output() -> None:
for fallback in llm_with_fallbacks_with_tools.fallbacks
)
assert llm_with_fallbacks_with_tools.runnable.kwargs["tools"] == []
with pytest.raises(NotImplementedError):
llm_with_fallbacks.with_structured_output({})

@ -6,7 +6,6 @@ from typing import (
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Type,
@ -14,7 +13,6 @@ from typing import (
TypeVar,
Union,
cast,
overload,
)
from langchain_community.chat_models.ollama import ChatOllama
@ -72,7 +70,6 @@ DEFAULT_RESPONSE_FUNCTION = {
}
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM]]
_DictOrPydantic = Union[Dict, _BM]
@ -151,33 +148,13 @@ class OllamaFunctions(ChatOllama):
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(functions=tools, **kwargs)
@overload
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
include_raw: Literal[True] = True,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _AllReturnType]:
...
@overload
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
include_raw: Literal[False] = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
...
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:

@ -135,6 +135,7 @@ class TestOllamaFunctions(unittest.TestCase):
structured_llm = model.with_structured_output(Joke, include_raw=True)
res = structured_llm.invoke("Tell me a joke about cars")
assert isinstance(res, dict)
assert "raw" in res
assert "parsed" in res
assert isinstance(res["raw"], AIMessage)

Loading…
Cancel
Save