mirror of https://github.com/hwchase17/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
501 lines
16 KiB
Plaintext
501 lines
16 KiB
Plaintext
1 year ago
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "ba5f8741",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
1 year ago
|
"# Custom agent with tool retrieval\n",
|
||
1 year ago
|
"\n",
|
||
|
"The novel idea introduced in this notebook is the idea of using retrieval to select the set of tools to use to answer an agent query. This is useful when you have many many tools to select from. You cannot put the description of all the tools in the prompt (because of context length issues) so instead you dynamically select the N tools you do want to consider using at run time.\n",
|
||
|
"\n",
|
||
1 year ago
|
"In this notebook we will create a somewhat contrived example. We will have one legitimate tool (search) and then 99 fake tools which are just nonsense. We will then add a step in the prompt template that takes the user input and retrieves tool relevant to the query."
|
||
1 year ago
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "fea4812c",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Set up environment\n",
|
||
|
"\n",
|
||
|
"Do necessary imports, etc."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "9af9734e",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
10 months ago
|
"import re\n",
|
||
|
"from typing import Union\n",
|
||
|
"\n",
|
||
1 year ago
|
"from langchain.agents import (\n",
|
||
|
" AgentExecutor,\n",
|
||
|
" AgentOutputParser,\n",
|
||
10 months ago
|
" LLMSingleActionAgent,\n",
|
||
|
" Tool,\n",
|
||
1 year ago
|
")\n",
|
||
10 months ago
|
"from langchain.chains import LLMChain\n",
|
||
10 months ago
|
"from langchain.llms import OpenAI\n",
|
||
|
"from langchain.prompts import StringPromptTemplate\n",
|
||
1 year ago
|
"from langchain.schema import AgentAction, AgentFinish\n",
|
||
10 months ago
|
"from langchain.utilities import SerpAPIWrapper"
|
||
1 year ago
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "6df0253f",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Set up tools\n",
|
||
|
"\n",
|
||
1 year ago
|
"We will create one legitimate tool (search) and then 99 fake tools."
|
||
1 year ago
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "becda2a1",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Define which tools the agent can use to answer user queries\n",
|
||
|
"search = SerpAPIWrapper()\n",
|
||
|
"search_tool = Tool(\n",
|
||
1 year ago
|
" name=\"Search\",\n",
|
||
|
" func=search.run,\n",
|
||
|
" description=\"useful for when you need to answer questions about current events\",\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
1 year ago
|
"def fake_func(inp: str) -> str:\n",
|
||
|
" return \"foo\"\n",
|
||
1 year ago
|
"\n",
|
||
|
"\n",
|
||
1 year ago
|
"fake_tools = [\n",
|
||
|
" Tool(\n",
|
||
1 year ago
|
" name=f\"foo-{i}\",\n",
|
||
|
" func=fake_func,\n",
|
||
|
" description=f\"a silly function that you can use to get more information about the number {i}\",\n",
|
||
|
" )\n",
|
||
1 year ago
|
" for i in range(99)\n",
|
||
|
"]\n",
|
||
|
"ALL_TOOLS = [search_tool] + fake_tools"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "17362717",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Tool Retriever\n",
|
||
|
"\n",
|
||
1 year ago
|
"We will use a vector store to create embeddings for each tool description. Then, for an incoming query we can create embeddings for that query and do a similarity search for relevant tools."
|
||
1 year ago
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "77c4be4b",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||
10 months ago
|
"from langchain.schema import Document\n",
|
||
|
"from langchain.vectorstores import FAISS"
|
||
1 year ago
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "9092a158",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
1 year ago
|
"docs = [\n",
|
||
|
" Document(page_content=t.description, metadata={\"index\": i})\n",
|
||
|
" for i, t in enumerate(ALL_TOOLS)\n",
|
||
|
"]"
|
||
1 year ago
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "affc4e56",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"vector_store = FAISS.from_documents(docs, OpenAIEmbeddings())"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"id": "735a7566",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"retriever = vector_store.as_retriever()\n",
|
||
|
"\n",
|
||
1 year ago
|
"\n",
|
||
1 year ago
|
"def get_tools(query):\n",
|
||
|
" docs = retriever.get_relevant_documents(query)\n",
|
||
|
" return [ALL_TOOLS[d.metadata[\"index\"]] for d in docs]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "7699afd7",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We can now test this retriever to see if it seems to work."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"id": "425f2886",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
1 year ago
|
"[Tool(name='Search', description='useful for when you need to answer questions about current events', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<bound method SerpAPIWrapper.run of SerpAPIWrapper(search_engine=<class 'serpapi.google_search.GoogleSearch'>, params={'engine': 'google', 'google_domain': 'google.com', 'gl': 'us', 'hl': 'en'}, serpapi_api_key='', aiosession=None)>, coroutine=None),\n",
|
||
1 year ago
|
" Tool(name='foo-95', description='a silly function that you can use to get more information about the number 95', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<function fake_func at 0x15e5bd1f0>, coroutine=None),\n",
|
||
|
" Tool(name='foo-12', description='a silly function that you can use to get more information about the number 12', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<function fake_func at 0x15e5bd1f0>, coroutine=None),\n",
|
||
|
" Tool(name='foo-15', description='a silly function that you can use to get more information about the number 15', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<function fake_func at 0x15e5bd1f0>, coroutine=None)]"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 19,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"get_tools(\"whats the weather?\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"id": "4036dd19",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"[Tool(name='foo-13', description='a silly function that you can use to get more information about the number 13', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<function fake_func at 0x15e5bd1f0>, coroutine=None),\n",
|
||
|
" Tool(name='foo-12', description='a silly function that you can use to get more information about the number 12', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<function fake_func at 0x15e5bd1f0>, coroutine=None),\n",
|
||
|
" Tool(name='foo-14', description='a silly function that you can use to get more information about the number 14', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<function fake_func at 0x15e5bd1f0>, coroutine=None),\n",
|
||
|
" Tool(name='foo-11', description='a silly function that you can use to get more information about the number 11', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<function fake_func at 0x15e5bd1f0>, coroutine=None)]"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 20,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"get_tools(\"whats the number 13?\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "2e7a075c",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
1 year ago
|
"## Prompt template\n",
|
||
1 year ago
|
"\n",
|
||
|
"The prompt template is pretty standard, because we're not actually changing that much logic in the actual prompt template, but rather we are just changing how retrieval is done."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"id": "339b1bb8",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Set up the base template\n",
|
||
|
"template = \"\"\"Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:\n",
|
||
|
"\n",
|
||
|
"{tools}\n",
|
||
|
"\n",
|
||
|
"Use the following format:\n",
|
||
|
"\n",
|
||
|
"Question: the input question you must answer\n",
|
||
|
"Thought: you should always think about what to do\n",
|
||
|
"Action: the action to take, should be one of [{tool_names}]\n",
|
||
|
"Action Input: the input to the action\n",
|
||
|
"Observation: the result of the action\n",
|
||
|
"... (this Thought/Action/Action Input/Observation can repeat N times)\n",
|
||
|
"Thought: I now know the final answer\n",
|
||
|
"Final Answer: the final answer to the original input question\n",
|
||
|
"\n",
|
||
|
"Begin! Remember to speak as a pirate when giving your final answer. Use lots of \"Arg\"s\n",
|
||
|
"\n",
|
||
|
"Question: {input}\n",
|
||
|
"{agent_scratchpad}\"\"\""
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "1583acdc",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
1 year ago
|
"The custom prompt template now has the concept of a `tools_getter`, which we call on the input to select the tools to use."
|
||
1 year ago
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 52,
|
||
|
"id": "fd969d31",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from typing import Callable\n",
|
||
1 year ago
|
"\n",
|
||
|
"\n",
|
||
1 year ago
|
"# Set up a prompt template\n",
|
||
|
"class CustomPromptTemplate(StringPromptTemplate):\n",
|
||
|
" # The template to use\n",
|
||
|
" template: str\n",
|
||
|
" ############## NEW ######################\n",
|
||
|
" # The list of tools available\n",
|
||
|
" tools_getter: Callable\n",
|
||
1 year ago
|
"\n",
|
||
1 year ago
|
" def format(self, **kwargs) -> str:\n",
|
||
|
" # Get the intermediate steps (AgentAction, Observation tuples)\n",
|
||
|
" # Format them in a particular way\n",
|
||
|
" intermediate_steps = kwargs.pop(\"intermediate_steps\")\n",
|
||
|
" thoughts = \"\"\n",
|
||
|
" for action, observation in intermediate_steps:\n",
|
||
|
" thoughts += action.log\n",
|
||
|
" thoughts += f\"\\nObservation: {observation}\\nThought: \"\n",
|
||
|
" # Set the agent_scratchpad variable to that value\n",
|
||
|
" kwargs[\"agent_scratchpad\"] = thoughts\n",
|
||
|
" ############## NEW ######################\n",
|
||
|
" tools = self.tools_getter(kwargs[\"input\"])\n",
|
||
|
" # Create a tools variable from the list of tools provided\n",
|
||
1 year ago
|
" kwargs[\"tools\"] = \"\\n\".join(\n",
|
||
|
" [f\"{tool.name}: {tool.description}\" for tool in tools]\n",
|
||
|
" )\n",
|
||
1 year ago
|
" # Create a list of tool names for the tools provided\n",
|
||
|
" kwargs[\"tool_names\"] = \", \".join([tool.name for tool in tools])\n",
|
||
|
" return self.template.format(**kwargs)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 53,
|
||
|
"id": "798ef9fb",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"prompt = CustomPromptTemplate(\n",
|
||
|
" template=template,\n",
|
||
|
" tools_getter=get_tools,\n",
|
||
|
" # This omits the `agent_scratchpad`, `tools`, and `tool_names` variables because those are generated dynamically\n",
|
||
|
" # This includes the `intermediate_steps` variable because that is needed\n",
|
||
1 year ago
|
" input_variables=[\"input\", \"intermediate_steps\"],\n",
|
||
1 year ago
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "ef3a1af3",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
1 year ago
|
"## Output parser\n",
|
||
1 year ago
|
"\n",
|
||
|
"The output parser is unchanged from the previous notebook, since we are not changing anything about the output format."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 54,
|
||
|
"id": "7c6fe0d3",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class CustomOutputParser(AgentOutputParser):\n",
|
||
|
" def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:\n",
|
||
|
" # Check if agent should finish\n",
|
||
|
" if \"Final Answer:\" in llm_output:\n",
|
||
|
" return AgentFinish(\n",
|
||
|
" # Return values is generally always a dictionary with a single `output` key\n",
|
||
|
" # It is not recommended to try anything else at the moment :)\n",
|
||
|
" return_values={\"output\": llm_output.split(\"Final Answer:\")[-1].strip()},\n",
|
||
|
" log=llm_output,\n",
|
||
|
" )\n",
|
||
|
" # Parse out the action and action input\n",
|
||
1 year ago
|
" regex = r\"Action\\s*\\d*\\s*:(.*?)\\nAction\\s*\\d*\\s*Input\\s*\\d*\\s*:[\\s]*(.*)\"\n",
|
||
1 year ago
|
" match = re.search(regex, llm_output, re.DOTALL)\n",
|
||
|
" if not match:\n",
|
||
|
" raise ValueError(f\"Could not parse LLM output: `{llm_output}`\")\n",
|
||
|
" action = match.group(1).strip()\n",
|
||
|
" action_input = match.group(2)\n",
|
||
|
" # Return the action and action input\n",
|
||
1 year ago
|
" return AgentAction(\n",
|
||
|
" tool=action, tool_input=action_input.strip(\" \").strip('\"'), log=llm_output\n",
|
||
|
" )"
|
||
1 year ago
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 55,
|
||
|
"id": "d278706a",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"output_parser = CustomOutputParser()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "170587b1",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Set up LLM, stop sequence, and the agent\n",
|
||
|
"\n",
|
||
1 year ago
|
"Also the same as the previous notebook."
|
||
1 year ago
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 56,
|
||
|
"id": "f9d4c374",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"llm = OpenAI(temperature=0)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 57,
|
||
|
"id": "9b1cc2a2",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# LLM chain consisting of the LLM and a prompt\n",
|
||
|
"llm_chain = LLMChain(llm=llm, prompt=prompt)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 58,
|
||
|
"id": "e4f5092f",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
1 year ago
|
"tools = get_tools(\"whats the weather?\")\n",
|
||
1 year ago
|
"tool_names = [tool.name for tool in tools]\n",
|
||
|
"agent = LLMSingleActionAgent(\n",
|
||
1 year ago
|
" llm_chain=llm_chain,\n",
|
||
1 year ago
|
" output_parser=output_parser,\n",
|
||
1 year ago
|
" stop=[\"\\nObservation:\"],\n",
|
||
|
" allowed_tools=tool_names,\n",
|
||
1 year ago
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "aa8a5326",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Use the Agent\n",
|
||
|
"\n",
|
||
|
"Now we can use it!"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 59,
|
||
|
"id": "490604e9",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
1 year ago
|
"agent_executor = AgentExecutor.from_agent_and_tools(\n",
|
||
|
" agent=agent, tools=tools, verbose=True\n",
|
||
|
")"
|
||
1 year ago
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 60,
|
||
|
"id": "653b1617",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||
|
"\u001b[32;1m\u001b[1;3mThought: I need to find out what the weather is in SF\n",
|
||
|
"Action: Search\n",
|
||
|
"Action Input: Weather in SF\u001b[0m\n",
|
||
|
"\n",
|
||
|
"Observation:\u001b[36;1m\u001b[1;3mMostly cloudy skies early, then partly cloudy in the afternoon. High near 60F. ENE winds shifting to W at 10 to 15 mph. Humidity71%. UV Index6 of 10.\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||
|
"Final Answer: 'Arg, 'tis mostly cloudy skies early, then partly cloudy in the afternoon. High near 60F. ENE winds shiftin' to W at 10 to 15 mph. Humidity71%. UV Index6 of 10.\u001b[0m\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"\"'Arg, 'tis mostly cloudy skies early, then partly cloudy in the afternoon. High near 60F. ENE winds shiftin' to W at 10 to 15 mph. Humidity71%. UV Index6 of 10.\""
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 60,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"agent_executor.run(\"What's the weather in SF?\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "2481ee76",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
9 months ago
|
"version": "3.10.1"
|
||
1 year ago
|
},
|
||
|
"vscode": {
|
||
|
"interpreter": {
|
||
|
"hash": "18784188d7ecd866c0586ac068b02361a6896dc3a29b64f5cc957f09c590acef"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|