community: Add MiniMaxChat bind_tools and structured output (#24310)

- **Description:** 
  - Add `bind_tools` method to support tool calling 
  - Add `with_structured_output` method to support structured output
This commit is contained in:
maang-h 2024-07-30 03:51:52 +08:00 committed by GitHub
parent 0a2ff40fcc
commit 4bb1a11e02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 334 additions and 4 deletions

View File

@ -3,12 +3,25 @@
import json import json
import logging import logging
from contextlib import asynccontextmanager, contextmanager from contextlib import asynccontextmanager, contextmanager
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Type, Union from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import ( from langchain_core.language_models.chat_models import (
BaseChatModel, BaseChatModel,
agenerate_from_stream, agenerate_from_stream,
@ -23,10 +36,19 @@ from langchain_core.messages import (
ChatMessageChunk, ChatMessageChunk,
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
ToolMessage,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -77,9 +99,20 @@ def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]:
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content} message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content} message_dict = {
"role": "assistant",
"content": message.content,
"tool_calls": message.additional_kwargs.get("tool_calls"),
}
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id,
"name": message.name or message.additional_kwargs.get("name"),
}
else: else:
raise TypeError(f"Got unknown type '{message.__class__.__name__}'.") raise TypeError(f"Got unknown type '{message.__class__.__name__}'.")
return message_dict return message_dict
@ -230,6 +263,70 @@ class MiniMaxChat(BaseChatModel):
id='run-c263b6f1-1736-4ece-a895-055c26b3436f-0' id='run-c263b6f1-1736-4ece-a895-055c26b3436f-0'
) )
Tool calling:
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
class GetPopulation(BaseModel):
'''Get the current population in a given location'''
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
chat_with_tools = chat.bind_tools([GetWeather, GetPopulation])
ai_msg = chat_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
)
ai_msg.tool_calls
.. code-block:: python
[
{
'name': 'GetWeather',
'args': {'location': 'LA'},
'id': 'call_function_2140449382',
'type': 'tool_call'
}
]
Structured output:
.. code-block:: python
from typing import Optional
from langchain_core.pydantic_v1 import BaseModel, Field
class Joke(BaseModel):
'''Joke to tell user.'''
setup: str = Field(description="The setup of the joke")
punchline: str = Field(description="The punchline to the joke")
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
structured_chat = chat.with_structured_output(Joke)
structured_chat.invoke("Tell me a joke about cats")
.. code-block:: python
Joke(
setup='Why do cats have nine lives?',
punchline='Because they are so cute and cuddly!',
rating=None
)
Response metadata Response metadata
.. code-block:: python .. code-block:: python
@ -242,7 +339,7 @@ class MiniMaxChat(BaseChatModel):
'model_name': 'abab6.5-chat', 'model_name': 'abab6.5-chat',
'finish_reason': 'stop'} 'finish_reason': 'stop'}
""" # noqa: E501conj """ # noqa: E501
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> Dict[str, Any]:
@ -342,12 +439,26 @@ class MiniMaxChat(BaseChatModel):
message_dicts = [_convert_message_to_dict(m) for m in messages] message_dicts = [_convert_message_to_dict(m) for m in messages]
payload = self._default_params payload = self._default_params
payload["messages"] = message_dicts payload["messages"] = message_dicts
self._reformat_function_parameters(kwargs.get("tools", {}))
payload.update(**kwargs) payload.update(**kwargs)
if is_stream: if is_stream:
payload["stream"] = True payload["stream"] = True
return payload return payload
@staticmethod
def _reformat_function_parameters(tools_arg: Dict[Any, Any]) -> None:
"""Reformat the function parameters to strings."""
for tool_arg in tools_arg:
if tool_arg["type"] == "function" and not isinstance(
tool_arg["function"]["parameters"], str
):
tool_arg["function"]["parameters"] = json.dumps(
tool_arg["function"]["parameters"]
)
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -521,3 +632,154 @@ class MiniMaxChat(BaseChatModel):
if finish_reason is not None: if finish_reason is not None:
break break
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to pass to the
:class: `~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be. If
`method` is "function_calling" and `schema` is a dict, then the dict
must match the OpenAI function-calling spec.
include_raw: If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes any ChatModel input and returns as output:
If include_raw is True then a dict with keys:
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
If include_raw is False then just _DictOrPydantic is returned,
where _DictOrPydantic depends on the schema:
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class.
If schema is a dict then _DictOrPydantic is a dict.
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=False):
.. code-block:: python
from langchain_community.chat_models import MiniMaxChat
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = MiniMaxChat()
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> AnswerWithJustification(
# answer='A pound of bricks and a pound of feathers weigh the same.',
# justification='The weight of the feathers is much less dense than the weight of the bricks, but since both weigh one pound, they weigh the same.'
# )
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=True):
.. code-block:: python
from langchain_community.chat_models import MiniMaxChat
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = MiniMaxChat()
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_function_8953642285', 'type': 'function', 'function': {'name': 'AnswerWithJustification', 'arguments': '{"answer": "A pound of bricks and a pound of feathers weigh the same.", "justification": "The weight of the feathers is much less dense than the weight of the bricks, but since both weigh one pound, they weigh the same."}'}}]}, response_metadata={'token_usage': {'total_tokens': 257}, 'model_name': 'abab6.5-chat', 'finish_reason': 'tool_calls'}, id='run-d897e037-2796-49f5-847e-f9f69dd390db-0', tool_calls=[{'name': 'AnswerWithJustification', 'args': {'answer': 'A pound of bricks and a pound of feathers weigh the same.', 'justification': 'The weight of the feathers is much less dense than the weight of the bricks, but since both weigh one pound, they weigh the same.'}, 'id': 'call_function_8953642285', 'type': 'tool_call'}]),
# 'parsed': AnswerWithJustification(answer='A pound of bricks and a pound of feathers weigh the same.', justification='The weight of the feathers is much less dense than the weight of the bricks, but since both weigh one pound, they weigh the same.'),
# 'parsing_error': None
# }
Example: Function-calling, dict schema (method="function_calling", include_raw=False):
.. code-block:: python
from langchain_community.chat_models import MiniMaxChat
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_tool
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
dict_schema = convert_to_openai_tool(AnswerWithJustification)
llm = MiniMaxChat()
structured_llm = llm.with_structured_output(dict_schema)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'answer': 'A pound of bricks and a pound of feathers both weigh the same, which is a pound.',
# 'justification': 'The difference is that bricks are much denser than feathers, so a pound of bricks will take up much less space than a pound of feathers.'
# }
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
llm = self.bind_tools([schema])
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item]
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser

View File

@ -1,6 +1,8 @@
import os import os
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import tool
from langchain_community.chat_models import MiniMaxChat from langchain_community.chat_models import MiniMaxChat
@ -19,3 +21,69 @@ def test_chat_minimax_with_stream() -> None:
for chunk in chat.stream("你好呀"): for chunk in chat.stream("你好呀"):
assert isinstance(chunk, AIMessage) assert isinstance(chunk, AIMessage)
assert isinstance(chunk.content, str) assert isinstance(chunk.content, str)
@tool
def add(a: int, b: int) -> int:
"""Adds a and b."""
return a + b
@tool
def multiply(a: int, b: int) -> int:
"""Multiplies a and b."""
return a * b
def test_chat_minimax_with_tool() -> None:
"""Test MinimaxChat with bind tools."""
chat = MiniMaxChat() # type: ignore[call-arg]
tools = [add, multiply]
chat_with_tools = chat.bind_tools(tools)
query = "What is 3 * 12?"
messages = [HumanMessage(query)]
ai_msg = chat_with_tools.invoke(messages)
assert isinstance(ai_msg, AIMessage)
assert isinstance(ai_msg.tool_calls, list)
assert len(ai_msg.tool_calls) == 1
tool_call = ai_msg.tool_calls[0]
assert "args" in tool_call
messages.append(ai_msg) # type: ignore[arg-type]
for tool_call in ai_msg.tool_calls:
selected_tool = {"add": add, "multiply": multiply}[tool_call["name"].lower()]
tool_output = selected_tool.invoke(tool_call["args"]) # type: ignore[attr-defined]
messages.append(ToolMessage(tool_output, tool_call_id=tool_call["id"])) # type: ignore[arg-type]
response = chat_with_tools.invoke(messages)
assert isinstance(response, AIMessage)
class AnswerWithJustification(BaseModel):
"""An answer to the user question along with justification for the answer."""
answer: str
justification: str
def test_chat_minimax_with_structured_output() -> None:
"""Test MiniMaxChat with structured output."""
llm = MiniMaxChat() # type: ignore
structured_llm = llm.with_structured_output(AnswerWithJustification)
response = structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
assert isinstance(response, AnswerWithJustification)
def test_chat_tongyi_with_structured_output_include_raw() -> None:
"""Test MiniMaxChat with structured output."""
llm = MiniMaxChat() # type: ignore
structured_llm = llm.with_structured_output(
AnswerWithJustification, include_raw=True
)
response = structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
assert isinstance(response, dict)
assert isinstance(response.get("raw"), AIMessage)
assert isinstance(response.get("parsed"), AnswerWithJustification)