diff --git a/libs/langchain/langchain/agents/tool_calling_agent/base.py b/libs/langchain/langchain/agents/tool_calling_agent/base.py index 917abab0ed..6266cecf1c 100644 --- a/libs/langchain/langchain/agents/tool_calling_agent/base.py +++ b/libs/langchain/langchain/agents/tool_calling_agent/base.py @@ -1,6 +1,8 @@ -from typing import Sequence +from typing import Callable, List, Sequence, Tuple +from langchain_core.agents import AgentAction from langchain_core.language_models import BaseLanguageModel +from langchain_core.messages import BaseMessage from langchain_core.prompts.chat import ChatPromptTemplate from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_core.tools import BaseTool @@ -10,9 +12,15 @@ from langchain.agents.format_scratchpad.tools import ( ) from langchain.agents.output_parsers.tools import ToolsAgentOutputParser +MessageFormatter = Callable[[Sequence[Tuple[AgentAction, str]]], List[BaseMessage]] + def create_tool_calling_agent( - llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate + llm: BaseLanguageModel, + tools: Sequence[BaseTool], + prompt: ChatPromptTemplate, + *, + message_formatter: MessageFormatter = format_to_tool_messages, ) -> Runnable: """Create an agent that uses tools. @@ -21,6 +29,8 @@ def create_tool_calling_agent( tools: Tools this agent has access to. prompt: The prompt to use. See Prompt section below for more on the expected input variables. + message_formatter: Formatter function to convert (AgentAction, tool output) + tuples into FunctionMessages. Returns: A Runnable sequence representing an agent. It takes as input all the same input @@ -89,7 +99,7 @@ def create_tool_calling_agent( agent = ( RunnablePassthrough.assign( - agent_scratchpad=lambda x: format_to_tool_messages(x["intermediate_steps"]) + agent_scratchpad=lambda x: message_formatter(x["intermediate_steps"]) ) | prompt | llm_with_tools