cohere: add with_structured_output to ChatCohere (#19730)

**Description:** Adds support for `with_structured_output` to Cohere,
which supports single function calling.

---------

Co-authored-by: BeatrixCohere <128378696+BeatrixCohere@users.noreply.github.com>
This commit is contained in:
harry-cohere 2024-03-28 19:09:25 +00:00 committed by GitHub
parent 0571f886d1
commit ea57050122
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,6 +12,7 @@ from typing import (
) )
from cohere.types import NonStreamedChatResponse, ToolCall from cohere.types import NonStreamedChatResponse, ToolCall
from langchain_core._api import beta
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
@ -30,12 +31,20 @@ from langchain_core.messages import (
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
) )
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_cohere.cohere_agent import _format_to_cohere_tools from langchain_cohere.cohere_agent import (
_convert_to_cohere_tool,
_format_to_cohere_tools,
)
from langchain_cohere.llms import BaseCohere from langchain_cohere.llms import BaseCohere
@ -165,6 +174,39 @@ class ChatCohere(BaseChatModel, BaseCohere):
formatted_tools = _format_to_cohere_tools(tools) formatted_tools = _format_to_cohere_tools(tools)
return super().bind(tools=formatted_tools, **kwargs) return super().bind(tools=formatted_tools, **kwargs)
@beta()
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict.
Returns:
A Runnable that takes any ChatModel input and returns either a dict or
Pydantic class as output.
"""
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
llm = self.bind_tools([schema])
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = _convert_to_cohere_tool(schema)["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
return llm | output_parser
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""