From ea570501226514103b6d47e3db20e5eb158ba746 Mon Sep 17 00:00:00 2001 From: harry-cohere <127103098+harry-cohere@users.noreply.github.com> Date: Thu, 28 Mar 2024 19:09:25 +0000 Subject: [PATCH] cohere: add with_structured_output to ChatCohere (#19730) **Description:** Adds support for `with_structured_output` to Cohere, which supports single function calling. --------- Co-authored-by: BeatrixCohere <128378696+BeatrixCohere@users.noreply.github.com> --- .../cohere/langchain_cohere/chat_models.py | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/libs/partners/cohere/langchain_cohere/chat_models.py b/libs/partners/cohere/langchain_cohere/chat_models.py index 004973fb93..ce2ce6a568 100644 --- a/libs/partners/cohere/langchain_cohere/chat_models.py +++ b/libs/partners/cohere/langchain_cohere/chat_models.py @@ -12,6 +12,7 @@ from typing import ( ) from cohere.types import NonStreamedChatResponse, ToolCall +from langchain_core._api import beta from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -30,12 +31,20 @@ from langchain_core.messages import ( HumanMessage, SystemMessage, ) +from langchain_core.output_parsers.base import OutputParserLike +from langchain_core.output_parsers.openai_tools import ( + JsonOutputKeyToolsParser, + PydanticToolsParser, +) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool -from langchain_cohere.cohere_agent import _format_to_cohere_tools +from langchain_cohere.cohere_agent import ( + _convert_to_cohere_tool, + _format_to_cohere_tools, +) from langchain_cohere.llms import BaseCohere @@ -165,6 +174,39 @@ class ChatCohere(BaseChatModel, BaseCohere): formatted_tools = _format_to_cohere_tools(tools) return super().bind(tools=formatted_tools, **kwargs) + @beta() + def with_structured_output( + self, + schema: Union[Dict, Type[BaseModel]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + """Model wrapper that returns outputs formatted to match the given schema. + + Args: + schema: The output schema as a dict or a Pydantic class. If a Pydantic class + then the model output will be an object of that class. If a dict then + the model output will be a dict. + + Returns: + A Runnable that takes any ChatModel input and returns either a dict or + Pydantic class as output. + """ + if kwargs: + raise ValueError(f"Received unsupported arguments {kwargs}") + is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel) + llm = self.bind_tools([schema]) + if is_pydantic_schema: + output_parser: OutputParserLike = PydanticToolsParser( + tools=[schema], first_tool_only=True + ) + else: + key_name = _convert_to_cohere_tool(schema)["name"] + output_parser = JsonOutputKeyToolsParser( + key_name=key_name, first_tool_only=True + ) + + return llm | output_parser + @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters."""