mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
community[minor]: add tool calling for DeepInfraChat (#22745)
DeepInfra now supports tool calling for supported models. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
158701ab3c
commit
dd25d08c06
@ -98,6 +98,78 @@
|
|||||||
")\n",
|
")\n",
|
||||||
"chat.invoke(messages)"
|
"chat.invoke(messages)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "466c3cb41ace1410",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Tool Calling\n",
|
||||||
|
"\n",
|
||||||
|
"DeepInfra currently supports only invoke and async invoke tool calling.\n",
|
||||||
|
"\n",
|
||||||
|
"For a complete list of models that support tool calling, please refer to our [tool calling documentation](https://deepinfra.com/docs/advanced/function_calling)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ddc4f4299763651c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import asyncio\n",
|
||||||
|
"\n",
|
||||||
|
"from dotenv import find_dotenv, load_dotenv\n",
|
||||||
|
"from langchain_community.chat_models import ChatDeepInfra\n",
|
||||||
|
"from langchain_core.messages import HumanMessage\n",
|
||||||
|
"from langchain_core.pydantic_v1 import BaseModel\n",
|
||||||
|
"from langchain_core.tools import tool\n",
|
||||||
|
"\n",
|
||||||
|
"model_name = \"meta-llama/Meta-Llama-3-70B-Instruct\"\n",
|
||||||
|
"\n",
|
||||||
|
"_ = load_dotenv(find_dotenv())\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Langchain tool\n",
|
||||||
|
"@tool\n",
|
||||||
|
"def foo(something):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Called when foo\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" pass\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Pydantic class\n",
|
||||||
|
"class Bar(BaseModel):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Called when Bar\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
" pass\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"llm = ChatDeepInfra(model=model_name)\n",
|
||||||
|
"tools = [foo, Bar]\n",
|
||||||
|
"llm_with_tools = llm.bind_tools(tools)\n",
|
||||||
|
"messages = [\n",
|
||||||
|
" HumanMessage(\"Foo and bar, please.\"),\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"response = llm_with_tools.invoke(messages)\n",
|
||||||
|
"print(response.tool_calls)\n",
|
||||||
|
"# [{'name': 'foo', 'args': {'something': None}, 'id': 'call_Mi4N4wAtW89OlbizFE1aDxDj'}, {'name': 'Bar', 'args': {}, 'id': 'call_daiE0mW454j2O1KVbmET4s2r'}]\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"async def call_ainvoke():\n",
|
||||||
|
" result = await llm_with_tools.ainvoke(messages)\n",
|
||||||
|
" print(result.tool_calls)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Async call\n",
|
||||||
|
"asyncio.run(call_ainvoke())\n",
|
||||||
|
"# [{'name': 'foo', 'args': {'something': None}, 'id': 'call_ZH7FetmgSot4LHcMU6CEb8tI'}, {'name': 'Bar', 'args': {}, 'id': 'call_2MQhDifAJVoijZEvH8PeFSVB'}]"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -13,6 +13,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
@ -24,6 +25,7 @@ from langchain_core.callbacks.manager import (
|
|||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.language_models.chat_models import (
|
from langchain_core.language_models.chat_models import (
|
||||||
BaseChatModel,
|
BaseChatModel,
|
||||||
agenerate_from_stream,
|
agenerate_from_stream,
|
||||||
@ -44,15 +46,18 @@ from langchain_core.messages import (
|
|||||||
SystemMessage,
|
SystemMessage,
|
||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.tool import ToolCall
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
ChatGenerationChunk,
|
ChatGenerationChunk,
|
||||||
ChatResult,
|
ChatResult,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
|
||||||
# from langchain.llms.base import create_base_retry_decorator
|
|
||||||
from langchain_community.utilities.requests import Requests
|
from langchain_community.utilities.requests import Requests
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -78,19 +83,51 @@ def _create_retry_decorator(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_tool_calling(tool_call: dict) -> ToolCall:
|
||||||
|
"""
|
||||||
|
Convert a tool calling response from server to a ToolCall object.
|
||||||
|
Args:
|
||||||
|
tool_call:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
name = tool_call.get("name", "")
|
||||||
|
args = json.loads(tool_call["function"]["arguments"])
|
||||||
|
id = tool_call.get("id")
|
||||||
|
return ToolCall(name=name, args=args, id=id)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_tool_calling(tool_call: ToolCall) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert a ToolCall object to a tool calling request for server.
|
||||||
|
Args:
|
||||||
|
tool_call:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"arguments": json.dumps(tool_call["args"]),
|
||||||
|
"name": tool_call["name"],
|
||||||
|
},
|
||||||
|
"id": tool_call.get("id"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
role = _dict["role"]
|
role = _dict["role"]
|
||||||
if role == "user":
|
if role == "user":
|
||||||
return HumanMessage(content=_dict["content"])
|
return HumanMessage(content=_dict["content"])
|
||||||
elif role == "assistant":
|
elif role == "assistant":
|
||||||
# Fix for azure
|
|
||||||
# Also OpenAI returns None for tool invocations
|
|
||||||
content = _dict.get("content", "") or ""
|
content = _dict.get("content", "") or ""
|
||||||
if _dict.get("function_call"):
|
tool_calls_content = _dict.get("tool_calls", []) or []
|
||||||
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
tool_calls = [
|
||||||
else:
|
_parse_tool_calling(tool_call) for tool_call in tool_calls_content
|
||||||
additional_kwargs = {}
|
]
|
||||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
return AIMessage(content=content, tool_calls=tool_calls)
|
||||||
elif role == "system":
|
elif role == "system":
|
||||||
return SystemMessage(content=_dict["content"])
|
return SystemMessage(content=_dict["content"])
|
||||||
elif role == "function":
|
elif role == "function":
|
||||||
@ -104,15 +141,14 @@ def _convert_delta_to_message_chunk(
|
|||||||
) -> BaseMessageChunk:
|
) -> BaseMessageChunk:
|
||||||
role = _dict.get("role")
|
role = _dict.get("role")
|
||||||
content = _dict.get("content") or ""
|
content = _dict.get("content") or ""
|
||||||
if _dict.get("function_call"):
|
|
||||||
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
|
||||||
else:
|
|
||||||
additional_kwargs = {}
|
|
||||||
|
|
||||||
if role == "user" or default_class == HumanMessageChunk:
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
return HumanMessageChunk(content=content)
|
return HumanMessageChunk(content=content)
|
||||||
elif role == "assistant" or default_class == AIMessageChunk:
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
tool_calls = [
|
||||||
|
_parse_tool_calling(tool_call) for tool_call in _dict.get("tool_calls", [])
|
||||||
|
]
|
||||||
|
return AIMessageChunk(content=content, tool_calls=tool_calls)
|
||||||
elif role == "system" or default_class == SystemMessageChunk:
|
elif role == "system" or default_class == SystemMessageChunk:
|
||||||
return SystemMessageChunk(content=content)
|
return SystemMessageChunk(content=content)
|
||||||
elif role == "function" or default_class == FunctionMessageChunk:
|
elif role == "function" or default_class == FunctionMessageChunk:
|
||||||
@ -129,9 +165,14 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
elif isinstance(message, HumanMessage):
|
elif isinstance(message, HumanMessage):
|
||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict = {"role": "user", "content": message.content}
|
||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
tool_calls = [
|
||||||
if "function_call" in message.additional_kwargs:
|
_convert_to_tool_calling(tool_call) for tool_call in message.tool_calls
|
||||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
]
|
||||||
|
message_dict = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": message.content,
|
||||||
|
"tool_calls": tool_calls, # type: ignore[dict-item]
|
||||||
|
}
|
||||||
elif isinstance(message, SystemMessage):
|
elif isinstance(message, SystemMessage):
|
||||||
message_dict = {"role": "system", "content": message.content}
|
message_dict = {"role": "system", "content": message.content}
|
||||||
elif isinstance(message, FunctionMessage):
|
elif isinstance(message, FunctionMessage):
|
||||||
@ -417,6 +458,27 @@ class ChatDeepInfra(BaseChatModel):
|
|||||||
def _body(self, kwargs: Any) -> Dict:
|
def _body(self, kwargs: Any) -> Dict:
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
def bind_tools(
|
||||||
|
self,
|
||||||
|
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
|
"""Bind tool-like objects to this chat model.
|
||||||
|
|
||||||
|
Assumes model is compatible with OpenAI tool-calling API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: A list of tool definitions to bind to this chat model.
|
||||||
|
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||||
|
models, callables, and BaseTools will be automatically converted to
|
||||||
|
their schema dictionary representation.
|
||||||
|
**kwargs: Any additional parameters to pass to the
|
||||||
|
:class:`~langchain.runnable.Runnable` constructor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||||
|
return super().bind(tools=formatted_tools, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
|
def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
|
||||||
for line in rbody:
|
for line in rbody:
|
||||||
|
@ -1,11 +1,23 @@
|
|||||||
"""Test ChatDeepInfra wrapper."""
|
"""Test ChatDeepInfra wrapper."""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage, HumanMessage
|
from langchain_core.messages import BaseMessage, HumanMessage
|
||||||
|
from langchain_core.messages.ai import AIMessage
|
||||||
|
from langchain_core.messages.tool import ToolMessage
|
||||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
from langchain_core.runnables.base import RunnableBinding
|
||||||
|
|
||||||
from langchain_community.chat_models.deepinfra import ChatDeepInfra
|
from langchain_community.chat_models.deepinfra import ChatDeepInfra
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateMovieName(BaseModel):
|
||||||
|
"Get a movie name from a description"
|
||||||
|
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
def test_chat_deepinfra() -> None:
|
def test_chat_deepinfra() -> None:
|
||||||
"""Test valid call to DeepInfra."""
|
"""Test valid call to DeepInfra."""
|
||||||
chat = ChatDeepInfra(
|
chat = ChatDeepInfra(
|
||||||
@ -63,3 +75,51 @@ async def test_async_chat_deepinfra_streaming() -> None:
|
|||||||
assert isinstance(generation, ChatGeneration)
|
assert isinstance(generation, ChatGeneration)
|
||||||
assert isinstance(generation.text, str)
|
assert isinstance(generation.text, str)
|
||||||
assert generation.text == generation.message.content
|
assert generation.text == generation.message.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_deepinfra_bind_tools() -> None:
|
||||||
|
class Foo(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
chat = ChatDeepInfra(
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
tools = [Foo]
|
||||||
|
chat_with_tools = chat.bind_tools(tools)
|
||||||
|
assert isinstance(chat_with_tools, RunnableBinding)
|
||||||
|
chat_tools = chat_with_tools.tools
|
||||||
|
assert chat_tools
|
||||||
|
assert chat_tools == {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"description": "",
|
||||||
|
"name": "Foo",
|
||||||
|
"parameters": {"properties": {}, "type": "object"},
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_use() -> None:
|
||||||
|
llm = ChatDeepInfra(model="meta-llama/Meta-Llama-3-70B-Instruct", temperature=0)
|
||||||
|
llm_with_tool = llm.bind_tools(tools=[GenerateMovieName], tool_choice=True)
|
||||||
|
msgs: List = [
|
||||||
|
HumanMessage(content="It should be a movie explaining humanity in 2133.")
|
||||||
|
]
|
||||||
|
ai_msg = llm_with_tool.invoke(msgs)
|
||||||
|
|
||||||
|
assert isinstance(ai_msg, AIMessage)
|
||||||
|
assert isinstance(ai_msg.tool_calls, list)
|
||||||
|
assert len(ai_msg.tool_calls) == 1
|
||||||
|
tool_call = ai_msg.tool_calls[0]
|
||||||
|
assert "args" in tool_call
|
||||||
|
|
||||||
|
tool_msg = ToolMessage(
|
||||||
|
content="Year 2133",
|
||||||
|
tool_call_id=ai_msg.additional_kwargs["tool_calls"][0]["id"],
|
||||||
|
)
|
||||||
|
msgs.extend([ai_msg, tool_msg])
|
||||||
|
llm_with_tool.invoke(msgs)
|
||||||
|
Loading…
Reference in New Issue
Block a user