@ -12,6 +12,7 @@ from typing import (
)
)
from cohere . types import NonStreamedChatResponse , ToolCall
from cohere . types import NonStreamedChatResponse , ToolCall
from langchain_core . _api import beta
from langchain_core . callbacks import (
from langchain_core . callbacks import (
AsyncCallbackManagerForLLMRun ,
AsyncCallbackManagerForLLMRun ,
CallbackManagerForLLMRun ,
CallbackManagerForLLMRun ,
@ -30,12 +31,20 @@ from langchain_core.messages import (
HumanMessage ,
HumanMessage ,
SystemMessage ,
SystemMessage ,
)
)
from langchain_core . output_parsers . base import OutputParserLike
from langchain_core . output_parsers . openai_tools import (
JsonOutputKeyToolsParser ,
PydanticToolsParser ,
)
from langchain_core . outputs import ChatGeneration , ChatGenerationChunk , ChatResult
from langchain_core . outputs import ChatGeneration , ChatGenerationChunk , ChatResult
from langchain_core . pydantic_v1 import BaseModel
from langchain_core . pydantic_v1 import BaseModel
from langchain_core . runnables import Runnable
from langchain_core . runnables import Runnable
from langchain_core . tools import BaseTool
from langchain_core . tools import BaseTool
from langchain_cohere . cohere_agent import _format_to_cohere_tools
from langchain_cohere . cohere_agent import (
_convert_to_cohere_tool ,
_format_to_cohere_tools ,
)
from langchain_cohere . llms import BaseCohere
from langchain_cohere . llms import BaseCohere
@ -165,6 +174,39 @@ class ChatCohere(BaseChatModel, BaseCohere):
formatted_tools = _format_to_cohere_tools ( tools )
formatted_tools = _format_to_cohere_tools ( tools )
return super ( ) . bind ( tools = formatted_tools , * * kwargs )
return super ( ) . bind ( tools = formatted_tools , * * kwargs )
@beta ( )
def with_structured_output (
self ,
schema : Union [ Dict , Type [ BaseModel ] ] ,
* * kwargs : Any ,
) - > Runnable [ LanguageModelInput , Union [ Dict , BaseModel ] ] :
""" Model wrapper that returns outputs formatted to match the given schema.
Args :
schema : The output schema as a dict or a Pydantic class . If a Pydantic class
then the model output will be an object of that class . If a dict then
the model output will be a dict .
Returns :
A Runnable that takes any ChatModel input and returns either a dict or
Pydantic class as output .
"""
if kwargs :
raise ValueError ( f " Received unsupported arguments { kwargs } " )
is_pydantic_schema = isinstance ( schema , type ) and issubclass ( schema , BaseModel )
llm = self . bind_tools ( [ schema ] )
if is_pydantic_schema :
output_parser : OutputParserLike = PydanticToolsParser (
tools = [ schema ] , first_tool_only = True
)
else :
key_name = _convert_to_cohere_tool ( schema ) [ " name " ]
output_parser = JsonOutputKeyToolsParser (
key_name = key_name , first_tool_only = True
)
return llm | output_parser
@property
@property
def _identifying_params ( self ) - > Dict [ str , Any ] :
def _identifying_params ( self ) - > Dict [ str , Any ] :
""" Get the identifying parameters. """
""" Get the identifying parameters. """