forked from Archives/langchain
Harrison/return intermediate (#1633)
Co-authored-by: Mario Kostelac <mario@intercom.io>
This commit is contained in:
parent
72b461e257
commit
aed9f9febe
@ -435,10 +435,6 @@ class AgentExecutor(Chain, BaseModel):
|
||||
llm_prefix="",
|
||||
observation_prefix=self.agent.observation_prefix,
|
||||
)
|
||||
return_direct = False
|
||||
if return_direct:
|
||||
# Set the log to "" because we do not want to log it.
|
||||
return AgentFinish({self.agent.return_values[0]: observation}, "")
|
||||
return output, observation
|
||||
|
||||
async def _atake_next_step(
|
||||
@ -483,9 +479,6 @@ class AgentExecutor(Chain, BaseModel):
|
||||
observation_prefix=self.agent.observation_prefix,
|
||||
)
|
||||
return_direct = False
|
||||
if return_direct:
|
||||
# Set the log to "" because we do not want to log it.
|
||||
return AgentFinish({self.agent.return_values[0]: observation}, "")
|
||||
return output, observation
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
@ -510,6 +503,10 @@ class AgentExecutor(Chain, BaseModel):
|
||||
return self._return(next_step_output, intermediate_steps)
|
||||
|
||||
intermediate_steps.append(next_step_output)
|
||||
# See if tool should return directly
|
||||
tool_return = self._get_tool_return(next_step_output)
|
||||
if tool_return is not None:
|
||||
return self._return(tool_return, intermediate_steps)
|
||||
iterations += 1
|
||||
output = self.agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
@ -538,8 +535,28 @@ class AgentExecutor(Chain, BaseModel):
|
||||
return await self._areturn(next_step_output, intermediate_steps)
|
||||
|
||||
intermediate_steps.append(next_step_output)
|
||||
# See if tool should return directly
|
||||
tool_return = self._get_tool_return(next_step_output)
|
||||
if tool_return is not None:
|
||||
return await self._areturn(tool_return, intermediate_steps)
|
||||
|
||||
iterations += 1
|
||||
output = self.agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
return await self._areturn(output, intermediate_steps)
|
||||
|
||||
def _get_tool_return(
|
||||
self, next_step_output: Tuple[AgentAction, str]
|
||||
) -> Optional[AgentFinish]:
|
||||
"""Check if the tool is a returning tool."""
|
||||
agent_action, observation = next_step_output
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
# Invalid tools won't be in the map, so we return False.
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
if name_to_tool_map[agent_action.tool].return_direct:
|
||||
return AgentFinish(
|
||||
{self.agent.return_values[0]: observation},
|
||||
"",
|
||||
)
|
||||
return None
|
||||
|
@ -230,6 +230,36 @@ def test_agent_tool_return_direct() -> None:
|
||||
assert output == "misalignment"
|
||||
|
||||
|
||||
def test_agent_tool_return_direct_in_intermediate_steps() -> None:
|
||||
"""Test agent using tools that return directly."""
|
||||
tool = "Search"
|
||||
responses = [
|
||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses)
|
||||
tools = [
|
||||
Tool(
|
||||
name="Search",
|
||||
func=lambda x: x,
|
||||
description="Useful for searching",
|
||||
return_direct=True,
|
||||
),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
fake_llm,
|
||||
agent="zero-shot-react-description",
|
||||
return_intermediate_steps=True,
|
||||
)
|
||||
|
||||
resp = agent("when was langchain made")
|
||||
assert resp["output"] == "misalignment"
|
||||
assert len(resp["intermediate_steps"]) == 1
|
||||
action, _action_intput = resp["intermediate_steps"][0]
|
||||
assert action.tool == "Search"
|
||||
|
||||
|
||||
def test_agent_with_new_prefix_suffix() -> None:
|
||||
"""Test agent initilization kwargs with new prefix and suffix."""
|
||||
fake_llm = FakeListLLM(
|
||||
|
Loading…
Reference in New Issue
Block a user