From fe572a5a0d306c1eb000998e282dc763d2878f43 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 2 Apr 2023 14:04:09 -0700 Subject: [PATCH] chat model example (#2310) --- .../agents/agents/custom_llm_chat_agent.ipynb | 395 ++++++++++++++++++ langchain/prompts/__init__.py | 2 + langchain/prompts/chat.py | 19 +- 3 files changed, 413 insertions(+), 3 deletions(-) create mode 100644 docs/modules/agents/agents/custom_llm_chat_agent.ipynb diff --git a/docs/modules/agents/agents/custom_llm_chat_agent.ipynb b/docs/modules/agents/agents/custom_llm_chat_agent.ipynb new file mode 100644 index 00000000..bbac7bff --- /dev/null +++ b/docs/modules/agents/agents/custom_llm_chat_agent.ipynb @@ -0,0 +1,395 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ba5f8741", + "metadata": {}, + "source": [ + "# Custom LLM Agent (with a ChatModel)\n", + "\n", + "This notebook goes through how to create your own custom agent based on a chat model.\n", + "\n", + "An LLM chat agent consists of three parts:\n", + "\n", + "- PromptTemplate: This is the prompt template that can be used to instruct the language model on what to do\n", + "- ChatModel: This is the language model that powers the agent\n", + "- `stop` sequence: Instructs the LLM to stop generating as soon as this string is found\n", + "- OutputParser: This determines how to parse the LLMOutput into an AgentAction or AgentFinish object\n", + "\n", + "\n", + "The LLMAgent is used in an AgentExecutor. This AgentExecutor can largely be thought of as a loop that:\n", + "1. Passes user input and any previous steps to the Agent (in this case, the LLMAgent)\n", + "2. If the Agent returns an `AgentFinish`, then return that directly to the user\n", + "3. If the Agent returns an `AgentAction`, then use that to call a tool and get an `Observation`\n", + "4. Repeat, passing the `AgentAction` and `Observation` back to the Agent until an `AgentFinish` is emitted.\n", + " \n", + "`AgentAction` is a response that consists of `action` and `action_input`. `action` refers to which tool to use, and `action_input` refers to the input to that tool. `log` can also be provided as more context (that can be used for logging, tracing, etc).\n", + "\n", + "`AgentFinish` is a response that contains the final message to be sent back to the user. This should be used to end an agent run.\n", + " \n", + "In this notebook we walk through how to create a custom LLM agent." + ] + }, + { + "cell_type": "markdown", + "id": "fea4812c", + "metadata": {}, + "source": [ + "## Set up environment\n", + "\n", + "Do necessary imports, etc." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9af9734e", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser\n", + "from langchain.prompts import BaseChatPromptTemplate\n", + "from langchain import SerpAPIWrapper, LLMChain\n", + "from langchain.chat_models import ChatOpenAI\n", + "from typing import List, Union\n", + "from langchain.schema import AgentAction, AgentFinish, HumanMessage\n", + "import re" + ] + }, + { + "cell_type": "markdown", + "id": "6df0253f", + "metadata": {}, + "source": [ + "# Set up tool\n", + "\n", + "Set up any tools the agent may want to use. This may be necessary to put in the prompt (so that the agent knows to use these tools)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "becda2a1", + "metadata": {}, + "outputs": [], + "source": [ + "# Define which tools the agent can use to answer user queries\n", + "search = SerpAPIWrapper()\n", + "tools = [\n", + " Tool(\n", + " name = \"Search\",\n", + " func=search.run,\n", + " description=\"useful for when you need to answer questions about current events\"\n", + " )\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "2e7a075c", + "metadata": {}, + "source": [ + "## Prompt Teplate\n", + "\n", + "This instructs the agent on what to do. Generally, the template should incorporate:\n", + " \n", + "- `tools`: which tools the agent has access and how and when to call them.\n", + "- `intermediate_steps`: These are tuples of previous (`AgentAction`, `Observation`) pairs. These are generally not passed directly to the model, but the prompt template formats them in a specific way.\n", + "- `input`: generic user input" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "339b1bb8", + "metadata": {}, + "outputs": [], + "source": [ + "# Set up the base template\n", + "template = \"\"\"Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:\n", + "\n", + "{tools}\n", + "\n", + "Use the following format:\n", + "\n", + "Question: the input question you must answer\n", + "Thought: you should always think about what to do\n", + "Action: the action to take, should be one of [{tool_names}]\n", + "Action Input: the input to the action\n", + "Observation: the result of the action\n", + "... (this Thought/Action/Action Input/Observation can repeat N times)\n", + "Thought: I now know the final answer\n", + "Final Answer: the final answer to the original input question\n", + "\n", + "Begin! Remember to speak as a pirate when giving your final answer. Use lots of \"Arg\"s\n", + "\n", + "Question: {input}\n", + "{agent_scratchpad}\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "fd969d31", + "metadata": {}, + "outputs": [], + "source": [ + "# Set up a prompt template\n", + "class CustomPromptTemplate(BaseChatPromptTemplate):\n", + " # The template to use\n", + " template: str\n", + " # The list of tools available\n", + " tools: List[Tool]\n", + " \n", + " def format_messages(self, **kwargs) -> str:\n", + " # Get the intermediate steps (AgentAction, Observation tuples)\n", + " # Format them in a particular way\n", + " intermediate_steps = kwargs.pop(\"intermediate_steps\")\n", + " thoughts = \"\"\n", + " for action, observation in intermediate_steps:\n", + " thoughts += action.log\n", + " thoughts += f\"\\nObservation: {observation}\\nThought: \"\n", + " # Set the agent_scratchpad variable to that value\n", + " kwargs[\"agent_scratchpad\"] = thoughts\n", + " # Create a tools variable from the list of tools provided\n", + " kwargs[\"tools\"] = \"\\n\".join([f\"{tool.name}: {tool.description}\" for tool in self.tools])\n", + " # Create a list of tool names for the tools provided\n", + " kwargs[\"tool_names\"] = \", \".join([tool.name for tool in self.tools])\n", + " formatted = self.template.format(**kwargs)\n", + " return [HumanMessage(content=formatted)]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "798ef9fb", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = CustomPromptTemplate(\n", + " template=template,\n", + " tools=tools,\n", + " # This omits the `agent_scratchpad`, `tools`, and `tool_names` variables because those are generated dynamically\n", + " # This includes the `intermediate_steps` variable because that is needed\n", + " input_variables=[\"input\", \"intermediate_steps\"]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ef3a1af3", + "metadata": {}, + "source": [ + "## Output Parser\n", + "\n", + "The output parser is responsible for parsing the LLM output into `AgentAction` and `AgentFinish`. This usually depends heavily on the prompt used.\n", + "\n", + "This is where you can change the parsing to do retries, handle whitespace, etc" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "7c6fe0d3", + "metadata": {}, + "outputs": [], + "source": [ + "class CustomOutputParser(AgentOutputParser):\n", + " \n", + " def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:\n", + " # Check if agent should finish\n", + " if \"Final Answer:\" in llm_output:\n", + " return AgentFinish(\n", + " # Return values is generally always a dictionary with a single `output` key\n", + " # It is not recommended to try anything else at the moment :)\n", + " return_values={\"output\": llm_output.split(\"Final Answer:\")[-1].strip()},\n", + " log=llm_output,\n", + " )\n", + " # Parse out the action and action input\n", + " regex = r\"Action: (.*?)[\\n]*Action Input:[\\s]*(.*)\"\n", + " match = re.search(regex, llm_output, re.DOTALL)\n", + " if not match:\n", + " raise ValueError(f\"Could not parse LLM output: `{llm_output}`\")\n", + " action = match.group(1).strip()\n", + " action_input = match.group(2)\n", + " # Return the action and action input\n", + " return AgentAction(tool=action, tool_input=action_input.strip(\" \").strip('\"'), log=llm_output)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "d278706a", + "metadata": {}, + "outputs": [], + "source": [ + "output_parser = CustomOutputParser()" + ] + }, + { + "cell_type": "markdown", + "id": "170587b1", + "metadata": {}, + "source": [ + "## Set up LLM\n", + "\n", + "Choose the LLM you want to use!" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "f9d4c374", + "metadata": {}, + "outputs": [], + "source": [ + "llm = ChatOpenAI(temperature=0)" + ] + }, + { + "cell_type": "markdown", + "id": "caeab5e4", + "metadata": {}, + "source": [ + "## Define the stop sequence\n", + "\n", + "This is important because it tells the LLM when to stop generation.\n", + "\n", + "This depends heavily on the prompt and model you are using. Generally, you want this to be whatever token you use in the prompt to denote the start of an `Observation` (otherwise, the LLM may hallucinate an observation for you)." + ] + }, + { + "cell_type": "markdown", + "id": "34be9f65", + "metadata": {}, + "source": [ + "## Set up the Agent\n", + "\n", + "We can now combine everything to set up our agent" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "9b1cc2a2", + "metadata": {}, + "outputs": [], + "source": [ + "# LLM chain consisting of the LLM and a prompt\n", + "llm_chain = LLMChain(llm=llm, prompt=prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "e4f5092f", + "metadata": {}, + "outputs": [], + "source": [ + "tool_names = [tool.name for tool in tools]\n", + "agent = LLMSingleActionAgent(\n", + " llm_chain=llm_chain, \n", + " output_parser=output_parser,\n", + " stop=[\"\\nObservation:\"], \n", + " allowed_tools=tool_names\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "aa8a5326", + "metadata": {}, + "source": [ + "## Use the Agent\n", + "\n", + "Now we can use it!" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "490604e9", + "metadata": {}, + "outputs": [], + "source": [ + "agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "653b1617", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mThought: Wot year be it now? That be important to know the answer.\n", + "Action: Search\n", + "Action Input: \"current population canada 2023\"\u001b[0m\n", + "\n", + "Observation:\u001b[36;1m\u001b[1;3m38,649,283\u001b[0m\u001b[32;1m\u001b[1;3mAhoy! That be the correct year, but the answer be in regular numbers. 'Tis time to translate to pirate speak.\n", + "Action: Search\n", + "Action Input: \"38,649,283 in pirate speak\"\u001b[0m\n", + "\n", + "Observation:\u001b[36;1m\u001b[1;3mBrush up on your “Pirate Talk” with these helpful pirate phrases. Aaaarrrrgggghhhh! Pirate catch phrase of grumbling or disgust. Ahoy! Hello! Ahoy, Matey, Hello ...\u001b[0m\u001b[32;1m\u001b[1;3mThat be not helpful, I'll just do the translation meself.\n", + "Final Answer: Arrrr, thar be 38,649,283 scallywags in Canada as of 2023.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Arrrr, thar be 38,649,283 scallywags in Canada as of 2023.'" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.run(\"How many people live in canada as of 2023?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "adefb4c2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + }, + "vscode": { + "interpreter": { + "hash": "18784188d7ecd866c0586ac068b02361a6896dc3a29b64f5cc957f09c590acef" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/prompts/__init__.py b/langchain/prompts/__init__.py index aef564b6..2cd3aae1 100644 --- a/langchain/prompts/__init__.py +++ b/langchain/prompts/__init__.py @@ -2,6 +2,7 @@ from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate from langchain.prompts.chat import ( AIMessagePromptTemplate, + BaseChatPromptTemplate, ChatMessagePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, @@ -27,4 +28,5 @@ __all__ = [ "AIMessagePromptTemplate", "SystemMessagePromptTemplate", "ChatMessagePromptTemplate", + "BaseChatPromptTemplate", ] diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index 87f67a6d..251b8e92 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -119,7 +119,20 @@ class ChatPromptValue(PromptValue): return self.messages -class ChatPromptTemplate(BasePromptTemplate, ABC): +class BaseChatPromptTemplate(BasePromptTemplate, ABC): + def format(self, **kwargs: Any) -> str: + return self.format_prompt(**kwargs).to_string() + + def format_prompt(self, **kwargs: Any) -> PromptValue: + messages = self.format_messages(**kwargs) + return ChatPromptValue(messages=messages) + + @abstractmethod + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format kwargs into a list of messages.""" + + +class ChatPromptTemplate(BaseChatPromptTemplate, ABC): input_variables: List[str] messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] @@ -158,7 +171,7 @@ class ChatPromptTemplate(BasePromptTemplate, ABC): def format(self, **kwargs: Any) -> str: return self.format_prompt(**kwargs).to_string() - def format_prompt(self, **kwargs: Any) -> PromptValue: + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: kwargs = self._merge_partial_and_user_variables(**kwargs) result = [] for message_template in self.messages: @@ -174,7 +187,7 @@ class ChatPromptTemplate(BasePromptTemplate, ABC): result.extend(message) else: raise ValueError(f"Unexpected input: {message_template}") - return ChatPromptValue(messages=result) + return result def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: raise NotImplementedError