forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
389 lines
12 KiB
Plaintext
389 lines
12 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ba5f8741",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Custom LLM Agent\n",
|
|
"\n",
|
|
"This notebook goes through how to create your own custom LLM agent.\n",
|
|
"\n",
|
|
"An LLM agent consists of three parts:\n",
|
|
"\n",
|
|
"- PromptTemplate: This is the prompt template that can be used to instruct the language model on what to do\n",
|
|
"- LLM: This is the language model that powers the agent\n",
|
|
"- `stop` sequence: Instructs the LLM to stop generating as soon as this string is found\n",
|
|
"- OutputParser: This determines how to parse the LLMOutput into an AgentAction or AgentFinish object\n",
|
|
"\n",
|
|
"\n",
|
|
"The LLMAgent is used in an AgentExecutor. This AgentExecutor can largely be thought of as a loop that:\n",
|
|
"1. Passes user input and any previous steps to the Agent (in this case, the LLMAgent)\n",
|
|
"2. If the Agent returns an `AgentFinish`, then return that directly to the user\n",
|
|
"3. If the Agent returns an `AgentAction`, then use that to call a tool and get an `Observation`\n",
|
|
"4. Repeat, passing the `AgentAction` and `Observation` back to the Agent until an `AgentFinish` is emitted.\n",
|
|
" \n",
|
|
"`AgentAction` is a response that consists of `action` and `action_input`. `action` refers to which tool to use, and `action_input` refers to the input to that tool. `log` can also be provided as more context (that can be used for logging, tracing, etc).\n",
|
|
"\n",
|
|
"`AgentFinish` is a response that contains the final message to be sent back to the user. This should be used to end an agent run.\n",
|
|
" \n",
|
|
"In this notebook we walk through how to create a custom LLM agent."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fea4812c",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Set up environment\n",
|
|
"\n",
|
|
"Do necessary imports, etc."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "9af9734e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser\n",
|
|
"from langchain.prompts import StringPromptTemplate\n",
|
|
"from langchain import OpenAI, SerpAPIWrapper, LLMChain\n",
|
|
"from typing import List, Union\n",
|
|
"from langchain.schema import AgentAction, AgentFinish\n",
|
|
"import re"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6df0253f",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Set up tool\n",
|
|
"\n",
|
|
"Set up any tools the agent may want to use. This may be necessary to put in the prompt (so that the agent knows to use these tools)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"id": "becda2a1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Define which tools the agent can use to answer user queries\n",
|
|
"search = SerpAPIWrapper()\n",
|
|
"tools = [\n",
|
|
" Tool(\n",
|
|
" name = \"Search\",\n",
|
|
" func=search.run,\n",
|
|
" description=\"useful for when you need to answer questions about current events\"\n",
|
|
" )\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2e7a075c",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Prompt Template\n",
|
|
"\n",
|
|
"This instructs the agent on what to do. Generally, the template should incorporate:\n",
|
|
" \n",
|
|
"- `tools`: which tools the agent has access and how and when to call them.\n",
|
|
"- `intermediate_steps`: These are tuples of previous (`AgentAction`, `Observation`) pairs. These are generally not passed directly to the model, but the prompt template formats them in a specific way.\n",
|
|
"- `input`: generic user input"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "339b1bb8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Set up the base template\n",
|
|
"template = \"\"\"Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:\n",
|
|
"\n",
|
|
"{tools}\n",
|
|
"\n",
|
|
"Use the following format:\n",
|
|
"\n",
|
|
"Question: the input question you must answer\n",
|
|
"Thought: you should always think about what to do\n",
|
|
"Action: the action to take, should be one of [{tool_names}]\n",
|
|
"Action Input: the input to the action\n",
|
|
"Observation: the result of the action\n",
|
|
"... (this Thought/Action/Action Input/Observation can repeat N times)\n",
|
|
"Thought: I now know the final answer\n",
|
|
"Final Answer: the final answer to the original input question\n",
|
|
"\n",
|
|
"Begin! Remember to speak as a pirate when giving your final answer. Use lots of \"Arg\"s\n",
|
|
"\n",
|
|
"Question: {input}\n",
|
|
"{agent_scratchpad}\"\"\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "fd969d31",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Set up a prompt template\n",
|
|
"class CustomPromptTemplate(StringPromptTemplate):\n",
|
|
" # The template to use\n",
|
|
" template: str\n",
|
|
" # The list of tools available\n",
|
|
" tools: List[Tool]\n",
|
|
" \n",
|
|
" def format(self, **kwargs) -> str:\n",
|
|
" # Get the intermediate steps (AgentAction, Observation tuples)\n",
|
|
" # Format them in a particular way\n",
|
|
" intermediate_steps = kwargs.pop(\"intermediate_steps\")\n",
|
|
" thoughts = \"\"\n",
|
|
" for action, observation in intermediate_steps:\n",
|
|
" thoughts += action.log\n",
|
|
" thoughts += f\"\\nObservation: {observation}\\nThought: \"\n",
|
|
" # Set the agent_scratchpad variable to that value\n",
|
|
" kwargs[\"agent_scratchpad\"] = thoughts\n",
|
|
" # Create a tools variable from the list of tools provided\n",
|
|
" kwargs[\"tools\"] = \"\\n\".join([f\"{tool.name}: {tool.description}\" for tool in self.tools])\n",
|
|
" # Create a list of tool names for the tools provided\n",
|
|
" kwargs[\"tool_names\"] = \", \".join([tool.name for tool in self.tools])\n",
|
|
" return self.template.format(**kwargs)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "798ef9fb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"prompt = CustomPromptTemplate(\n",
|
|
" template=template,\n",
|
|
" tools=tools,\n",
|
|
" # This omits the `agent_scratchpad`, `tools`, and `tool_names` variables because those are generated dynamically\n",
|
|
" # This includes the `intermediate_steps` variable because that is needed\n",
|
|
" input_variables=[\"input\", \"intermediate_steps\"]\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ef3a1af3",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Output Parser\n",
|
|
"\n",
|
|
"The output parser is responsible for parsing the LLM output into `AgentAction` and `AgentFinish`. This usually depends heavily on the prompt used.\n",
|
|
"\n",
|
|
"This is where you can change the parsing to do retries, handle whitespace, etc"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "7c6fe0d3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class CustomOutputParser(AgentOutputParser):\n",
|
|
" \n",
|
|
" def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:\n",
|
|
" # Check if agent should finish\n",
|
|
" if \"Final Answer:\" in llm_output:\n",
|
|
" return AgentFinish(\n",
|
|
" # Return values is generally always a dictionary with a single `output` key\n",
|
|
" # It is not recommended to try anything else at the moment :)\n",
|
|
" return_values={\"output\": llm_output.split(\"Final Answer:\")[-1].strip()},\n",
|
|
" log=llm_output,\n",
|
|
" )\n",
|
|
" # Parse out the action and action input\n",
|
|
" regex = r\"Action: (.*?)[\\n]*Action Input:[\\s]*(.*)\"\n",
|
|
" match = re.search(regex, llm_output, re.DOTALL)\n",
|
|
" if not match:\n",
|
|
" raise ValueError(f\"Could not parse LLM output: `{llm_output}`\")\n",
|
|
" action = match.group(1).strip()\n",
|
|
" action_input = match.group(2)\n",
|
|
" # Return the action and action input\n",
|
|
" return AgentAction(tool=action, tool_input=action_input.strip(\" \").strip('\"'), log=llm_output)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "d278706a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"output_parser = CustomOutputParser()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "170587b1",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Set up LLM\n",
|
|
"\n",
|
|
"Choose the LLM you want to use!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "f9d4c374",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"llm = OpenAI(temperature=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "caeab5e4",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Define the stop sequence\n",
|
|
"\n",
|
|
"This is important because it tells the LLM when to stop generation.\n",
|
|
"\n",
|
|
"This depends heavily on the prompt and model you are using. Generally, you want this to be whatever token you use in the prompt to denote the start of an `Observation` (otherwise, the LLM may hallucinate an observation for you)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "34be9f65",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Set up the Agent\n",
|
|
"\n",
|
|
"We can now combine everything to set up our agent"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "9b1cc2a2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# LLM chain consisting of the LLM and a prompt\n",
|
|
"llm_chain = LLMChain(llm=llm, prompt=prompt)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "e4f5092f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"tool_names = [tool.name for tool in tools]\n",
|
|
"agent = LLMSingleActionAgent(\n",
|
|
" llm_chain=llm_chain, \n",
|
|
" output_parser=output_parser,\n",
|
|
" stop=[\"\\nObservation:\"], \n",
|
|
" allowed_tools=tool_names\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "aa8a5326",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Use the Agent\n",
|
|
"\n",
|
|
"Now we can use it!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"id": "490604e9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "653b1617",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mAction: Search\n",
|
|
"Action Input: Population of Canada in 2023\u001b[0m\n",
|
|
"\n",
|
|
"Observation:\u001b[36;1m\u001b[1;3m38,648,380\u001b[0m\u001b[32;1m\u001b[1;3m That's a lot of people!\n",
|
|
"Final Answer: Arrr, there be 38,648,380 people livin' in Canada come 2023!\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"\"Arrr, there be 38,648,380 people livin' in Canada come 2023!\""
|
|
]
|
|
},
|
|
"execution_count": 27,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent_executor.run(\"How many people live in canada as of 2023?\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "adefb4c2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.1"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "18784188d7ecd866c0586ac068b02361a6896dc3a29b64f5cc957f09c590acef"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|