Compare commits

...

1 Commits

Author SHA1 Message Date
Harrison Chase a1fec47d9b track intermediate steps 1 year ago

@ -13,6 +13,26 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "51cbc105",
"metadata": {},
"outputs": [],
"source": [
"from langchain import OpenAI, SerpAPIWrapper\n",
"from langchain.agents import initialize_agent, Tool\n",
"\n",
"llm = OpenAI(temperature=0)\n",
"search = SerpAPIWrapper()\n",
"tools = [\n",
" Tool(\n",
" name=\"Intermediate Answer\",\n",
" func=search.run\n",
" )\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7e3b513e",
"metadata": {},
"outputs": [
@ -39,24 +59,12 @@
"'El Palmar, Spain'"
]
},
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain import OpenAI, SerpAPIWrapper\n",
"from langchain.agents import initialize_agent, Tool\n",
"\n",
"llm = OpenAI(temperature=0)\n",
"search = SerpAPIWrapper()\n",
"tools = [\n",
" Tool(\n",
" name=\"Intermediate Answer\",\n",
" func=search.run\n",
" )\n",
"]\n",
"\n",
"self_ask_with_search = initialize_agent(tools, llm, agent=\"self-ask-with-search\", verbose=True)\n",
"\n",
"self_ask_with_search.run(\"What is the hometown of the reigning men's U.S. Open champion?\")"
@ -64,9 +72,51 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "683d69e7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new SelfAskWithSearchAgent chain...\u001b[0m\n",
"What is the hometown of the reigning men's U.S. Open champion?\n",
"Are follow up questions needed here:\u001b[32;1m\u001b[1;3m Yes.\n",
"Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n",
"Intermediate answer: \u001b[36;1m\u001b[1;3mCarlos Alcaraz\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mFollow up: Where is Carlos Alcaraz from?\u001b[0m\n",
"Intermediate answer: \u001b[36;1m\u001b[1;3mEl Palmar, Spain\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mSo the final answer is: El Palmar, Spain\u001b[0m\n",
"\u001b[1m> Finished SelfAskWithSearchAgent chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': \"What is the hometown of the reigning men's U.S. Open champion?\",\n",
" 'output': 'El Palmar, Spain',\n",
" 'intermediate_steps': \"What is the hometown of the reigning men's U.S. Open champion?\\nAre follow up questions needed here: Yes.\\nFollow up: Who is the reigning men's U.S. Open champion?\\nIntermediate answer: Carlos Alcaraz\\nFollow up: Where is Carlos Alcaraz from?\\nIntermediate answer: El Palmar, Spain\\nSo the final answer is: El Palmar, Spain\"}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"self_ask_with_search = initialize_agent(tools, llm, agent=\"self-ask-with-search\", verbose=True, return_intermediate_steps=True)\n",
"\n",
"self_ask_with_search({\"input\": \"What is the hometown of the reigning men's U.S. Open champion?\"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b30a1b9a",
"metadata": {},
"outputs": [],
"source": []
}

@ -26,6 +26,8 @@ class Agent(Chain, BaseModel, ABC):
prompt: ClassVar[BasePromptTemplate]
llm_chain: LLMChain
tools: List[Tool]
return_intermediate_steps: bool = False
intermediate_steps_key: str = "intermediate_steps" #: :meta private:
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
@ -43,7 +45,10 @@ class Agent(Chain, BaseModel, ABC):
:meta private:
"""
return [self.output_key]
if self.return_intermediate_steps:
return [self.output_key, self.intermediate_steps_key]
else:
return [self.output_key]
@property
@abstractmethod
@ -146,7 +151,10 @@ class Agent(Chain, BaseModel, ABC):
chained_input.add(output.log, color="green")
# If the tool chosen is the finishing tool, then we end and return.
if output.tool == self.finish_tool_name:
return {self.output_key: output.tool_input}
final_output = {self.output_key: output.tool_input}
if self.return_intermediate_steps:
final_output[self.intermediate_steps_key] = chained_input.input
return final_output
# Otherwise we lookup the tool
chain = name_to_tool_map[output.tool]
# We then call the tool on the tool input to get an observation

Loading…
Cancel
Save