add return_direct flag to tool (#537)

adds a return_direct flag to tools, which just returns the tool output
as the final output
pull/552/head^2
Harrison Chase 1 year ago committed by GitHub
parent 1f248c47f3
commit 4974f49bb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -44,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"id": "36ed392e",
"metadata": {},
"outputs": [],
@ -63,7 +63,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"id": "56ff7670",
"metadata": {},
"outputs": [],
@ -94,7 +94,6 @@
"source": [
"# Construct the agent. We will use the default agent type here.\n",
"# See documentation for a full list of options.\n",
"llm = OpenAI(temperature=0)\n",
"agent = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)"
]
},
@ -248,10 +247,86 @@
"agent.run(\"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\")"
]
},
{
"cell_type": "markdown",
"id": "a8d95999",
"metadata": {},
"source": [
"## Using tools to return directly\n",
"\n",
"Often, it can be desirable to have a tool output returned directly to the user, if it's called. You can do this easily with LangChain by setting the `return_direct` flag for a tool to be True."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c8f65640",
"metadata": {},
"outputs": [],
"source": [
"llm_math_chain = LLMMathChain(llm=llm)\n",
"tools = [\n",
" Tool(\n",
" name=\"Calculator\",\n",
" func=llm_math_chain.run,\n",
" description=\"useful for when you need to answer questions about math\",\n",
" return_direct=True\n",
" )\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4bd1c4bb",
"metadata": {},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)\n",
"agent = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "52fe0594",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to calculate this\n",
"Action: Calculator\n",
"Action Input: 2**.12\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.2599210498948732\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Answer: 1.2599210498948732'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"whats 2**.12\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3450512e",
"id": "7cc1b875",
"metadata": {},
"outputs": [],
"source": []
@ -259,7 +334,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.0 64-bit ('llm-env')",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@ -273,7 +348,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
"version": "3.10.9"
},
"vscode": {
"interpreter": {

@ -239,12 +239,20 @@ class AgentExecutor(Chain, BaseModel):
else:
return iterations < self.max_iterations
def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]:
if self.verbose:
self.callback_manager.on_agent_finish(output, color="green")
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps
return final_output
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run text through and get agent response."""
# Do any preparation necessary when receiving a new input.
self.agent.prepare_for_new_call()
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool.func for tool in self.tools}
name_to_tool_map = {tool.name: tool for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green"]
@ -258,24 +266,20 @@ class AgentExecutor(Chain, BaseModel):
output = self.agent.plan(intermediate_steps, **inputs)
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
if self.verbose:
self.callback_manager.on_agent_finish(output, color="green")
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps
return final_output
return self._return(output, intermediate_steps)
# And then we lookup the tool
# Otherwise we lookup the tool
if output.tool in name_to_tool_map:
chain = name_to_tool_map[output.tool]
tool = name_to_tool_map[output.tool]
if self.verbose:
self.callback_manager.on_tool_start(
{"name": str(chain)[:60] + "..."}, output, color="green"
{"name": str(tool.func)[:60] + "..."}, output, color="green"
)
try:
# We then call the tool on the tool input to get an observation
observation = chain(output.tool_input)
observation = tool.func(output.tool_input)
color = color_mapping[output.tool]
return_direct = tool.return_direct
except Exception as e:
if self.verbose:
self.callback_manager.on_tool_error(e)
@ -287,21 +291,22 @@ class AgentExecutor(Chain, BaseModel):
)
observation = f"{output.tool} is not a valid tool, try another one."
color = None
return_direct = False
if self.verbose:
llm_prefix = "" if return_direct else self.agent.llm_prefix
self.callback_manager.on_tool_end(
observation,
color=color,
observation_prefix=self.agent.observation_prefix,
llm_prefix=self.agent.llm_prefix,
llm_prefix=llm_prefix,
)
intermediate_steps.append((output, observation))
if return_direct:
# Set the log to "" because we do not want to log it.
output = AgentFinish({self.agent.return_values[0]: observation}, "")
return self._return(output, intermediate_steps)
iterations += 1
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
if self.verbose:
self.callback_manager.on_agent_finish(output, color="green")
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps
return final_output
return self._return(output, intermediate_steps)

@ -10,3 +10,4 @@ class Tool:
name: str
func: Callable[[str], str]
description: Optional[str] = None
return_direct: bool = False

@ -169,3 +169,24 @@ def test_agent_with_callbacks_not_verbose() -> None:
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0
def test_agent_tool_return_direct() -> None:
"""Test agent using tools that return directly."""
tool = "Search"
responses = [
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
]
fake_llm = FakeListLLM(responses=responses)
tools = [
Tool("Search", lambda x: x, "Useful for searching", return_direct=True),
]
agent = initialize_agent(
tools,
fake_llm,
agent="zero-shot-react-description",
)
output = agent.run("when was langchain made")
assert output == "misalignment"

Loading…
Cancel
Save