diff --git a/docs/docs/integrations/chat/deepinfra.ipynb b/docs/docs/integrations/chat/deepinfra.ipynb index f3f3704ccf..e8d6d3465a 100644 --- a/docs/docs/integrations/chat/deepinfra.ipynb +++ b/docs/docs/integrations/chat/deepinfra.ipynb @@ -98,6 +98,78 @@ ")\n", "chat.invoke(messages)" ] + }, + { + "cell_type": "markdown", + "id": "466c3cb41ace1410", + "metadata": {}, + "source": [ + "# Tool Calling\n", + "\n", + "DeepInfra currently supports only invoke and async invoke tool calling.\n", + "\n", + "For a complete list of models that support tool calling, please refer to our [tool calling documentation](https://deepinfra.com/docs/advanced/function_calling)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ddc4f4299763651c", + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "\n", + "from dotenv import find_dotenv, load_dotenv\n", + "from langchain_community.chat_models import ChatDeepInfra\n", + "from langchain_core.messages import HumanMessage\n", + "from langchain_core.pydantic_v1 import BaseModel\n", + "from langchain_core.tools import tool\n", + "\n", + "model_name = \"meta-llama/Meta-Llama-3-70B-Instruct\"\n", + "\n", + "_ = load_dotenv(find_dotenv())\n", + "\n", + "\n", + "# Langchain tool\n", + "@tool\n", + "def foo(something):\n", + " \"\"\"\n", + " Called when foo\n", + " \"\"\"\n", + " pass\n", + "\n", + "\n", + "# Pydantic class\n", + "class Bar(BaseModel):\n", + " \"\"\"\n", + " Called when Bar\n", + " \"\"\"\n", + "\n", + " pass\n", + "\n", + "\n", + "llm = ChatDeepInfra(model=model_name)\n", + "tools = [foo, Bar]\n", + "llm_with_tools = llm.bind_tools(tools)\n", + "messages = [\n", + " HumanMessage(\"Foo and bar, please.\"),\n", + "]\n", + "\n", + "response = llm_with_tools.invoke(messages)\n", + "print(response.tool_calls)\n", + "# [{'name': 'foo', 'args': {'something': None}, 'id': 'call_Mi4N4wAtW89OlbizFE1aDxDj'}, {'name': 'Bar', 'args': {}, 'id': 'call_daiE0mW454j2O1KVbmET4s2r'}]\n", + "\n", + "\n", + "async def call_ainvoke():\n", + " result = await llm_with_tools.ainvoke(messages)\n", + " print(result.tool_calls)\n", + "\n", + "\n", + "# Async call\n", + "asyncio.run(call_ainvoke())\n", + "# [{'name': 'foo', 'args': {'something': None}, 'id': 'call_ZH7FetmgSot4LHcMU6CEb8tI'}, {'name': 'Bar', 'args': {}, 'id': 'call_2MQhDifAJVoijZEvH8PeFSVB'}]" + ] } ], "metadata": { diff --git a/libs/community/langchain_community/chat_models/deepinfra.py b/libs/community/langchain_community/chat_models/deepinfra.py index 51df3b634b..32be0867a0 100644 --- a/libs/community/langchain_community/chat_models/deepinfra.py +++ b/libs/community/langchain_community/chat_models/deepinfra.py @@ -13,6 +13,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Tuple, Type, Union, @@ -24,6 +25,7 @@ from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, agenerate_from_stream, @@ -44,15 +46,18 @@ from langchain_core.messages import ( SystemMessage, SystemMessageChunk, ) +from langchain_core.messages.tool import ToolCall from langchain_core.outputs import ( ChatGeneration, ChatGenerationChunk, ChatResult, ) -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils.function_calling import convert_to_openai_tool -# from langchain.llms.base import create_base_retry_decorator from langchain_community.utilities.requests import Requests logger = logging.getLogger(__name__) @@ -78,19 +83,51 @@ def _create_retry_decorator( ) +def _parse_tool_calling(tool_call: dict) -> ToolCall: + """ + Convert a tool calling response from server to a ToolCall object. + Args: + tool_call: + + Returns: + + """ + name = tool_call.get("name", "") + args = json.loads(tool_call["function"]["arguments"]) + id = tool_call.get("id") + return ToolCall(name=name, args=args, id=id) + + +def _convert_to_tool_calling(tool_call: ToolCall) -> Dict[str, Any]: + """ + Convert a ToolCall object to a tool calling request for server. + Args: + tool_call: + + Returns: + + """ + return { + "type": "function", + "function": { + "arguments": json.dumps(tool_call["args"]), + "name": tool_call["name"], + }, + "id": tool_call.get("id"), + } + + def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: role = _dict["role"] if role == "user": return HumanMessage(content=_dict["content"]) elif role == "assistant": - # Fix for azure - # Also OpenAI returns None for tool invocations content = _dict.get("content", "") or "" - if _dict.get("function_call"): - additional_kwargs = {"function_call": dict(_dict["function_call"])} - else: - additional_kwargs = {} - return AIMessage(content=content, additional_kwargs=additional_kwargs) + tool_calls_content = _dict.get("tool_calls", []) or [] + tool_calls = [ + _parse_tool_calling(tool_call) for tool_call in tool_calls_content + ] + return AIMessage(content=content, tool_calls=tool_calls) elif role == "system": return SystemMessage(content=_dict["content"]) elif role == "function": @@ -104,15 +141,14 @@ def _convert_delta_to_message_chunk( ) -> BaseMessageChunk: role = _dict.get("role") content = _dict.get("content") or "" - if _dict.get("function_call"): - additional_kwargs = {"function_call": dict(_dict["function_call"])} - else: - additional_kwargs = {} if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + tool_calls = [ + _parse_tool_calling(tool_call) for tool_call in _dict.get("tool_calls", []) + ] + return AIMessageChunk(content=content, tool_calls=tool_calls) elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) elif role == "function" or default_class == FunctionMessageChunk: @@ -129,9 +165,14 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: elif isinstance(message, HumanMessage): message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): - message_dict = {"role": "assistant", "content": message.content} - if "function_call" in message.additional_kwargs: - message_dict["function_call"] = message.additional_kwargs["function_call"] + tool_calls = [ + _convert_to_tool_calling(tool_call) for tool_call in message.tool_calls + ] + message_dict = { + "role": "assistant", + "content": message.content, + "tool_calls": tool_calls, # type: ignore[dict-item] + } elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} elif isinstance(message, FunctionMessage): @@ -417,6 +458,27 @@ class ChatDeepInfra(BaseChatModel): def _body(self, kwargs: Any) -> Dict: return kwargs + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + Assumes model is compatible with OpenAI tool-calling API. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + return super().bind(tools=formatted_tools, **kwargs) + def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]: for line in rbody: diff --git a/libs/community/tests/integration_tests/chat_models/test_deepinfra.py b/libs/community/tests/integration_tests/chat_models/test_deepinfra.py index 0fa4593ace..572cec0522 100644 --- a/libs/community/tests/integration_tests/chat_models/test_deepinfra.py +++ b/libs/community/tests/integration_tests/chat_models/test_deepinfra.py @@ -1,11 +1,23 @@ """Test ChatDeepInfra wrapper.""" +from typing import List + from langchain_core.messages import BaseMessage, HumanMessage +from langchain_core.messages.ai import AIMessage +from langchain_core.messages.tool import ToolMessage from langchain_core.outputs import ChatGeneration, LLMResult +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables.base import RunnableBinding from langchain_community.chat_models.deepinfra import ChatDeepInfra from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +class GenerateMovieName(BaseModel): + "Get a movie name from a description" + + description: str + + def test_chat_deepinfra() -> None: """Test valid call to DeepInfra.""" chat = ChatDeepInfra( @@ -63,3 +75,51 @@ async def test_async_chat_deepinfra_streaming() -> None: assert isinstance(generation, ChatGeneration) assert isinstance(generation.text, str) assert generation.text == generation.message.content + + +def test_chat_deepinfra_bind_tools() -> None: + class Foo(BaseModel): + pass + + chat = ChatDeepInfra( + max_tokens=10, + ) + tools = [Foo] + chat_with_tools = chat.bind_tools(tools) + assert isinstance(chat_with_tools, RunnableBinding) + chat_tools = chat_with_tools.tools + assert chat_tools + assert chat_tools == { + "tools": [ + { + "function": { + "description": "", + "name": "Foo", + "parameters": {"properties": {}, "type": "object"}, + }, + "type": "function", + } + ] + } + + +def test_tool_use() -> None: + llm = ChatDeepInfra(model="meta-llama/Meta-Llama-3-70B-Instruct", temperature=0) + llm_with_tool = llm.bind_tools(tools=[GenerateMovieName], tool_choice=True) + msgs: List = [ + HumanMessage(content="It should be a movie explaining humanity in 2133.") + ] + ai_msg = llm_with_tool.invoke(msgs) + + assert isinstance(ai_msg, AIMessage) + assert isinstance(ai_msg.tool_calls, list) + assert len(ai_msg.tool_calls) == 1 + tool_call = ai_msg.tool_calls[0] + assert "args" in tool_call + + tool_msg = ToolMessage( + content="Year 2133", + tool_call_id=ai_msg.additional_kwargs["tool_calls"][0]["id"], + ) + msgs.extend([ai_msg, tool_msg]) + llm_with_tool.invoke(msgs)