community: Implement `bind_tools` for ChatTongyi (#20725)

## Description

Implement `bind_tools` in ChatTongyi. Usage example:

```py
from langchain_core.tools import tool
from langchain_community.chat_models.tongyi import ChatTongyi

@tool
def multiply(first_int: int, second_int: int) -> int:
    """Multiply two integers together."""
    return first_int * second_int

llm = ChatTongyi(model="qwen-turbo")

llm_with_tools = llm.bind_tools([multiply])

msg = llm_with_tools.invoke("What's 5 times forty two")

print(msg)
```

Streaming is also supported.

## Dependencies

No Dependency is required for this change.

---------

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
pull/21765/head
Cheese 1 month ago committed by GitHub
parent b216a1dddb
commit 0ead09f84d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -26,14 +26,22 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"# Install the package\n",
"%pip install --upgrade --quiet dashscope"
@ -48,15 +56,7 @@
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ········\n"
]
}
],
"outputs": [],
"source": [
"# Get a new token: https://help.aliyun.com/document_detail/611472.html?spm=a2c4g.2399481.0.0\n",
"from getpass import getpass\n",
@ -94,8 +94,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"chat resp: content='Hello! How' additional_kwargs={} example=False\n",
"chat resp: content=' can I assist you today?' additional_kwargs={} example=False\n"
"chat resp: content='Hello' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
"chat resp: content='!' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
"chat resp: content=' How' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
"chat resp: content=' can I assist you today' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
"chat resp: content='?' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
"chat resp: content='' response_metadata={'finish_reason': 'stop', 'request_id': '921db2c5-4d53-9a89-8e87-e4ad6a671237', 'token_usage': {'input_tokens': 20, 'output_tokens': 9, 'total_tokens': 29}} id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n"
]
}
],
@ -116,10 +120,18 @@
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/cheese/PARA/Projects/langchain-contribution/langchain/libs/core/langchain_core/_api/deprecation.py:119: LangChainDeprecationWarning: The method `BaseChatModel.__call__` was deprecated in langchain-core 0.1.7 and will be removed in 0.2.0. Use invoke instead.\n",
" warn_deprecated(\n"
]
},
{
"data": {
"text/plain": [
"AIMessageChunk(content=\"J'aime programmer.\", additional_kwargs={}, example=False)"
"AIMessage(content=\"J'adore programmer.\", response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'stop', 'request_id': 'ae725086-0ffa-9728-8c72-b204c7bc7eeb', 'token_usage': {'input_tokens': 36, 'output_tokens': 6, 'total_tokens': 42}}, id='run-060cc103-ef5f-4c8a-af40-792ac7f40c26-0')"
]
},
"execution_count": 5,
@ -149,18 +161,65 @@
"ChatTongyi supports tool calling API that lets you describe tools and their arguments, and have the model return a JSON object with a tool to invoke and the inputs to that tool."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Use with `bind_tools`"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"content='' additional_kwargs={'tool_calls': [{'function': {'name': 'multiply', 'arguments': '{\"first_int\": 5, \"second_int\": 42}'}, 'id': '', 'type': 'function'}]} response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'tool_calls', 'request_id': '4acf0e36-44af-987a-a0c0-8b5c5eaa1a8b', 'token_usage': {'input_tokens': 200, 'output_tokens': 25, 'total_tokens': 225}} id='run-0ecd0f09-1d20-4e55-a4f3-f14d1f710ae7-0' tool_calls=[{'name': 'multiply', 'args': {'first_int': 5, 'second_int': 42}, 'id': ''}]\n"
]
}
],
"source": [
"from langchain_community.chat_models.tongyi import ChatTongyi\n",
"from langchain_core.tools import tool\n",
"\n",
"\n",
"@tool\n",
"def multiply(first_int: int, second_int: int) -> int:\n",
" \"\"\"Multiply two integers together.\"\"\"\n",
" return first_int * second_int\n",
"\n",
"\n",
"llm = ChatTongyi(model=\"qwen-turbo\")\n",
"\n",
"llm_with_tools = llm.bind_tools([multiply])\n",
"\n",
"msg = llm_with_tools.invoke(\"What's 5 times forty two\")\n",
"\n",
"print(msg)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Construct args manually"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='', additional_kwargs={'tool_calls': [{'function': {'name': 'get_current_weather', 'arguments': '{\"location\": \"San Francisco\"}'}, 'id': '', 'type': 'function'}]}, response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'tool_calls', 'request_id': 'dae79197-8780-9b7e-8c15-6a83e2a53534', 'token_usage': {'input_tokens': 229, 'output_tokens': 19, 'total_tokens': 248}}, id='run-9e06f837-582b-473b-bb1f-5e99a68ecc10-0', tool_calls=[{'name': 'get_current_weather', 'args': {'location': 'San Francisco'}, 'id': ''}])"
"AIMessage(content='', additional_kwargs={'tool_calls': [{'function': {'name': 'get_current_weather', 'arguments': '{\"location\": \"San Francisco\"}'}, 'id': '', 'type': 'function'}]}, response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'tool_calls', 'request_id': '87ef33d2-5c6b-9457-91e2-39faad7120eb', 'token_usage': {'input_tokens': 229, 'output_tokens': 19, 'total_tokens': 248}}, id='run-7939ba7f-e3f7-46f8-980b-30499b52723c-0', tool_calls=[{'name': 'get_current_weather', 'args': {'location': 'San Francisco'}, 'id': ''}])"
]
},
"execution_count": 5,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@ -224,7 +283,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.12.2"
}
},
"nbformat": 4,

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import functools
import json
import logging
from typing import (
Any,
@ -12,6 +13,8 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Type,
Union,
cast,
)
@ -20,6 +23,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
@ -32,6 +36,8 @@ from langchain_core.messages import (
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
ToolMessageChunk,
)
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
@ -42,8 +48,11 @@ from langchain_core.outputs import (
ChatGenerationChunk,
ChatResult,
)
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from requests.exceptions import HTTPError
from tenacity import (
before_sleep_log,
@ -68,6 +77,7 @@ def convert_dict_to_message(
"""Convert a dict to a message."""
role = _dict["role"]
content = _dict["content"]
if role == "user":
return (
HumanMessageChunk(content=content)
@ -79,17 +89,39 @@ def convert_dict_to_message(
invalid_tool_calls = []
if "tool_calls" in _dict:
additional_kwargs = {"tool_calls": _dict["tool_calls"]}
for raw_tool_call in _dict["tool_calls"]:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
for index, value in enumerate(_dict["tool_calls"]):
if is_chunk:
try:
tool_calls.append(
{
"name": value["function"].get("name"),
"args": value["function"].get("arguments"),
"id": value.get("id"),
# Tongyi does not respond with index,
# use index in the list instead
"index": index,
}
)
except KeyError:
pass
else:
try:
parsed_tool = parse_tool_call(value, return_id=True)
if parsed_tool:
tool_calls.append(parsed_tool)
except Exception as e:
invalid_tool_calls.append(make_invalid_tool_call(value, str(e)))
else:
additional_kwargs = {}
return (
AIMessageChunk(content=content)
AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_calls,
id=_dict.get("id"),
)
if is_chunk
else AIMessage(
content=content,
@ -104,6 +136,23 @@ def convert_dict_to_message(
if is_chunk
else SystemMessage(content=content)
)
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return (
ToolMessageChunk(
content=_dict.get("content", ""),
tool_call_id=_dict.get("tool_call_id"),
additional_kwargs=additional_kwargs,
)
if is_chunk
else ToolMessage(
content=_dict.get("content", ""),
tool_call_id=_dict.get("tool_call_id"),
additional_kwargs=additional_kwargs,
)
)
else:
return (
ChatMessageChunk(role=role, content=content)
@ -113,17 +162,23 @@ def convert_dict_to_message(
def convert_message_chunk_to_message(message_chunk: BaseMessageChunk) -> BaseMessage:
"""Convert a message chunk to a message."""
if isinstance(message_chunk, HumanMessageChunk):
return HumanMessage(content=message_chunk.content)
elif isinstance(message_chunk, AIMessageChunk):
return AIMessage(content=message_chunk.content)
elif isinstance(message_chunk, SystemMessageChunk):
return SystemMessage(content=message_chunk.content)
elif isinstance(message_chunk, ChatMessageChunk):
return ChatMessage(role=message_chunk.role, content=message_chunk.content)
else:
raise TypeError(f"Got unknown type {message_chunk}")
"""Convert a message chunk to a message.
Args:
chunk: Message chunk to convert.
Returns:
Message.
"""
if not isinstance(message_chunk, BaseMessageChunk):
return message_chunk
# chunk classes always have the equivalent non-chunk class as their first parent
ignore_keys = ["type"]
if isinstance(message_chunk, AIMessageChunk):
ignore_keys.append("tool_call_chunks")
return message_chunk.__class__.__mro__[1](
**{k: v for k, v in message_chunk.__dict__.items() if k not in ignore_keys}
)
def convert_message_to_dict(message: BaseMessage) -> dict:
@ -136,8 +191,17 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"tool_call_id": message.tool_call_id,
"content": message.content,
"name": message.name,
}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
@ -256,11 +320,57 @@ class ChatTongyi(BaseChatModel):
@retry_decorator
def _stream_completion_with_retry(**_kwargs: Any) -> Any:
responses = self.client.call(**_kwargs)
prev_resp = None
for resp in responses:
yield check_response(resp)
# If we are streaming without `incremental_output = True`,
# we need to calculate the delta response manually
if _kwargs.get("stream") and not _kwargs.get(
"incremental_output", False
):
if prev_resp is None:
delta_resp = resp
else:
delta_resp = self.subtract_client_response(resp, prev_resp)
prev_resp = resp
yield check_response(delta_resp)
else:
yield check_response(resp)
return _stream_completion_with_retry(**kwargs)
def subtract_client_response(self, resp: Any, prev_resp: Any) -> Any:
"""Subtract prev response from curr response.
Useful when streaming without `incremental_output = True`
"""
resp_copy = json.loads(json.dumps(resp))
choice = resp_copy["output"]["choices"][0]
message = choice["message"]
prev_resp_copy = json.loads(json.dumps(prev_resp))
prev_choice = prev_resp_copy["output"]["choices"][0]
prev_message = prev_choice["message"]
message["content"] = message["content"].replace(prev_message["content"], "")
if message.get("tool_calls"):
for index, tool_call in enumerate(message["tool_calls"]):
function = tool_call["function"]
if prev_message.get("tool_calls"):
prev_function = prev_message["tool_calls"][index]["function"]
function["name"] = function["name"].replace(
prev_function["name"], ""
)
function["arguments"] = function["arguments"].replace(
prev_function["arguments"], ""
)
return resp_copy
async def astream_completion_with_retry(self, **kwargs: Any) -> Any:
"""Because the dashscope SDK doesn't provide an async API,
we wrap `stream_generate_with_retry` with an async generator."""
@ -301,16 +411,16 @@ class ChatTongyi(BaseChatModel):
) -> ChatResult:
generations = []
if self.streaming:
generation: Optional[ChatGenerationChunk] = None
generation_chunk: Optional[ChatGenerationChunk] = None
for chunk in self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
if generation is None:
generation = chunk
if generation_chunk is None:
generation_chunk = chunk
else:
generation += chunk
assert generation is not None
generations.append(self._chunk_to_generation(generation))
generation_chunk += chunk
assert generation_chunk is not None
generations.append(self._chunk_to_generation(generation_chunk))
else:
params: Dict[str, Any] = self._invocation_params(
messages=messages, stop=stop, **kwargs
@ -373,9 +483,19 @@ class ChatTongyi(BaseChatModel):
params: Dict[str, Any] = self._invocation_params(
messages=messages, stop=stop, stream=True, **kwargs
)
for stream_resp, is_last_chunk in generate_with_last_element_mark(
self.stream_completion_with_retry(**params)
):
choice = stream_resp["output"]["choices"][0]
message = choice["message"]
if (
choice["finish_reason"] == "null"
and message["content"] == ""
and "tool_calls" not in message
):
continue
chunk = ChatGenerationChunk(
**self._chat_generation_from_qwen_resp(
stream_resp, is_chunk=True, is_last_chunk=is_last_chunk
@ -413,14 +533,13 @@ class ChatTongyi(BaseChatModel):
params = {**self._default_params, **kwargs}
if stop is not None:
params["stop"] = stop
if params.get("stream"):
# According to the Tongyi official docs,
# `incremental_output` with `tools` is not supported yet
if params.get("stream") and not params.get("tools"):
params["incremental_output"] = True
message_dicts = [convert_message_to_dict(m) for m in messages]
# According to the docs, the last message should be a `user` message
if message_dicts[-1]["role"] != "user":
raise ValueError("Last message should be user message.")
# And the `system` message should be the first message if present
system_message_indices = [
i for i, m in enumerate(message_dicts) if m["role"] == "system"
@ -470,3 +589,22 @@ class ChatTongyi(BaseChatModel):
message=convert_message_chunk_to_message(chunk.message),
generation_info=chunk.generation_info,
)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)

@ -55,17 +55,17 @@ def _create_retry_decorator(llm: Tongyi) -> Callable[[Any], Any]:
def check_response(resp: Any) -> Any:
"""Check the response from the completion call."""
if resp.status_code == 200:
if resp["status_code"] == 200:
return resp
elif resp.status_code in [400, 401]:
elif resp["status_code"] in [400, 401]:
raise ValueError(
f"status_code: {resp.status_code} \n "
f"code: {resp.code} \n message: {resp.message}"
f"status_code: {resp['status_code']} \n "
f"code: {resp['code']} \n message: {resp['message']}"
)
else:
raise HTTPError(
f"HTTP error occurred: status_code: {resp.status_code} \n "
f"code: {resp.code} \n message: {resp.message}",
f"HTTP error occurred: status_code: {resp['status_code']} \n "
f"code: {resp['code']} \n message: {resp['message']}",
response=resp,
)

@ -1,11 +1,14 @@
"""Test Alibaba Tongyi Chat Model."""
from typing import Any, cast
from typing import Any, List, cast
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.pydantic_v1 import SecretStr
from langchain_core.pydantic_v1 import BaseModel, SecretStr
from pytest import CaptureFixture
from langchain_community.chat_models.tongyi import ChatTongyi
@ -138,3 +141,76 @@ def test_multiple_messages() -> None:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
class GenerateUsername(BaseModel):
"Get a username based on someone's name and hair color."
name: str
hair_color: str
def test_tool_use() -> None:
llm = ChatTongyi(model="qwen-turbo", temperature=0)
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
msgs: List = [HumanMessage("Sally has green hair, what would her username be?")]
ai_msg = llm_with_tool.invoke(msgs)
# assert ai_msg is None
# ai_msg.content = " "
assert isinstance(ai_msg, AIMessage)
assert isinstance(ai_msg.tool_calls, list)
assert len(ai_msg.tool_calls) == 1
tool_call = ai_msg.tool_calls[0]
assert "args" in tool_call
tool_msg = ToolMessage(
"sally_green_hair",
tool_call_id=ai_msg.tool_calls[0]["id"],
name=ai_msg.tool_calls[0]["name"],
)
msgs.extend([ai_msg, tool_msg])
llm_with_tool.invoke(msgs)
# Test streaming
ai_messages = llm_with_tool.stream(msgs)
first = True
for message in ai_messages:
if first:
gathered = message
first = False
else:
gathered = gathered + message # type: ignore
assert isinstance(gathered, AIMessageChunk)
streaming_tool_msg = ToolMessage(
"sally_green_hair",
name=tool_call["name"],
tool_call_id=tool_call["id"] if tool_call["id"] else " ",
)
msgs.extend([gathered, streaming_tool_msg])
llm_with_tool.invoke(msgs)
def test_manual_tool_call_msg() -> None:
"""Test passing in manually construct tool call message."""
llm = ChatTongyi(model="qwen-turbo", temperature=0)
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
msgs: List = [
HumanMessage("Sally has green hair, what would her username be?"),
AIMessage(
content=" ",
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="foo",
)
],
),
ToolMessage("sally_green_hair", tool_call_id="foo"),
]
output: AIMessage = cast(AIMessage, llm_with_tool.invoke(msgs))
assert output.content
# Should not have called the tool again.
assert not output.tool_calls and not output.invalid_tool_calls

Loading…
Cancel
Save