mirror of https://github.com/hwchase17/langchain
update agents to use tool call messages (#20074)
```python from langchain.agents import AgentExecutor, create_tool_calling_agent, tool from langchain_anthropic import ChatAnthropic from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder prompt = ChatPromptTemplate.from_messages( [ ("system", "You are a helpful assistant"), MessagesPlaceholder("chat_history", optional=True), ("human", "{input}"), MessagesPlaceholder("agent_scratchpad"), ] ) model = ChatAnthropic(model="claude-3-opus-20240229") @tool def magic_function(input: int) -> int: """Applies a magic function to an input.""" return input + 2 tools = [magic_function] agent = create_tool_calling_agent(model, tools, prompt) agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) agent_executor.invoke({"input": "what is the value of magic_function(3)?"}) ``` ``` > Entering new AgentExecutor chain... Invoking: `magic_function` with `{'input': 3}` responded: [{'text': '<thinking>\nThe user has asked for the value of magic_function applied to the input 3. Looking at the available tools, magic_function is the relevant one to use here, as it takes an integer input and returns an integer output.\n\nThe magic_function has one required parameter:\n- input (integer)\n\nThe user has directly provided the value 3 for the input parameter. Since the required parameter is present, we can proceed with calling the function.\n</thinking>', 'type': 'text'}, {'id': 'toolu_01HsTheJPA5mcipuFDBbJ1CW', 'input': {'input': 3}, 'name': 'magic_function', 'type': 'tool_use'}] 5 Therefore, the value of magic_function(3) is 5. > Finished chain. {'input': 'what is the value of magic_function(3)?', 'output': 'Therefore, the value of magic_function(3) is 5.'} ``` --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>pull/20290/head
parent
9eb6f538f0
commit
21c1ce0bc1
@ -1,59 +1,5 @@
|
||||
import json
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction
|
||||
|
||||
|
||||
def _create_tool_message(
|
||||
agent_action: OpenAIToolAgentAction, observation: str
|
||||
) -> ToolMessage:
|
||||
"""Convert agent action and observation into a function message.
|
||||
Args:
|
||||
agent_action: the tool invocation request from the agent
|
||||
observation: the result of the tool invocation
|
||||
Returns:
|
||||
FunctionMessage that corresponds to the original tool invocation
|
||||
"""
|
||||
if not isinstance(observation, str):
|
||||
try:
|
||||
content = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception:
|
||||
content = str(observation)
|
||||
else:
|
||||
content = observation
|
||||
return ToolMessage(
|
||||
tool_call_id=agent_action.tool_call_id,
|
||||
content=content,
|
||||
additional_kwargs={"name": agent_action.tool},
|
||||
from langchain.agents.format_scratchpad.tools import (
|
||||
format_to_tool_messages as format_to_openai_tool_messages,
|
||||
)
|
||||
|
||||
|
||||
def format_to_openai_tool_messages(
|
||||
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
|
||||
Returns:
|
||||
list of messages to send to the LLM for the next prediction
|
||||
|
||||
"""
|
||||
messages = []
|
||||
for agent_action, observation in intermediate_steps:
|
||||
if isinstance(agent_action, OpenAIToolAgentAction):
|
||||
new_messages = list(agent_action.message_log) + [
|
||||
_create_tool_message(agent_action, observation)
|
||||
]
|
||||
messages.extend([new for new in new_messages if new not in messages])
|
||||
else:
|
||||
messages.append(AIMessage(content=agent_action.log))
|
||||
return messages
|
||||
__all__ = ["format_to_openai_tool_messages"]
|
||||
|
@ -0,0 +1,59 @@
|
||||
import json
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
from langchain.agents.output_parsers.tools import ToolAgentAction
|
||||
|
||||
|
||||
def _create_tool_message(
|
||||
agent_action: ToolAgentAction, observation: str
|
||||
) -> ToolMessage:
|
||||
"""Convert agent action and observation into a function message.
|
||||
Args:
|
||||
agent_action: the tool invocation request from the agent
|
||||
observation: the result of the tool invocation
|
||||
Returns:
|
||||
FunctionMessage that corresponds to the original tool invocation
|
||||
"""
|
||||
if not isinstance(observation, str):
|
||||
try:
|
||||
content = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception:
|
||||
content = str(observation)
|
||||
else:
|
||||
content = observation
|
||||
return ToolMessage(
|
||||
tool_call_id=agent_action.tool_call_id,
|
||||
content=content,
|
||||
additional_kwargs={"name": agent_action.tool},
|
||||
)
|
||||
|
||||
|
||||
def format_to_tool_messages(
|
||||
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
|
||||
Returns:
|
||||
list of messages to send to the LLM for the next prediction
|
||||
|
||||
"""
|
||||
messages = []
|
||||
for agent_action, observation in intermediate_steps:
|
||||
if isinstance(agent_action, ToolAgentAction):
|
||||
new_messages = list(agent_action.message_log) + [
|
||||
_create_tool_message(agent_action, observation)
|
||||
]
|
||||
messages.extend([new for new in new_messages if new not in messages])
|
||||
else:
|
||||
messages.append(AIMessage(content=agent_action.log))
|
||||
return messages
|
@ -0,0 +1,102 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ToolCall,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
|
||||
from langchain.agents.agent import MultiActionAgentOutputParser
|
||||
|
||||
|
||||
class ToolAgentAction(AgentActionMessageLog):
|
||||
tool_call_id: str
|
||||
"""Tool call that this message is responding to."""
|
||||
|
||||
|
||||
def parse_ai_message_to_tool_action(
|
||||
message: BaseMessage,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
"""Parse an AI message potentially containing tool_calls."""
|
||||
if not isinstance(message, AIMessage):
|
||||
raise TypeError(f"Expected an AI message got {type(message)}")
|
||||
|
||||
actions: List = []
|
||||
if message.tool_calls:
|
||||
tool_calls = message.tool_calls
|
||||
else:
|
||||
if not message.additional_kwargs.get("tool_calls"):
|
||||
return AgentFinish(
|
||||
return_values={"output": message.content}, log=str(message.content)
|
||||
)
|
||||
# Best-effort parsing
|
||||
tool_calls = []
|
||||
for tool_call in message.additional_kwargs["tool_calls"]:
|
||||
function = tool_call["function"]
|
||||
function_name = function["name"]
|
||||
try:
|
||||
args = json.loads(function["arguments"] or "{}")
|
||||
tool_calls.append(
|
||||
ToolCall(name=function_name, args=args, id=tool_call["id"])
|
||||
)
|
||||
except JSONDecodeError:
|
||||
raise OutputParserException(
|
||||
f"Could not parse tool input: {function} because "
|
||||
f"the `arguments` is not valid JSON."
|
||||
)
|
||||
for tool_call in tool_calls:
|
||||
# HACK HACK HACK:
|
||||
# The code that encodes tool input into Open AI uses a special variable
|
||||
# name called `__arg1` to handle old style tools that do not expose a
|
||||
# schema and expect a single string argument as an input.
|
||||
# We unpack the argument here if it exists.
|
||||
# Open AI does not support passing in a JSON array as an argument.
|
||||
function_name = tool_call["name"]
|
||||
_tool_input = tool_call["args"]
|
||||
if "__arg1" in _tool_input:
|
||||
tool_input = _tool_input["__arg1"]
|
||||
else:
|
||||
tool_input = _tool_input
|
||||
|
||||
content_msg = f"responded: {message.content}\n" if message.content else "\n"
|
||||
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
|
||||
actions.append(
|
||||
ToolAgentAction(
|
||||
tool=function_name,
|
||||
tool_input=tool_input,
|
||||
log=log,
|
||||
message_log=[message],
|
||||
tool_call_id=tool_call["id"],
|
||||
)
|
||||
)
|
||||
return actions
|
||||
|
||||
|
||||
class ToolsAgentOutputParser(MultiActionAgentOutputParser):
|
||||
"""Parses a message into agent actions/finish.
|
||||
|
||||
If a tool_calls parameter is passed, then that is used to get
|
||||
the tool names and tool inputs.
|
||||
|
||||
If one is not passed, then the AIMessage is assumed to be the final output.
|
||||
"""
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "tools-agent-output-parser"
|
||||
|
||||
def parse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
if not isinstance(result[0], ChatGeneration):
|
||||
raise ValueError("This output parser only works on ChatGeneration output")
|
||||
message = result[0].message
|
||||
return parse_ai_message_to_tool_action(message)
|
||||
|
||||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
||||
raise ValueError("Can only parse messages")
|
@ -0,0 +1,96 @@
|
||||
from typing import Sequence
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain.agents.format_scratchpad.tools import (
|
||||
format_to_tool_messages,
|
||||
)
|
||||
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
|
||||
|
||||
|
||||
def create_tool_calling_agent(
|
||||
llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate
|
||||
) -> Runnable:
|
||||
"""Create an agent that uses tools.
|
||||
|
||||
Args:
|
||||
llm: LLM to use as the agent.
|
||||
tools: Tools this agent has access to.
|
||||
prompt: The prompt to use. See Prompt section below for more on the expected
|
||||
input variables.
|
||||
|
||||
Returns:
|
||||
A Runnable sequence representing an agent. It takes as input all the same input
|
||||
variables as the prompt passed in does. It returns as output either an
|
||||
AgentAction or AgentFinish.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.agents import AgentExecutor, create_tool_calling_agent, tool
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are a helpful assistant"),
|
||||
MessagesPlaceholder("chat_history", optional=True),
|
||||
("human", "{input}"),
|
||||
MessagesPlaceholder("agent_scratchpad"),
|
||||
]
|
||||
)
|
||||
model = ChatAnthropic(model="claude-3-opus-20240229")
|
||||
|
||||
@tool
|
||||
def magic_function(input: int) -> int:
|
||||
\"\"\"Applies a magic function to an input.\"\"\"
|
||||
return input + 2
|
||||
|
||||
tools = [magic_function]
|
||||
|
||||
agent = create_tool_calling_agent(model, tools, prompt)
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
agent_executor.invoke({"input": "what is the value of magic_function(3)?"})
|
||||
|
||||
# Using with chat history
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
agent_executor.invoke(
|
||||
{
|
||||
"input": "what's my name?",
|
||||
"chat_history": [
|
||||
HumanMessage(content="hi! my name is bob"),
|
||||
AIMessage(content="Hello Bob! How can I assist you today?"),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
Prompt:
|
||||
|
||||
The agent prompt must have an `agent_scratchpad` key that is a
|
||||
``MessagesPlaceholder``. Intermediate agent actions and tool output
|
||||
messages will be passed in here.
|
||||
"""
|
||||
missing_vars = {"agent_scratchpad"}.difference(prompt.input_variables)
|
||||
if missing_vars:
|
||||
raise ValueError(f"Prompt missing required variables: {missing_vars}")
|
||||
|
||||
if not hasattr(llm, "bind_tools"):
|
||||
raise ValueError(
|
||||
"This function requires a .bind_tools method be implemented on the LLM.",
|
||||
)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
|
||||
agent = (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: format_to_tool_messages(x["intermediate_steps"])
|
||||
)
|
||||
| prompt
|
||||
| llm_with_tools
|
||||
| ToolsAgentOutputParser()
|
||||
)
|
||||
return agent
|
Loading…
Reference in New Issue