mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
cohere: add with_structured_output to ChatCohere (#19730)
**Description:** Adds support for `with_structured_output` to Cohere, which supports single function calling. --------- Co-authored-by: BeatrixCohere <128378696+BeatrixCohere@users.noreply.github.com>
This commit is contained in:
parent
0571f886d1
commit
ea57050122
@ -12,6 +12,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from cohere.types import NonStreamedChatResponse, ToolCall
|
from cohere.types import NonStreamedChatResponse, ToolCall
|
||||||
|
from langchain_core._api import beta
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
@ -30,12 +31,20 @@ from langchain_core.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
|
from langchain_core.output_parsers.base import OutputParserLike
|
||||||
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
|
JsonOutputKeyToolsParser,
|
||||||
|
PydanticToolsParser,
|
||||||
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.runnables import Runnable
|
from langchain_core.runnables import Runnable
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
from langchain_cohere.cohere_agent import _format_to_cohere_tools
|
from langchain_cohere.cohere_agent import (
|
||||||
|
_convert_to_cohere_tool,
|
||||||
|
_format_to_cohere_tools,
|
||||||
|
)
|
||||||
from langchain_cohere.llms import BaseCohere
|
from langchain_cohere.llms import BaseCohere
|
||||||
|
|
||||||
|
|
||||||
@ -165,6 +174,39 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
|||||||
formatted_tools = _format_to_cohere_tools(tools)
|
formatted_tools = _format_to_cohere_tools(tools)
|
||||||
return super().bind(tools=formatted_tools, **kwargs)
|
return super().bind(tools=formatted_tools, **kwargs)
|
||||||
|
|
||||||
|
@beta()
|
||||||
|
def with_structured_output(
|
||||||
|
self,
|
||||||
|
schema: Union[Dict, Type[BaseModel]],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||||
|
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
|
||||||
|
then the model output will be an object of that class. If a dict then
|
||||||
|
the model output will be a dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Runnable that takes any ChatModel input and returns either a dict or
|
||||||
|
Pydantic class as output.
|
||||||
|
"""
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||||
|
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
|
||||||
|
llm = self.bind_tools([schema])
|
||||||
|
if is_pydantic_schema:
|
||||||
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
|
tools=[schema], first_tool_only=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
key_name = _convert_to_cohere_tool(schema)["name"]
|
||||||
|
output_parser = JsonOutputKeyToolsParser(
|
||||||
|
key_name=key_name, first_tool_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return llm | output_parser
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
|
Loading…
Reference in New Issue
Block a user