mirror of https://github.com/hwchase17/langchain
update agents to use tool call messages (#20074)
```python from langchain.agents import AgentExecutor, create_tool_calling_agent, tool from langchain_anthropic import ChatAnthropic from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder prompt = ChatPromptTemplate.from_messages( [ ("system", "You are a helpful assistant"), MessagesPlaceholder("chat_history", optional=True), ("human", "{input}"), MessagesPlaceholder("agent_scratchpad"), ] ) model = ChatAnthropic(model="claude-3-opus-20240229") @tool def magic_function(input: int) -> int: """Applies a magic function to an input.""" return input + 2 tools = [magic_function] agent = create_tool_calling_agent(model, tools, prompt) agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) agent_executor.invoke({"input": "what is the value of magic_function(3)?"}) ``` ``` > Entering new AgentExecutor chain... Invoking: `magic_function` with `{'input': 3}` responded: [{'text': '<thinking>\nThe user has asked for the value of magic_function applied to the input 3. Looking at the available tools, magic_function is the relevant one to use here, as it takes an integer input and returns an integer output.\n\nThe magic_function has one required parameter:\n- input (integer)\n\nThe user has directly provided the value 3 for the input parameter. Since the required parameter is present, we can proceed with calling the function.\n</thinking>', 'type': 'text'}, {'id': 'toolu_01HsTheJPA5mcipuFDBbJ1CW', 'input': {'input': 3}, 'name': 'magic_function', 'type': 'tool_use'}] 5 Therefore, the value of magic_function(3) is 5. > Finished chain. {'input': 'what is the value of magic_function(3)?', 'output': 'Therefore, the value of magic_function(3) is 5.'} ``` --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>pull/20290/head
parent
9eb6f538f0
commit
21c1ce0bc1
@ -1,59 +1,5 @@
|
|||||||
import json
|
from langchain.agents.format_scratchpad.tools import (
|
||||||
from typing import List, Sequence, Tuple
|
format_to_tool_messages as format_to_openai_tool_messages,
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction
|
|
||||||
from langchain_core.messages import (
|
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
|
||||||
ToolMessage,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction
|
__all__ = ["format_to_openai_tool_messages"]
|
||||||
|
|
||||||
|
|
||||||
def _create_tool_message(
|
|
||||||
agent_action: OpenAIToolAgentAction, observation: str
|
|
||||||
) -> ToolMessage:
|
|
||||||
"""Convert agent action and observation into a function message.
|
|
||||||
Args:
|
|
||||||
agent_action: the tool invocation request from the agent
|
|
||||||
observation: the result of the tool invocation
|
|
||||||
Returns:
|
|
||||||
FunctionMessage that corresponds to the original tool invocation
|
|
||||||
"""
|
|
||||||
if not isinstance(observation, str):
|
|
||||||
try:
|
|
||||||
content = json.dumps(observation, ensure_ascii=False)
|
|
||||||
except Exception:
|
|
||||||
content = str(observation)
|
|
||||||
else:
|
|
||||||
content = observation
|
|
||||||
return ToolMessage(
|
|
||||||
tool_call_id=agent_action.tool_call_id,
|
|
||||||
content=content,
|
|
||||||
additional_kwargs={"name": agent_action.tool},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def format_to_openai_tool_messages(
|
|
||||||
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
|
||||||
) -> List[BaseMessage]:
|
|
||||||
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list of messages to send to the LLM for the next prediction
|
|
||||||
|
|
||||||
"""
|
|
||||||
messages = []
|
|
||||||
for agent_action, observation in intermediate_steps:
|
|
||||||
if isinstance(agent_action, OpenAIToolAgentAction):
|
|
||||||
new_messages = list(agent_action.message_log) + [
|
|
||||||
_create_tool_message(agent_action, observation)
|
|
||||||
]
|
|
||||||
messages.extend([new for new in new_messages if new not in messages])
|
|
||||||
else:
|
|
||||||
messages.append(AIMessage(content=agent_action.log))
|
|
||||||
return messages
|
|
||||||
|
@ -0,0 +1,59 @@
|
|||||||
|
import json
|
||||||
|
from typing import List, Sequence, Tuple
|
||||||
|
|
||||||
|
from langchain_core.agents import AgentAction
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ToolMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.agents.output_parsers.tools import ToolAgentAction
|
||||||
|
|
||||||
|
|
||||||
|
def _create_tool_message(
|
||||||
|
agent_action: ToolAgentAction, observation: str
|
||||||
|
) -> ToolMessage:
|
||||||
|
"""Convert agent action and observation into a function message.
|
||||||
|
Args:
|
||||||
|
agent_action: the tool invocation request from the agent
|
||||||
|
observation: the result of the tool invocation
|
||||||
|
Returns:
|
||||||
|
FunctionMessage that corresponds to the original tool invocation
|
||||||
|
"""
|
||||||
|
if not isinstance(observation, str):
|
||||||
|
try:
|
||||||
|
content = json.dumps(observation, ensure_ascii=False)
|
||||||
|
except Exception:
|
||||||
|
content = str(observation)
|
||||||
|
else:
|
||||||
|
content = observation
|
||||||
|
return ToolMessage(
|
||||||
|
tool_call_id=agent_action.tool_call_id,
|
||||||
|
content=content,
|
||||||
|
additional_kwargs={"name": agent_action.tool},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def format_to_tool_messages(
|
||||||
|
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
||||||
|
) -> List[BaseMessage]:
|
||||||
|
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of messages to send to the LLM for the next prediction
|
||||||
|
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
for agent_action, observation in intermediate_steps:
|
||||||
|
if isinstance(agent_action, ToolAgentAction):
|
||||||
|
new_messages = list(agent_action.message_log) + [
|
||||||
|
_create_tool_message(agent_action, observation)
|
||||||
|
]
|
||||||
|
messages.extend([new for new in new_messages if new not in messages])
|
||||||
|
else:
|
||||||
|
messages.append(AIMessage(content=agent_action.log))
|
||||||
|
return messages
|
@ -0,0 +1,102 @@
|
|||||||
|
import json
|
||||||
|
from json import JSONDecodeError
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
||||||
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ToolCall,
|
||||||
|
)
|
||||||
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
|
|
||||||
|
from langchain.agents.agent import MultiActionAgentOutputParser
|
||||||
|
|
||||||
|
|
||||||
|
class ToolAgentAction(AgentActionMessageLog):
|
||||||
|
tool_call_id: str
|
||||||
|
"""Tool call that this message is responding to."""
|
||||||
|
|
||||||
|
|
||||||
|
def parse_ai_message_to_tool_action(
|
||||||
|
message: BaseMessage,
|
||||||
|
) -> Union[List[AgentAction], AgentFinish]:
|
||||||
|
"""Parse an AI message potentially containing tool_calls."""
|
||||||
|
if not isinstance(message, AIMessage):
|
||||||
|
raise TypeError(f"Expected an AI message got {type(message)}")
|
||||||
|
|
||||||
|
actions: List = []
|
||||||
|
if message.tool_calls:
|
||||||
|
tool_calls = message.tool_calls
|
||||||
|
else:
|
||||||
|
if not message.additional_kwargs.get("tool_calls"):
|
||||||
|
return AgentFinish(
|
||||||
|
return_values={"output": message.content}, log=str(message.content)
|
||||||
|
)
|
||||||
|
# Best-effort parsing
|
||||||
|
tool_calls = []
|
||||||
|
for tool_call in message.additional_kwargs["tool_calls"]:
|
||||||
|
function = tool_call["function"]
|
||||||
|
function_name = function["name"]
|
||||||
|
try:
|
||||||
|
args = json.loads(function["arguments"] or "{}")
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(name=function_name, args=args, id=tool_call["id"])
|
||||||
|
)
|
||||||
|
except JSONDecodeError:
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Could not parse tool input: {function} because "
|
||||||
|
f"the `arguments` is not valid JSON."
|
||||||
|
)
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
# HACK HACK HACK:
|
||||||
|
# The code that encodes tool input into Open AI uses a special variable
|
||||||
|
# name called `__arg1` to handle old style tools that do not expose a
|
||||||
|
# schema and expect a single string argument as an input.
|
||||||
|
# We unpack the argument here if it exists.
|
||||||
|
# Open AI does not support passing in a JSON array as an argument.
|
||||||
|
function_name = tool_call["name"]
|
||||||
|
_tool_input = tool_call["args"]
|
||||||
|
if "__arg1" in _tool_input:
|
||||||
|
tool_input = _tool_input["__arg1"]
|
||||||
|
else:
|
||||||
|
tool_input = _tool_input
|
||||||
|
|
||||||
|
content_msg = f"responded: {message.content}\n" if message.content else "\n"
|
||||||
|
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
|
||||||
|
actions.append(
|
||||||
|
ToolAgentAction(
|
||||||
|
tool=function_name,
|
||||||
|
tool_input=tool_input,
|
||||||
|
log=log,
|
||||||
|
message_log=[message],
|
||||||
|
tool_call_id=tool_call["id"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return actions
|
||||||
|
|
||||||
|
|
||||||
|
class ToolsAgentOutputParser(MultiActionAgentOutputParser):
|
||||||
|
"""Parses a message into agent actions/finish.
|
||||||
|
|
||||||
|
If a tool_calls parameter is passed, then that is used to get
|
||||||
|
the tool names and tool inputs.
|
||||||
|
|
||||||
|
If one is not passed, then the AIMessage is assumed to be the final output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "tools-agent-output-parser"
|
||||||
|
|
||||||
|
def parse_result(
|
||||||
|
self, result: List[Generation], *, partial: bool = False
|
||||||
|
) -> Union[List[AgentAction], AgentFinish]:
|
||||||
|
if not isinstance(result[0], ChatGeneration):
|
||||||
|
raise ValueError("This output parser only works on ChatGeneration output")
|
||||||
|
message = result[0].message
|
||||||
|
return parse_ai_message_to_tool_action(message)
|
||||||
|
|
||||||
|
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
||||||
|
raise ValueError("Can only parse messages")
|
@ -0,0 +1,96 @@
|
|||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||||
|
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from langchain.agents.format_scratchpad.tools import (
|
||||||
|
format_to_tool_messages,
|
||||||
|
)
|
||||||
|
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
|
||||||
|
|
||||||
|
|
||||||
|
def create_tool_calling_agent(
|
||||||
|
llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate
|
||||||
|
) -> Runnable:
|
||||||
|
"""Create an agent that uses tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: LLM to use as the agent.
|
||||||
|
tools: Tools this agent has access to.
|
||||||
|
prompt: The prompt to use. See Prompt section below for more on the expected
|
||||||
|
input variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Runnable sequence representing an agent. It takes as input all the same input
|
||||||
|
variables as the prompt passed in does. It returns as output either an
|
||||||
|
AgentAction or AgentFinish.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.agents import AgentExecutor, create_tool_calling_agent, tool
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
("system", "You are a helpful assistant"),
|
||||||
|
MessagesPlaceholder("chat_history", optional=True),
|
||||||
|
("human", "{input}"),
|
||||||
|
MessagesPlaceholder("agent_scratchpad"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
model = ChatAnthropic(model="claude-3-opus-20240229")
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def magic_function(input: int) -> int:
|
||||||
|
\"\"\"Applies a magic function to an input.\"\"\"
|
||||||
|
return input + 2
|
||||||
|
|
||||||
|
tools = [magic_function]
|
||||||
|
|
||||||
|
agent = create_tool_calling_agent(model, tools, prompt)
|
||||||
|
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
||||||
|
|
||||||
|
agent_executor.invoke({"input": "what is the value of magic_function(3)?"})
|
||||||
|
|
||||||
|
# Using with chat history
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
agent_executor.invoke(
|
||||||
|
{
|
||||||
|
"input": "what's my name?",
|
||||||
|
"chat_history": [
|
||||||
|
HumanMessage(content="hi! my name is bob"),
|
||||||
|
AIMessage(content="Hello Bob! How can I assist you today?"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
Prompt:
|
||||||
|
|
||||||
|
The agent prompt must have an `agent_scratchpad` key that is a
|
||||||
|
``MessagesPlaceholder``. Intermediate agent actions and tool output
|
||||||
|
messages will be passed in here.
|
||||||
|
"""
|
||||||
|
missing_vars = {"agent_scratchpad"}.difference(prompt.input_variables)
|
||||||
|
if missing_vars:
|
||||||
|
raise ValueError(f"Prompt missing required variables: {missing_vars}")
|
||||||
|
|
||||||
|
if not hasattr(llm, "bind_tools"):
|
||||||
|
raise ValueError(
|
||||||
|
"This function requires a .bind_tools method be implemented on the LLM.",
|
||||||
|
)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
|
||||||
|
agent = (
|
||||||
|
RunnablePassthrough.assign(
|
||||||
|
agent_scratchpad=lambda x: format_to_tool_messages(x["intermediate_steps"])
|
||||||
|
)
|
||||||
|
| prompt
|
||||||
|
| llm_with_tools
|
||||||
|
| ToolsAgentOutputParser()
|
||||||
|
)
|
||||||
|
return agent
|
Loading…
Reference in New Issue