From 440083fb3594c1e079587bbfd972abca02d7d17a Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 8 Dec 2022 09:40:27 -0800 Subject: [PATCH 1/3] agent multi inputs --- langchain/agents/agent.py | 28 +++++++++++----------------- langchain/agents/mrkl/base.py | 2 +- langchain/agents/mrkl/prompt.py | 2 +- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 62c4dcca..160c5b7d 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -37,7 +37,7 @@ class Agent(Chain, BaseModel, ABC): :meta private: """ - return [self.input_key] + return set(self.llm_chain.input_keys) - {"thoughts"} @property def output_keys(self) -> List[str]: @@ -99,23 +99,24 @@ class Agent(Chain, BaseModel, ABC): llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools)) return cls(llm_chain=llm_chain, tools=tools, **kwargs) - def get_action(self, text: str) -> Action: + def get_action(self, thoughts: str, inputs: dict) -> Action: """Given input, decided what to do. Args: - text: input string + thoughts: LLM thoughts + inputs: user inputs Returns: Action specifying what tool to use. """ - input_key = self.llm_chain.input_keys[0] - inputs = {input_key: text, "stop": self._stop} - full_output = self.llm_chain.predict(**inputs) + new_inputs = {"thoughts": thoughts, "stop": self._stop} + full_inputs = {**inputs, **new_inputs} + full_output = self.llm_chain.predict(**full_inputs) parsed_output = self._extract_tool_and_input(full_output) while parsed_output is None: full_output = self._fix_text(full_output) - inputs = {input_key: text + full_output, "stop": self._stop} - output = self.llm_chain.predict(**inputs) + full_inputs["thoughts"] += full_output + output = self.llm_chain.predict(**full_inputs) full_output += output parsed_output = self._extract_tool_and_input(full_output) tool, tool_input = parsed_output @@ -123,19 +124,12 @@ class Agent(Chain, BaseModel, ABC): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: """Run text through and get agent response.""" - text = inputs[self.input_key] # Do any preparation necessary when receiving a new input. self._prepare_for_new_call() # Construct a mapping of tool name to tool for easy lookup name_to_tool_map = {tool.name: tool.func for tool in self.tools} - # Construct the initial string to pass into the LLM. This is made up - # of the user input, the special starter string, and then the LLM prefix. - # The starter string is a special string that may be used by a LLM to - # immediately follow the user input. The LLM prefix is a string that - # prompts the LLM to take an action. - starter_string = text + self.starter_string + self.llm_prefix # We use the ChainedInput class to iteratively add to the input over time. - chained_input = ChainedInput(starter_string, verbose=self.verbose) + chained_input = ChainedInput(self.llm_prefix, verbose=self.verbose) # We construct a mapping from each tool to a color, used for logging. color_mapping = get_color_mapping( [tool.name for tool in self.tools], excluded_colors=["green"] @@ -143,7 +137,7 @@ class Agent(Chain, BaseModel, ABC): # We now enter the agent loop (until it returns something). while True: # Call the LLM to see what to do. - output = self.get_action(chained_input.input) + output = self.get_action(chained_input.input, inputs) # Add the log to the Chained Input. chained_input.add(output.log, color="green") # If the tool chosen is the finishing tool, then we end and return. diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index 1519c38d..2f74dca3 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -85,7 +85,7 @@ class ZeroShotAgent(Agent): format_instructions = FORMAT_INSTRUCTIONS.format(tool_names=tool_names) template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) if input_variables is None: - input_variables = ["input"] + input_variables = ["input", "thoughts"] return PromptTemplate(template=template, input_variables=input_variables) @classmethod diff --git a/langchain/agents/mrkl/prompt.py b/langchain/agents/mrkl/prompt.py index c32310d4..3e957e8e 100644 --- a/langchain/agents/mrkl/prompt.py +++ b/langchain/agents/mrkl/prompt.py @@ -12,4 +12,4 @@ Thought: I now know the final answer Final Answer: the final answer to the original input question""" SUFFIX = """Begin! -Question: {input}""" +Question: {input}{thoughts}""" From aacd41707602067a4bdb9c120d1498f08dfdf619 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 10 Dec 2022 23:47:55 -0800 Subject: [PATCH 2/3] cr --- docs/examples/agents/mrkl.ipynb | 75 ++++++++++++++++++++------------- 1 file changed, 45 insertions(+), 30 deletions(-) diff --git a/docs/examples/agents/mrkl.ipynb b/docs/examples/agents/mrkl.ipynb index d3571993..7f885bab 100644 --- a/docs/examples/agents/mrkl.ipynb +++ b/docs/examples/agents/mrkl.ipynb @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 21, "id": "07e96d99", "metadata": {}, "outputs": [], @@ -46,7 +46,7 @@ " Tool(\n", " name = \"Search\",\n", " func=search.run,\n", - " description=\"useful for when you need to answer questions about current events\"\n", + " description=\"useful for when you need to answer questions about current events. You should ask targeted questions\"\n", " ),\n", " Tool(\n", " name=\"Calculator\",\n", @@ -56,14 +56,14 @@ " Tool(\n", " name=\"FooBar DB\",\n", " func=db_chain.run,\n", - " description=\"useful for when you need to answer questions about FooBar. Input should be in the form of a question\"\n", + " description=\"useful for when you need to answer questions about FooBar. Input should be in the form of a question containing full context\"\n", " )\n", "]" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 22, "id": "a069c4b6", "metadata": {}, "outputs": [], @@ -73,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 29, "id": "e603cd7d", "metadata": {}, "outputs": [ @@ -84,40 +84,55 @@ "\n", "\n", "\u001b[1m> Entering new ZeroShotAgent chain...\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to use a calculator to solve this.\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", + "Action: Search\n", + "Action Input: \"Who is Olivia Wilde's boyfriend?\"\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n", + "Action: Search\n", + "Action Input: \"How old is Harry Styles?\"\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3m28 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 28 raised to the 0.23 power.\n", "Action: Calculator\n", - "Action Input: (age of Olivia Wilde's boyfriend)^0.23\u001b[0m\n", + "Action Input: 28^0.23\u001b[0m\n", "\n", "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", - "(age of Olivia Wilde's boyfriend)^0.23\u001b[32;1m\u001b[1;3m\n", + "28^0.23\u001b[32;1m\u001b[1;3m\n", "\n", - "Answer: 2.8379908717450045\u001b[0m\n", + "```python\n", + "import math\n", + "print(math.pow(28, 0.23))\n", + "```\n", + "\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3m2.1520202182226886\n", + "\u001b[0m\n", "\u001b[1m> Finished LLMMathChain chain.\u001b[0m\n", "\n", - "Observation: \u001b[33;1m\u001b[1;3mAnswer: 2.8379908717450045\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: 2.8379908717450045\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mAnswer: 2.1520202182226886\n", + "\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n", + "Final Answer: Harry Styles, Olivia Wilde's boyfriend, is 28 years old and his age raised to the 0.23 power is 2.1520202182226886.\u001b[0m\n", "\u001b[1m> Finished ZeroShotAgent chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "'2.8379908717450045'" + "\"Harry Styles, Olivia Wilde's boyfriend, is 28 years old and his age raised to the 0.23 power is 2.1520202182226886.\"" ] }, - "execution_count": 8, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "mrkl.run(\"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\")" + "mrkl.run(\"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\")" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 25, "id": "a5c07010", "metadata": {}, "outputs": [ @@ -128,46 +143,46 @@ "\n", "\n", "\u001b[1m> Entering new ZeroShotAgent chain...\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out who released the album and if they are in the FooBar database.\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out the artist's full name and then search the FooBar database for their albums.\n", "Action: Search\n", - "Action Input: \"The Storm Before the Calm\" album\u001b[0m\n", + "Action Input: \"The Storm Before the Calm\" artist\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3mThe Storm Before the Calm (stylized in all lowercase) is the tenth (and eighth international) studio album by Canadian-American singer-songwriter Alanis ...\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to check if Alanis is in the FooBar database\n", + "Thought:\u001b[32;1m\u001b[1;3m I now need to search the FooBar database for Alanis Morissette's albums\n", "Action: FooBar DB\n", - "Action Input: \"Alanis\"\u001b[0m\n", + "Action Input: What albums by Alanis Morissette are in the FooBar database?\u001b[0m\n", "\n", "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n", - "Alanis \n", - "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT FirstName, LastName FROM Customer WHERE FirstName LIKE 'Alanis'\u001b[0m\n", - "SQLResult: \u001b[33;1m\u001b[1;3m[]\u001b[0m\n", - "Answer:\u001b[32;1m\u001b[1;3m No results found.\u001b[0m\n", + "What albums by Alanis Morissette are in the FooBar database? \n", + "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Title FROM Album INNER JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alanis Morissette';\u001b[0m\n", + "SQLResult: \u001b[33;1m\u001b[1;3m[('Jagged Little Pill',)]\u001b[0m\n", + "Answer:\u001b[32;1m\u001b[1;3m The album 'Jagged Little Pill' by Alanis Morissette is in the FooBar database.\u001b[0m\n", "\u001b[1m> Finished SQLDatabaseChain chain.\u001b[0m\n", "\n", - "Observation: \u001b[38;5;200m\u001b[1;3m No results found.\u001b[0m\n", + "Observation: \u001b[38;5;200m\u001b[1;3m The album 'Jagged Little Pill' by Alanis Morissette is in the FooBar database.\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Alanis released the album The Storm Before the Calm, but she is not in the FooBar database.\u001b[0m\n", + "Final Answer: Alanis Morissette's album 'Jagged Little Pill' is in the FooBar database.\u001b[0m\n", "\u001b[1m> Finished ZeroShotAgent chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "'Alanis released the album The Storm Before the Calm, but she is not in the FooBar database.'" + "\"Alanis Morissette's album 'Jagged Little Pill' is in the FooBar database.\"" ] }, - "execution_count": 9, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "mrkl.run(\"Who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\")" + "mrkl.run(\"What is the full name of the artist who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "d7c2e6ac", + "id": "3f13b1c3", "metadata": {}, "outputs": [], "source": [] From 81383474c4a4f1b0ef9bb1f7d286519cc9610a6f Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 10 Dec 2022 23:58:02 -0800 Subject: [PATCH 3/3] cr --- docs/examples/agents/mrkl.ipynb | 14 +++++++------- langchain/agents/agent.py | 12 +++++++++++- langchain/agents/react/textworld_prompt.py | 4 +++- langchain/agents/react/wiki_prompt.py | 4 +++- langchain/agents/self_ask_with_search/prompt.py | 4 +++- tests/unit_tests/agents/test_react.py | 4 ++-- 6 files changed, 29 insertions(+), 13 deletions(-) diff --git a/docs/examples/agents/mrkl.ipynb b/docs/examples/agents/mrkl.ipynb index 7f885bab..8b54347b 100644 --- a/docs/examples/agents/mrkl.ipynb +++ b/docs/examples/agents/mrkl.ipynb @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 2, "id": "07e96d99", "metadata": {}, "outputs": [], @@ -63,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 3, "id": "a069c4b6", "metadata": {}, "outputs": [], @@ -73,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 4, "id": "e603cd7d", "metadata": {}, "outputs": [ @@ -121,7 +121,7 @@ "\"Harry Styles, Olivia Wilde's boyfriend, is 28 years old and his age raised to the 0.23 power is 2.1520202182226886.\"" ] }, - "execution_count": 29, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -132,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 5, "id": "a5c07010", "metadata": {}, "outputs": [ @@ -170,7 +170,7 @@ "\"Alanis Morissette's album 'Jagged Little Pill' is in the FooBar database.\"" ] }, - "execution_count": 25, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -182,7 +182,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f13b1c3", + "id": "af016a70", "metadata": {}, "outputs": [], "source": [] diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 6efd1db9..59319a85 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -4,7 +4,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import Any, ClassVar, Dict, List, Optional, Tuple -from pydantic import BaseModel +from pydantic import BaseModel, root_validator from langchain.agents.input import ChainedInput from langchain.agents.tools import Tool @@ -41,6 +41,16 @@ class Agent(Chain, BaseModel, ABC): """ return [self.output_key] + @root_validator() + def validate_prompt(cls, values: Dict) -> Dict: + """Validate that prompt matches format.""" + prompt = values["llm_chain"].prompt + if "agent_scratchpad" not in prompt.input_variables: + raise ValueError( + "`agent_scratchpad` should be a variable in prompt.input_variables" + ) + return values + @property @abstractmethod def observation_prefix(self) -> str: diff --git a/langchain/agents/react/textworld_prompt.py b/langchain/agents/react/textworld_prompt.py index b6675501..b832a6bb 100644 --- a/langchain/agents/react/textworld_prompt.py +++ b/langchain/agents/react/textworld_prompt.py @@ -47,4 +47,6 @@ Action 4: Finish[yes] SUFFIX = """\n\nSetup: {input} {agent_scratchpad}""" -TEXTWORLD_PROMPT = PromptTemplate.from_examples(EXAMPLES, SUFFIX, ["input", "agent_scratchpad"]) +TEXTWORLD_PROMPT = PromptTemplate.from_examples( + EXAMPLES, SUFFIX, ["input", "agent_scratchpad"] +) diff --git a/langchain/agents/react/wiki_prompt.py b/langchain/agents/react/wiki_prompt.py index 10db44d4..24370406 100644 --- a/langchain/agents/react/wiki_prompt.py +++ b/langchain/agents/react/wiki_prompt.py @@ -110,4 +110,6 @@ Action 3: Finish[yes]""", SUFFIX = """\n\nQuestion: {input} {agent_scratchpad}""" -WIKI_PROMPT = PromptTemplate.from_examples(EXAMPLES, SUFFIX, ["input", "agent_scratchpad"]) +WIKI_PROMPT = PromptTemplate.from_examples( + EXAMPLES, SUFFIX, ["input", "agent_scratchpad"] +) diff --git a/langchain/agents/self_ask_with_search/prompt.py b/langchain/agents/self_ask_with_search/prompt.py index 8d33ddbf..e511a64b 100644 --- a/langchain/agents/self_ask_with_search/prompt.py +++ b/langchain/agents/self_ask_with_search/prompt.py @@ -39,4 +39,6 @@ So the final answer is: No Question: {input} {agent_scratchpad}""" -PROMPT = PromptTemplate(input_variables=["input", "agent_scratchpad"], template=_DEFAULT_TEMPLATE) +PROMPT = PromptTemplate( + input_variables=["input", "agent_scratchpad"], template=_DEFAULT_TEMPLATE +) diff --git a/tests/unit_tests/agents/test_react.py b/tests/unit_tests/agents/test_react.py index 2d6c4cca..16cc26ab 100644 --- a/tests/unit_tests/agents/test_react.py +++ b/tests/unit_tests/agents/test_react.py @@ -56,7 +56,7 @@ def test_predict_until_observation_normal() -> None: Tool("Lookup", lambda x: x), ] agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools) - output = agent.get_action("") + output = agent.get_action("", {"input": ""}) assert output.log == outputs[0] assert output.tool == "Search" assert output.tool_input == "foo" @@ -71,7 +71,7 @@ def test_predict_until_observation_repeat() -> None: Tool("Lookup", lambda x: x), ] agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools) - output = agent.get_action("") + output = agent.get_action("", {"input": ""}) assert output.log == "foo\nAction 1: Search[foo]" assert output.tool == "Search" assert output.tool_input == "foo"