From 39f6c4169d4ff47542aa4ef15caa980b06d51e8b Mon Sep 17 00:00:00 2001 From: mackong Date: Tue, 18 Jun 2024 11:29:00 +0800 Subject: [PATCH] langchain[patch]: add tool messages formatter for tool calling agent (#22849) - **Description:** add tool_messages_formatter for tool calling agent, make tool messages can be formatted in different ways for your LLM. - **Issue:** N/A - **Dependencies:** N/A --- .../langchain/agents/tool_calling_agent/base.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/agents/tool_calling_agent/base.py b/libs/langchain/langchain/agents/tool_calling_agent/base.py index 917abab0ed..6266cecf1c 100644 --- a/libs/langchain/langchain/agents/tool_calling_agent/base.py +++ b/libs/langchain/langchain/agents/tool_calling_agent/base.py @@ -1,6 +1,8 @@ -from typing import Sequence +from typing import Callable, List, Sequence, Tuple +from langchain_core.agents import AgentAction from langchain_core.language_models import BaseLanguageModel +from langchain_core.messages import BaseMessage from langchain_core.prompts.chat import ChatPromptTemplate from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_core.tools import BaseTool @@ -10,9 +12,15 @@ from langchain.agents.format_scratchpad.tools import ( ) from langchain.agents.output_parsers.tools import ToolsAgentOutputParser +MessageFormatter = Callable[[Sequence[Tuple[AgentAction, str]]], List[BaseMessage]] + def create_tool_calling_agent( - llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate + llm: BaseLanguageModel, + tools: Sequence[BaseTool], + prompt: ChatPromptTemplate, + *, + message_formatter: MessageFormatter = format_to_tool_messages, ) -> Runnable: """Create an agent that uses tools. @@ -21,6 +29,8 @@ def create_tool_calling_agent( tools: Tools this agent has access to. prompt: The prompt to use. See Prompt section below for more on the expected input variables. + message_formatter: Formatter function to convert (AgentAction, tool output) + tuples into FunctionMessages. Returns: A Runnable sequence representing an agent. It takes as input all the same input @@ -89,7 +99,7 @@ def create_tool_calling_agent( agent = ( RunnablePassthrough.assign( - agent_scratchpad=lambda x: format_to_tool_messages(x["intermediate_steps"]) + agent_scratchpad=lambda x: message_formatter(x["intermediate_steps"]) ) | prompt | llm_with_tools