diff --git a/cookbook/tool_call_messages.ipynb b/cookbook/tool_call_messages.ipynb index c7253e44bd..3d533c89b7 100644 --- a/cookbook/tool_call_messages.ipynb +++ b/cookbook/tool_call_messages.ipynb @@ -2,19 +2,10 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "c48812ed-35bd-4fbe-9a2c-6c7335e5645e", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/chestercurme/repos/langchain/libs/core/langchain_core/_api/beta_decorator.py:87: LangChainBetaWarning: The function `bind_tools` is in beta. It is actively being worked on, so the API may change.\n", - " warn_beta(\n" - ] - } - ], + "outputs": [], "source": [ "from langchain_anthropic import ChatAnthropic\n", "from langchain_core.runnables import ConfigurableField\n", @@ -49,221 +40,6 @@ ")" ] }, - { - "cell_type": "markdown", - "id": "4719ebdb-ad50-468e-9b30-fb5fb086e140", - "metadata": {}, - "source": [ - "# AgentExecutor" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "b98feaa5-8c2d-4125-9519-67114a1fef31", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import List, Tuple, Union\n", - "\n", - "from langchain.agents import AgentExecutor\n", - "from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction\n", - "from langchain_core.agents import AgentFinish, _convert_agent_action_to_messages\n", - "from langchain_core.messages import (\n", - " AIMessage,\n", - " BaseMessage,\n", - " ToolMessage,\n", - ")\n", - "from langchain_core.prompts import ChatPromptTemplate\n", - "from langchain_core.runnables import RunnablePassthrough\n", - "\n", - "\n", - "def actions_observations_to_messages(\n", - " steps: List[Tuple[OpenAIToolAgentAction, str]],\n", - ") -> List[BaseMessage]:\n", - " messages = []\n", - " for action, observation in steps:\n", - " messages.extend([m for m in action.message_log if m not in messages])\n", - " messages.append(ToolMessage(observation, tool_call_id=action.tool_call_id))\n", - " return messages\n", - "\n", - "\n", - "def messages_to_action(\n", - " msg: AIMessage,\n", - ") -> Union[List[OpenAIToolAgentAction], AgentFinish]:\n", - " if isinstance(msg, AIMessage) and msg.tool_calls is not None:\n", - " actions = []\n", - " for tool_call in msg.tool_calls:\n", - " actions.append(\n", - " OpenAIToolAgentAction(\n", - " tool=tool_call.name,\n", - " tool_input=tool_call.args,\n", - " tool_call_id=tool_call.id,\n", - " message_log=[msg],\n", - " log=\"\",\n", - " )\n", - " )\n", - " return actions\n", - " else:\n", - " return AgentFinish(return_values={\"output\": msg.content}, log=\"\")\n", - "\n", - "\n", - "prompt = ChatPromptTemplate.from_messages(\n", - " [\n", - " (\"system\", \"You're a helpful assistant with access to tools\"),\n", - " (\"human\", \"{input}\"),\n", - " (\"placeholder\", \"{agent_scratchpad}\"),\n", - " ]\n", - ")\n", - "\n", - "agent = (\n", - " RunnablePassthrough.assign(\n", - " agent_scratchpad=lambda x: actions_observations_to_messages(\n", - " x[\"intermediate_steps\"]\n", - " ),\n", - " )\n", - " | prompt\n", - " | llm_with_tools\n", - " | messages_to_action\n", - ")\n", - "\n", - "agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "b4c0fc7a-80bb-4bb8-a87b-7388291ae8b6", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[33;1m\u001b[1;3m300.03770462067547\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m-900.8841\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "{'input': \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\",\n", - " 'output': 'The result of \\\\(3 + 5^{2.743}\\\\) is approximately 300.04, and the result of \\\\(17.24 - 918.1241\\\\) is approximately -900.88.'}" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "agent_executor.invoke(\n", - " {\"input\": \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"}\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "41a3a3c8-185d-4861-b6f0-7592668feb62", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n", - " warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[33;1m\u001b[1;3m82.65606421491815\u001b[0m" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n", - " warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m85.65606421491815\u001b[0m" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n", - " warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m-900.8841\u001b[0m" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n", - " warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32;1m\u001b[1;3m\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "{'input': \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\",\n", - " 'output': 'Therefore, 17.24 - 918.1241 = -900.8841'}" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "agent_executor = AgentExecutor(\n", - " agent=agent.with_config(configurable={\"llm\": \"claude3\"}), tools=tools, verbose=True\n", - ")\n", - "agent_executor.invoke(\n", - " {\"input\": \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"},\n", - ")" - ] - }, { "cell_type": "markdown", "id": "9c186263-1b98-4cb2-b6d1-71f65eb0d811", @@ -274,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "id": "28fc2c60-7dbc-428a-8983-1a6a15ea30d2", "metadata": {}, "outputs": [], @@ -282,7 +58,7 @@ "import operator\n", "from typing import Annotated, Sequence, TypedDict\n", "\n", - "from langchain_core.messages import AIMessage, BaseMessage, HumanMessage\n", + "from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage\n", "from langchain_core.runnables import RunnableLambda\n", "from langgraph.graph import END, StateGraph\n", "\n", @@ -292,7 +68,7 @@ "\n", "\n", "def should_continue(state):\n", - " return \"continue\" if state[\"messages\"][-1].tool_calls is not None else \"end\"\n", + " return \"continue\" if state[\"messages\"][-1].tool_calls else \"end\"\n", "\n", "\n", "def call_model(state, config):\n", @@ -300,8 +76,8 @@ "\n", "\n", "def _invoke_tool(tool_call):\n", - " tool = {tool.name: tool for tool in tools}[tool_call.name]\n", - " return ToolMessage(tool.invoke(tool_call.args), tool_call_id=tool_call.id)\n", + " tool = {tool.name: tool for tool in tools}[tool_call[\"name\"]]\n", + " return ToolMessage(tool.invoke(tool_call[\"args\"]), tool_call_id=tool_call[\"id\"])\n", "\n", "\n", "tool_executor = RunnableLambda(_invoke_tool)\n", @@ -330,21 +106,21 @@ }, { "cell_type": "code", - "execution_count": 6, - "id": "24463798-74e6-4c39-8092-7a1524d83225", + "execution_count": 4, + "id": "3710e724-2595-4625-ba3a-effb81e66e4a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'messages': [HumanMessage(content=\"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"),\n", - " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_kbBUUeqK75fZZqDTvu8aif7Z', 'function': {'arguments': '{\"x\": 8, \"y\": 2.743}', 'name': 'exponentiate'}, 'type': 'function'}, {'id': 'call_pBD8daSyXidXnrIyG4vG5C9O', 'function': {'arguments': '{\"x\": 17.24, \"y\": -918.1241}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 58, 'prompt_tokens': 168, 'total_tokens': 226}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-8e1d9687-611c-4c8e-9fcd-ef6e48bd22a6-0', tool_calls=[ToolCall(name='exponentiate', args={'x': 8, 'y': 2.743}, id='call_kbBUUeqK75fZZqDTvu8aif7Z'), ToolCall(name='add', args={'x': 17.24, 'y': -918.1241}, id='call_pBD8daSyXidXnrIyG4vG5C9O')]),\n", - " ToolMessage(content='300.03770462067547', tool_call_id='call_kbBUUeqK75fZZqDTvu8aif7Z'),\n", - " ToolMessage(content='-900.8841', tool_call_id='call_pBD8daSyXidXnrIyG4vG5C9O'),\n", - " AIMessage(content='The result of \\\\(3 + 5^{2.743}\\\\) is approximately 300.04, and the result of \\\\(17.24 - 918.1241\\\\) is approximately -900.88.', response_metadata={'token_usage': {'completion_tokens': 44, 'prompt_tokens': 251, 'total_tokens': 295}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'stop', 'logprobs': None}, id='run-47fe5cbc-3f25-44c3-85b2-6540c3054a77-0')]}" + " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_6yMU2WsS4Bqgi1WxFHxtfJRc', 'function': {'arguments': '{\"x\": 8, \"y\": 2.743}', 'name': 'exponentiate'}, 'type': 'function'}, {'id': 'call_GAL3dQiKFF9XEV0RrRLPTvVp', 'function': {'arguments': '{\"x\": 17.24, \"y\": -918.1241}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 58, 'prompt_tokens': 168, 'total_tokens': 226}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-528302fc-7acf-4c11-82c4-119ccf40c573-0', tool_calls=[{'name': 'exponentiate', 'args': {'x': 8, 'y': 2.743}, 'id': 'call_6yMU2WsS4Bqgi1WxFHxtfJRc'}, {'name': 'add', 'args': {'x': 17.24, 'y': -918.1241}, 'id': 'call_GAL3dQiKFF9XEV0RrRLPTvVp'}]),\n", + " ToolMessage(content='300.03770462067547', tool_call_id='call_6yMU2WsS4Bqgi1WxFHxtfJRc'),\n", + " ToolMessage(content='-900.8841', tool_call_id='call_GAL3dQiKFF9XEV0RrRLPTvVp'),\n", + " AIMessage(content='The result of \\\\(3 + 5^{2.743}\\\\) is approximately 300.04, and the result of \\\\(17.24 - 918.1241\\\\) is approximately -900.88.', response_metadata={'token_usage': {'completion_tokens': 44, 'prompt_tokens': 251, 'total_tokens': 295}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'stop', 'logprobs': None}, id='run-d1161669-ed09-4b18-94bd-6d8530df5aa8-0')]}" ] }, - "execution_count": 6, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -363,7 +139,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "id": "073c074e-d722-42e0-85ec-c62c079207e4", "metadata": {}, "outputs": [ @@ -371,16 +147,16 @@ "data": { "text/plain": [ "{'messages': [HumanMessage(content=\"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"),\n", - " AIMessage(content=[{'text': \"Okay, let's break this down into steps:\", 'type': 'text'}, {'id': 'toolu_01DJkSDpB8ztmJx2DLNbc3eW', 'input': {'x': 5, 'y': 2.743}, 'name': 'exponentiate', 'type': 'tool_use'}], response_metadata={'id': 'msg_01KuVNohyJr24cPhJkY3XVtt', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 450, 'output_tokens': 84}}, id='run-336cdfb6-0fe4-4d7a-9946-9f01c2eb41ae-0', tool_calls=[ToolCall(name='exponentiate', args={'x': 5, 'y': 2.743}, id='toolu_01DJkSDpB8ztmJx2DLNbc3eW', index=1)]),\n", - " ToolMessage(content='82.65606421491815', tool_call_id='toolu_01DJkSDpB8ztmJx2DLNbc3eW'),\n", - " AIMessage(content=[{'text': 'To get 5 raised to the 2.743 power.', 'type': 'text'}, {'id': 'toolu_01MKQqnDw5CtyuKjQP8YG1FX', 'input': {'x': 3, 'y': 82.65606421491815}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_01UBsKkvA4StUR4NEvoFFFep', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 552, 'output_tokens': 91}}, id='run-9d25b7bd-58aa-47dd-933f-15459b24b2c2-0', tool_calls=[ToolCall(name='add', args={'x': 3, 'y': 82.65606421491815}, id='toolu_01MKQqnDw5CtyuKjQP8YG1FX', index=1)]),\n", - " ToolMessage(content='85.65606421491815', tool_call_id='toolu_01MKQqnDw5CtyuKjQP8YG1FX'),\n", - " AIMessage(content=[{'text': 'So 3 plus 5 raised to the 2.743 power is 85.656.\\n\\nFor the second part:', 'type': 'text'}, {'id': 'toolu_019Wb2zPouCR3dw2bSKvCRUL', 'input': {'x': 17.24, 'y': -918.1241}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_01Y2H2L8FWcDtVkCtuosie2P', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 661, 'output_tokens': 105}}, id='run-e553c1e3-24ba-4e1b-93ba-6f1985932db4-0', tool_calls=[ToolCall(name='add', args={'x': 17.24, 'y': -918.1241}, id='toolu_019Wb2zPouCR3dw2bSKvCRUL', index=1)]),\n", - " ToolMessage(content='-900.8841', tool_call_id='toolu_019Wb2zPouCR3dw2bSKvCRUL'),\n", - " AIMessage(content='Therefore, 17.24 - 918.1241 = -900.8841', response_metadata={'id': 'msg_01Q14dqvaCD2eA4zwrUvxTcF', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 782, 'output_tokens': 24}}, id='run-f6b6e525-2df6-4617-9bb3-b39d5cc963a9-0')]}" + " AIMessage(content=[{'text': \"Okay, let's break this down into two parts:\", 'type': 'text'}, {'id': 'toolu_01DEhqcXkXTtzJAiZ7uMBeDC', 'input': {'x': 3, 'y': 5}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_01AkLGH8sxMHaH15yewmjwkF', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 450, 'output_tokens': 81}}, id='run-f35bfae8-8ded-4f8a-831b-0940d6ad16b6-0', tool_calls=[{'name': 'add', 'args': {'x': 3, 'y': 5}, 'id': 'toolu_01DEhqcXkXTtzJAiZ7uMBeDC'}]),\n", + " ToolMessage(content='8.0', tool_call_id='toolu_01DEhqcXkXTtzJAiZ7uMBeDC'),\n", + " AIMessage(content=[{'id': 'toolu_013DyMLrvnrto33peAKMGMr1', 'input': {'x': 8.0, 'y': 2.743}, 'name': 'exponentiate', 'type': 'tool_use'}], response_metadata={'id': 'msg_015Fmp8aztwYcce2JDAFfce3', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 545, 'output_tokens': 75}}, id='run-48aaeeeb-a1e5-48fd-a57a-6c3da2907b47-0', tool_calls=[{'name': 'exponentiate', 'args': {'x': 8.0, 'y': 2.743}, 'id': 'toolu_013DyMLrvnrto33peAKMGMr1'}]),\n", + " ToolMessage(content='300.03770462067547', tool_call_id='toolu_013DyMLrvnrto33peAKMGMr1'),\n", + " AIMessage(content=[{'text': 'So 3 plus 5 raised to the 2.743 power is 300.04.\\n\\nFor the second part:', 'type': 'text'}, {'id': 'toolu_01UTmMrGTmLpPrPCF1rShN46', 'input': {'x': 17.24, 'y': -918.1241}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_015TkhfRBENPib2RWAxkieH6', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 638, 'output_tokens': 105}}, id='run-45fb62e3-d102-4159-881d-241c5dbadeed-0', tool_calls=[{'name': 'add', 'args': {'x': 17.24, 'y': -918.1241}, 'id': 'toolu_01UTmMrGTmLpPrPCF1rShN46'}]),\n", + " ToolMessage(content='-900.8841', tool_call_id='toolu_01UTmMrGTmLpPrPCF1rShN46'),\n", + " AIMessage(content='Therefore, 17.24 - 918.1241 = -900.8841', response_metadata={'id': 'msg_01LgKnRuUcSyADCpxv9tPoYD', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 759, 'output_tokens': 24}}, id='run-1008254e-ccd1-497c-8312-9550dd77bd08-0')]}" ] }, - "execution_count": 7, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" }