community[minor]: add tool calling for DeepInfraChat (#22745)

DeepInfra now supports tool calling for supported models.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Oguz Vuruskaner 2024-06-17 12:21:49 -07:00 committed by GitHub
parent 158701ab3c
commit dd25d08c06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 211 additions and 17 deletions

View File

@ -98,6 +98,78 @@
")\n", ")\n",
"chat.invoke(messages)" "chat.invoke(messages)"
] ]
},
{
"cell_type": "markdown",
"id": "466c3cb41ace1410",
"metadata": {},
"source": [
"# Tool Calling\n",
"\n",
"DeepInfra currently supports only invoke and async invoke tool calling.\n",
"\n",
"For a complete list of models that support tool calling, please refer to our [tool calling documentation](https://deepinfra.com/docs/advanced/function_calling)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ddc4f4299763651c",
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"\n",
"from dotenv import find_dotenv, load_dotenv\n",
"from langchain_community.chat_models import ChatDeepInfra\n",
"from langchain_core.messages import HumanMessage\n",
"from langchain_core.pydantic_v1 import BaseModel\n",
"from langchain_core.tools import tool\n",
"\n",
"model_name = \"meta-llama/Meta-Llama-3-70B-Instruct\"\n",
"\n",
"_ = load_dotenv(find_dotenv())\n",
"\n",
"\n",
"# Langchain tool\n",
"@tool\n",
"def foo(something):\n",
" \"\"\"\n",
" Called when foo\n",
" \"\"\"\n",
" pass\n",
"\n",
"\n",
"# Pydantic class\n",
"class Bar(BaseModel):\n",
" \"\"\"\n",
" Called when Bar\n",
" \"\"\"\n",
"\n",
" pass\n",
"\n",
"\n",
"llm = ChatDeepInfra(model=model_name)\n",
"tools = [foo, Bar]\n",
"llm_with_tools = llm.bind_tools(tools)\n",
"messages = [\n",
" HumanMessage(\"Foo and bar, please.\"),\n",
"]\n",
"\n",
"response = llm_with_tools.invoke(messages)\n",
"print(response.tool_calls)\n",
"# [{'name': 'foo', 'args': {'something': None}, 'id': 'call_Mi4N4wAtW89OlbizFE1aDxDj'}, {'name': 'Bar', 'args': {}, 'id': 'call_daiE0mW454j2O1KVbmET4s2r'}]\n",
"\n",
"\n",
"async def call_ainvoke():\n",
" result = await llm_with_tools.ainvoke(messages)\n",
" print(result.tool_calls)\n",
"\n",
"\n",
"# Async call\n",
"asyncio.run(call_ainvoke())\n",
"# [{'name': 'foo', 'args': {'something': None}, 'id': 'call_ZH7FetmgSot4LHcMU6CEb8tI'}, {'name': 'Bar', 'args': {}, 'id': 'call_2MQhDifAJVoijZEvH8PeFSVB'}]"
]
} }
], ],
"metadata": { "metadata": {

View File

@ -13,6 +13,7 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Sequence,
Tuple, Tuple,
Type, Type,
Union, Union,
@ -24,6 +25,7 @@ from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import ( from langchain_core.language_models.chat_models import (
BaseChatModel, BaseChatModel,
agenerate_from_stream, agenerate_from_stream,
@ -44,15 +46,18 @@ from langchain_core.messages import (
SystemMessage, SystemMessage,
SystemMessageChunk, SystemMessageChunk,
) )
from langchain_core.messages.tool import ToolCall
from langchain_core.outputs import ( from langchain_core.outputs import (
ChatGeneration, ChatGeneration,
ChatGenerationChunk, ChatGenerationChunk,
ChatResult, ChatResult,
) )
from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import get_from_dict_or_env from langchain_core.utils import get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
# from langchain.llms.base import create_base_retry_decorator
from langchain_community.utilities.requests import Requests from langchain_community.utilities.requests import Requests
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -78,19 +83,51 @@ def _create_retry_decorator(
) )
def _parse_tool_calling(tool_call: dict) -> ToolCall:
"""
Convert a tool calling response from server to a ToolCall object.
Args:
tool_call:
Returns:
"""
name = tool_call.get("name", "")
args = json.loads(tool_call["function"]["arguments"])
id = tool_call.get("id")
return ToolCall(name=name, args=args, id=id)
def _convert_to_tool_calling(tool_call: ToolCall) -> Dict[str, Any]:
"""
Convert a ToolCall object to a tool calling request for server.
Args:
tool_call:
Returns:
"""
return {
"type": "function",
"function": {
"arguments": json.dumps(tool_call["args"]),
"name": tool_call["name"],
},
"id": tool_call.get("id"),
}
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"] role = _dict["role"]
if role == "user": if role == "user":
return HumanMessage(content=_dict["content"]) return HumanMessage(content=_dict["content"])
elif role == "assistant": elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or "" content = _dict.get("content", "") or ""
if _dict.get("function_call"): tool_calls_content = _dict.get("tool_calls", []) or []
additional_kwargs = {"function_call": dict(_dict["function_call"])} tool_calls = [
else: _parse_tool_calling(tool_call) for tool_call in tool_calls_content
additional_kwargs = {} ]
return AIMessage(content=content, additional_kwargs=additional_kwargs) return AIMessage(content=content, tool_calls=tool_calls)
elif role == "system": elif role == "system":
return SystemMessage(content=_dict["content"]) return SystemMessage(content=_dict["content"])
elif role == "function": elif role == "function":
@ -104,15 +141,14 @@ def _convert_delta_to_message_chunk(
) -> BaseMessageChunk: ) -> BaseMessageChunk:
role = _dict.get("role") role = _dict.get("role")
content = _dict.get("content") or "" content = _dict.get("content") or ""
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
else:
additional_kwargs = {}
if role == "user" or default_class == HumanMessageChunk: if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content) return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk: elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) tool_calls = [
_parse_tool_calling(tool_call) for tool_call in _dict.get("tool_calls", [])
]
return AIMessageChunk(content=content, tool_calls=tool_calls)
elif role == "system" or default_class == SystemMessageChunk: elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content) return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk: elif role == "function" or default_class == FunctionMessageChunk:
@ -129,9 +165,14 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
elif isinstance(message, HumanMessage): elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content} message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content} tool_calls = [
if "function_call" in message.additional_kwargs: _convert_to_tool_calling(tool_call) for tool_call in message.tool_calls
message_dict["function_call"] = message.additional_kwargs["function_call"] ]
message_dict = {
"role": "assistant",
"content": message.content,
"tool_calls": tool_calls, # type: ignore[dict-item]
}
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage): elif isinstance(message, FunctionMessage):
@ -417,6 +458,27 @@ class ChatDeepInfra(BaseChatModel):
def _body(self, kwargs: Any) -> Dict: def _body(self, kwargs: Any) -> Dict:
return kwargs return kwargs
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]: def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
for line in rbody: for line in rbody:

View File

@ -1,11 +1,23 @@
"""Test ChatDeepInfra wrapper.""" """Test ChatDeepInfra wrapper."""
from typing import List
from langchain_core.messages import BaseMessage, HumanMessage from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.tool import ToolMessage
from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import RunnableBinding
from langchain_community.chat_models.deepinfra import ChatDeepInfra from langchain_community.chat_models.deepinfra import ChatDeepInfra
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class GenerateMovieName(BaseModel):
"Get a movie name from a description"
description: str
def test_chat_deepinfra() -> None: def test_chat_deepinfra() -> None:
"""Test valid call to DeepInfra.""" """Test valid call to DeepInfra."""
chat = ChatDeepInfra( chat = ChatDeepInfra(
@ -63,3 +75,51 @@ async def test_async_chat_deepinfra_streaming() -> None:
assert isinstance(generation, ChatGeneration) assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str) assert isinstance(generation.text, str)
assert generation.text == generation.message.content assert generation.text == generation.message.content
def test_chat_deepinfra_bind_tools() -> None:
class Foo(BaseModel):
pass
chat = ChatDeepInfra(
max_tokens=10,
)
tools = [Foo]
chat_with_tools = chat.bind_tools(tools)
assert isinstance(chat_with_tools, RunnableBinding)
chat_tools = chat_with_tools.tools
assert chat_tools
assert chat_tools == {
"tools": [
{
"function": {
"description": "",
"name": "Foo",
"parameters": {"properties": {}, "type": "object"},
},
"type": "function",
}
]
}
def test_tool_use() -> None:
llm = ChatDeepInfra(model="meta-llama/Meta-Llama-3-70B-Instruct", temperature=0)
llm_with_tool = llm.bind_tools(tools=[GenerateMovieName], tool_choice=True)
msgs: List = [
HumanMessage(content="It should be a movie explaining humanity in 2133.")
]
ai_msg = llm_with_tool.invoke(msgs)
assert isinstance(ai_msg, AIMessage)
assert isinstance(ai_msg.tool_calls, list)
assert len(ai_msg.tool_calls) == 1
tool_call = ai_msg.tool_calls[0]
assert "args" in tool_call
tool_msg = ToolMessage(
content="Year 2133",
tool_call_id=ai_msg.additional_kwargs["tool_calls"][0]["id"],
)
msgs.extend([ai_msg, tool_msg])
llm_with_tool.invoke(msgs)