diff --git a/docs/examples/agents/mrkl.ipynb b/docs/examples/agents/mrkl.ipynb index 7f885bab..8b54347b 100644 --- a/docs/examples/agents/mrkl.ipynb +++ b/docs/examples/agents/mrkl.ipynb @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 2, "id": "07e96d99", "metadata": {}, "outputs": [], @@ -63,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 3, "id": "a069c4b6", "metadata": {}, "outputs": [], @@ -73,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 4, "id": "e603cd7d", "metadata": {}, "outputs": [ @@ -121,7 +121,7 @@ "\"Harry Styles, Olivia Wilde's boyfriend, is 28 years old and his age raised to the 0.23 power is 2.1520202182226886.\"" ] }, - "execution_count": 29, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -132,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 5, "id": "a5c07010", "metadata": {}, "outputs": [ @@ -170,7 +170,7 @@ "\"Alanis Morissette's album 'Jagged Little Pill' is in the FooBar database.\"" ] }, - "execution_count": 25, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -182,7 +182,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f13b1c3", + "id": "af016a70", "metadata": {}, "outputs": [], "source": [] diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 6efd1db9..59319a85 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -4,7 +4,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import Any, ClassVar, Dict, List, Optional, Tuple -from pydantic import BaseModel +from pydantic import BaseModel, root_validator from langchain.agents.input import ChainedInput from langchain.agents.tools import Tool @@ -41,6 +41,16 @@ class Agent(Chain, BaseModel, ABC): """ return [self.output_key] + @root_validator() + def validate_prompt(cls, values: Dict) -> Dict: + """Validate that prompt matches format.""" + prompt = values["llm_chain"].prompt + if "agent_scratchpad" not in prompt.input_variables: + raise ValueError( + "`agent_scratchpad` should be a variable in prompt.input_variables" + ) + return values + @property @abstractmethod def observation_prefix(self) -> str: diff --git a/langchain/agents/react/textworld_prompt.py b/langchain/agents/react/textworld_prompt.py index b6675501..b832a6bb 100644 --- a/langchain/agents/react/textworld_prompt.py +++ b/langchain/agents/react/textworld_prompt.py @@ -47,4 +47,6 @@ Action 4: Finish[yes] SUFFIX = """\n\nSetup: {input} {agent_scratchpad}""" -TEXTWORLD_PROMPT = PromptTemplate.from_examples(EXAMPLES, SUFFIX, ["input", "agent_scratchpad"]) +TEXTWORLD_PROMPT = PromptTemplate.from_examples( + EXAMPLES, SUFFIX, ["input", "agent_scratchpad"] +) diff --git a/langchain/agents/react/wiki_prompt.py b/langchain/agents/react/wiki_prompt.py index 10db44d4..24370406 100644 --- a/langchain/agents/react/wiki_prompt.py +++ b/langchain/agents/react/wiki_prompt.py @@ -110,4 +110,6 @@ Action 3: Finish[yes]""", SUFFIX = """\n\nQuestion: {input} {agent_scratchpad}""" -WIKI_PROMPT = PromptTemplate.from_examples(EXAMPLES, SUFFIX, ["input", "agent_scratchpad"]) +WIKI_PROMPT = PromptTemplate.from_examples( + EXAMPLES, SUFFIX, ["input", "agent_scratchpad"] +) diff --git a/langchain/agents/self_ask_with_search/prompt.py b/langchain/agents/self_ask_with_search/prompt.py index 8d33ddbf..e511a64b 100644 --- a/langchain/agents/self_ask_with_search/prompt.py +++ b/langchain/agents/self_ask_with_search/prompt.py @@ -39,4 +39,6 @@ So the final answer is: No Question: {input} {agent_scratchpad}""" -PROMPT = PromptTemplate(input_variables=["input", "agent_scratchpad"], template=_DEFAULT_TEMPLATE) +PROMPT = PromptTemplate( + input_variables=["input", "agent_scratchpad"], template=_DEFAULT_TEMPLATE +) diff --git a/tests/unit_tests/agents/test_react.py b/tests/unit_tests/agents/test_react.py index 2d6c4cca..16cc26ab 100644 --- a/tests/unit_tests/agents/test_react.py +++ b/tests/unit_tests/agents/test_react.py @@ -56,7 +56,7 @@ def test_predict_until_observation_normal() -> None: Tool("Lookup", lambda x: x), ] agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools) - output = agent.get_action("") + output = agent.get_action("", {"input": ""}) assert output.log == outputs[0] assert output.tool == "Search" assert output.tool_input == "foo" @@ -71,7 +71,7 @@ def test_predict_until_observation_repeat() -> None: Tool("Lookup", lambda x: x), ] agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools) - output = agent.get_action("") + output = agent.get_action("", {"input": ""}) assert output.log == "foo\nAction 1: Search[foo]" assert output.tool == "Search" assert output.tool_input == "foo"