diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index ea9a3f46..50bf62ff 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -19,13 +19,183 @@ from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate -from langchain.schema import AgentAction, AgentFinish, BaseMessage +from langchain.schema import AgentAction, AgentFinish, BaseMessage, BaseOutputParser from langchain.tools.base import BaseTool logger = logging.getLogger() -class Agent(BaseModel): +class BaseSingleActionAgent(BaseModel): + """Base Agent class.""" + + @property + def return_values(self) -> List[str]: + """Return values of the agent.""" + return ["output"] + + def get_allowed_tools(self) -> Optional[List[str]]: + return None + + @abstractmethod + def plan( + self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + + @abstractmethod + async def aplan( + self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + + @property + @abstractmethod + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + + def return_stopped_response( + self, + early_stopping_method: str, + intermediate_steps: List[Tuple[AgentAction, str]], + **kwargs: Any, + ) -> AgentFinish: + """Return response when agent has been stopped due to max iterations.""" + if early_stopping_method == "force": + # `force` just returns a constant string + return AgentFinish({"output": "Agent stopped due to max iterations."}, "") + else: + raise ValueError( + f"Got unsupported early_stopping_method `{early_stopping_method}`" + ) + + @property + def _agent_type(self) -> str: + """Return Identifier of agent type.""" + raise NotImplementedError + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of agent.""" + _dict = super().dict() + _dict["_type"] = self._agent_type + return _dict + + def save(self, file_path: Union[Path, str]) -> None: + """Save the agent. + + Args: + file_path: Path to file to save the agent to. + + Example: + .. code-block:: python + + # If working with agent executor + agent.agent.save(file_path="path/agent.yaml") + """ + # Convert file to Path object. + if isinstance(file_path, str): + save_path = Path(file_path) + else: + save_path = file_path + + directory_path = save_path.parent + directory_path.mkdir(parents=True, exist_ok=True) + + # Fetch dictionary to save + agent_dict = self.dict() + + if save_path.suffix == ".json": + with open(file_path, "w") as f: + json.dump(agent_dict, f, indent=4) + elif save_path.suffix == ".yaml": + with open(file_path, "w") as f: + yaml.dump(agent_dict, f, default_flow_style=False) + else: + raise ValueError(f"{save_path} must be json or yaml") + + def tool_run_logging_kwargs(self) -> Dict: + return {} + + +class AgentOutputParser(BaseOutputParser): + @abstractmethod + def parse(self, text: str) -> Union[AgentAction, AgentFinish]: + """Parse text into agent action/finish.""" + + +class LLMSingleActionAgent(BaseSingleActionAgent): + llm_chain: LLMChain + output_parser: AgentOutputParser + stop: List[str] + + @property + def input_keys(self) -> List[str]: + return list(set(self.llm_chain.input_keys) - {"intermediate_steps"}) + + def plan( + self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + output = self.llm_chain.run( + intermediate_steps=intermediate_steps, stop=self.stop, **kwargs + ) + return self.output_parser.parse(output) + + async def aplan( + self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + output = await self.llm_chain.arun( + intermediate_steps=intermediate_steps, stop=self.stop, **kwargs + ) + return self.output_parser.parse(output) + + def tool_run_logging_kwargs(self) -> Dict: + return { + "llm_prefix": "", + "observation_prefix": "" if len(self.stop) == 0 else self.stop[0], + } + + +class Agent(BaseSingleActionAgent): """Class responsible for calling the language model and deciding the action. This is driven by an LLMChain. The prompt in the LLMChain MUST include @@ -35,7 +205,13 @@ class Agent(BaseModel): llm_chain: LLMChain allowed_tools: Optional[List[str]] = None - return_values: List[str] = ["output"] + + def get_allowed_tools(self) -> Optional[List[str]]: + return self.allowed_tools + + @property + def return_values(self) -> List[str]: + return ["output"] @abstractmethod def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]: @@ -248,55 +424,17 @@ class Agent(BaseModel): f"got {early_stopping_method}" ) - @property - @abstractmethod - def _agent_type(self) -> str: - """Return Identifier of agent type.""" - - def dict(self, **kwargs: Any) -> Dict: - """Return dictionary representation of agent.""" - _dict = super().dict() - _dict["_type"] = self._agent_type - return _dict - - def save(self, file_path: Union[Path, str]) -> None: - """Save the agent. - - Args: - file_path: Path to file to save the agent to. - - Example: - .. code-block:: python - - # If working with agent executor - agent.agent.save(file_path="path/agent.yaml") - """ - # Convert file to Path object. - if isinstance(file_path, str): - save_path = Path(file_path) - else: - save_path = file_path - - directory_path = save_path.parent - directory_path.mkdir(parents=True, exist_ok=True) - - # Fetch dictionary to save - agent_dict = self.dict() - - if save_path.suffix == ".json": - with open(file_path, "w") as f: - json.dump(agent_dict, f, indent=4) - elif save_path.suffix == ".yaml": - with open(file_path, "w") as f: - yaml.dump(agent_dict, f, default_flow_style=False) - else: - raise ValueError(f"{save_path} must be json or yaml") + def tool_run_logging_kwargs(self) -> Dict: + return { + "llm_prefix": self.llm_prefix, + "observation_prefix": self.observation_prefix, + } class AgentExecutor(Chain, BaseModel): """Consists of an agent using tools.""" - agent: Agent + agent: BaseSingleActionAgent tools: Sequence[BaseTool] return_intermediate_steps: bool = False max_iterations: Optional[int] = 15 @@ -305,7 +443,7 @@ class AgentExecutor(Chain, BaseModel): @classmethod def from_agent_and_tools( cls, - agent: Agent, + agent: BaseSingleActionAgent, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, @@ -320,10 +458,11 @@ class AgentExecutor(Chain, BaseModel): """Validate that tools are compatible with agent.""" agent = values["agent"] tools = values["tools"] - if agent.allowed_tools is not None: - if set(agent.allowed_tools) != set([tool.name for tool in tools]): + allowed_tools = agent.get_allowed_tools() + if allowed_tools is not None: + if set(allowed_tools) != set([tool.name for tool in tools]): raise ValueError( - f"Allowed tools ({agent.allowed_tools}) different than " + f"Allowed tools ({allowed_tools}) different than " f"provided tools ({[tool.name for tool in tools]})" ) return values @@ -418,22 +557,17 @@ class AgentExecutor(Chain, BaseModel): tool = name_to_tool_map[output.tool] return_direct = tool.return_direct color = color_mapping[output.tool] - llm_prefix = "" if return_direct else self.agent.llm_prefix + tool_run_kwargs = self.agent.tool_run_logging_kwargs() + if return_direct: + tool_run_kwargs["llm_prefix"] = "" # We then call the tool on the tool input to get an observation observation = tool.run( - output.tool_input, - verbose=self.verbose, - color=color, - llm_prefix=llm_prefix, - observation_prefix=self.agent.observation_prefix, + output.tool_input, verbose=self.verbose, color=color, **tool_run_kwargs ) else: + tool_run_kwargs = self.agent.tool_run_logging_kwargs() observation = InvalidTool().run( - output.tool, - verbose=self.verbose, - color=None, - llm_prefix="", - observation_prefix=self.agent.observation_prefix, + output.tool, verbose=self.verbose, color=None, **tool_run_kwargs ) return output, observation @@ -467,22 +601,17 @@ class AgentExecutor(Chain, BaseModel): tool = name_to_tool_map[output.tool] return_direct = tool.return_direct color = color_mapping[output.tool] - llm_prefix = "" if return_direct else self.agent.llm_prefix + tool_run_kwargs = self.agent.tool_run_logging_kwargs() + if return_direct: + tool_run_kwargs["llm_prefix"] = "" # We then call the tool on the tool input to get an observation observation = await tool.arun( - output.tool_input, - verbose=self.verbose, - color=color, - llm_prefix=llm_prefix, - observation_prefix=self.agent.observation_prefix, + output.tool_input, verbose=self.verbose, color=color, **tool_run_kwargs ) else: + tool_run_kwargs = self.agent.tool_run_logging_kwargs() observation = await InvalidTool().arun( - output.tool, - verbose=self.verbose, - color=None, - llm_prefix="", - observation_prefix=self.agent.observation_prefix, + output.tool, verbose=self.verbose, color=None, **tool_run_kwargs ) return_direct = False return output, observation diff --git a/langchain/callbacks/stdout.py b/langchain/callbacks/stdout.py index 70867367..cabd8f0c 100644 --- a/langchain/callbacks/stdout.py +++ b/langchain/callbacks/stdout.py @@ -74,9 +74,11 @@ class StdOutCallbackHandler(BaseCallbackHandler): **kwargs: Any, ) -> None: """If not the final action, print out observation.""" - print_text(f"\n{observation_prefix}") + if observation_prefix: + print_text(f"\n{observation_prefix}") print_text(output, color=color if color else self.color) - print_text(f"\n{llm_prefix}") + if llm_prefix: + print_text(f"\n{llm_prefix}") def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any diff --git a/langchain/chains/base.py b/langchain/chains/base.py index a89c5c4c..1b1837a3 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -199,7 +199,7 @@ class Chain(BaseModel, ABC): """Call the chain on all inputs in the list.""" return [self(inputs) for inputs in input_list] - def run(self, *args: str, **kwargs: str) -> str: + def run(self, *args: Any, **kwargs: Any) -> str: """Run the chain as text in, text out or multiple variables, text out.""" if len(self.output_keys) != 1: raise ValueError( @@ -220,7 +220,7 @@ class Chain(BaseModel, ABC): f" but not both. Got args: {args} and kwargs: {kwargs}." ) - async def arun(self, *args: str, **kwargs: str) -> str: + async def arun(self, *args: Any, **kwargs: Any) -> str: """Run the chain as text in, text out or multiple variables, text out.""" if len(self.output_keys) != 1: raise ValueError( diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 166ac4d5..2038f301 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -144,9 +144,9 @@ class BasePromptTemplate(BaseModel, ABC): """ @property - @abstractmethod def _prompt_type(self) -> str: """Return the prompt type key.""" + raise NotImplementedError def dict(self, **kwargs: Any) -> Dict: """Return dictionary representation of prompt.""" diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index 7db3894d..f928efa7 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -149,7 +149,7 @@ def test_agent_with_callbacks_local() -> None: callback_manager=manager, ) - agent.agent.llm_chain.verbose = True + agent.agent.llm_chain.verbose = True # type: ignore output = agent.run("when was langchain made") assert output == "curses foiled again" @@ -285,8 +285,8 @@ def test_agent_with_new_prefix_suffix() -> None: ) # avoids "BasePromptTemplate" has no attribute "template" error - assert hasattr(agent.agent.llm_chain.prompt, "template") - prompt_str = agent.agent.llm_chain.prompt.template + assert hasattr(agent.agent.llm_chain.prompt, "template") # type: ignore + prompt_str = agent.agent.llm_chain.prompt.template # type: ignore assert prompt_str.startswith(prefix), "Prompt does not start with prefix" assert prompt_str.endswith(suffix), "Prompt does not end with suffix"