From edbf7045d2cebbfa99c950ade933873e0a72be95 Mon Sep 17 00:00:00 2001 From: blob42 Date: Wed, 17 May 2023 18:33:14 +0200 Subject: [PATCH] wip dynamic agent tools --- docs/modules/agents/tools/custom_tools.ipynb | 72 ++++++++++++++++---- langchain/agents/agent.py | 19 +++++- langchain/agents/conversational_chat/base.py | 26 ++++--- langchain/prompts/chat.py | 16 ++++- langchain/schema.py | 1 + 5 files changed, 105 insertions(+), 29 deletions(-) diff --git a/docs/modules/agents/tools/custom_tools.ipynb b/docs/modules/agents/tools/custom_tools.ipynb index a0fea125..6e0a3564 100644 --- a/docs/modules/agents/tools/custom_tools.ipynb +++ b/docs/modules/agents/tools/custom_tools.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "id": "5436020b", "metadata": {}, @@ -56,7 +55,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "f8bc72c2", "metadata": {}, @@ -69,7 +67,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "b63fcc3b", "metadata": {}, @@ -111,7 +108,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "e9b560f7", "metadata": {}, @@ -145,7 +141,24 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, + "id": "5f040378", + "metadata": {}, + "outputs": [], + "source": [ + "tools = [\n", + " Tool(\n", + " name=\"FooBar\",\n", + " description=\"Useful to answer questions related to FooBar\",\n", + " func=lambda x: \"FooBar is well and alive !\",\n", + " return_direct=True\n", + " )\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "id": "5b93047d", "metadata": { "tags": [] @@ -154,7 +167,46 @@ "source": [ "# Construct the agent. We will use the default agent type here.\n", "# See documentation for a full list of options.\n", - "agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)" + "agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, max_iterations=3)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2bb58cac", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mI should use FooBar to answer this question\n", + "Action: FooBar\n", + "Action Input: \"Who is FooBar?\"\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mFooBar is well and alive !\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'FooBar is well and alive !'" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"Who is FooBar\")" ] }, { @@ -220,7 +272,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "6f12eaf0", "metadata": {}, @@ -459,7 +510,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "61d2e80b", "metadata": {}, @@ -470,7 +520,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "5be41722", "metadata": {}, @@ -499,7 +548,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "fb0a38eb", "metadata": {}, @@ -561,7 +609,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "7d68b0ac", "metadata": {}, @@ -589,7 +636,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "1d0430d6", "metadata": {}, @@ -857,7 +903,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.9.11" }, "vscode": { "interpreter": { diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index f73b5f26..27bb4074 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -619,7 +619,7 @@ class ExceptionTool(BaseTool): class AgentExecutor(Chain): """Consists of an agent using tools.""" - agent: Union[BaseSingleActionAgent, BaseMultiActionAgent] + agent: Union[Agent, BaseSingleActionAgent, BaseMultiActionAgent] tools: Sequence[BaseTool] return_intermediate_steps: bool = False max_iterations: Optional[int] = 15 @@ -703,6 +703,23 @@ class AgentExecutor(Chain): """Lookup tool by name.""" return {tool.name: tool for tool in self.tools}[name] + def add_tools(self, tools: Sequence[BaseTool]) -> None: + """Add extra tools to an active agent instance.""" + self.agent._validate_tools(tools) + + #HACK: should not cast to list + self.tools = list(self.tools) + list(tools) + + # update allowed_tools + new_tool_names = [tool.name for tool in tools] + self.agent._allowed_tools = self.agent.get_allowed_tools().extend(new_tool_names) + + # how to update the agent prompt + # update the agent's llm_chain prompt + self.agent.llm_chain.prompt.update + + + def _should_continue(self, iterations: int, time_elapsed: float) -> bool: if self.max_iterations is not None and iterations >= self.max_iterations: return False diff --git a/langchain/agents/conversational_chat/base.py b/langchain/agents/conversational_chat/base.py index 17128543..16a6dc29 100644 --- a/langchain/agents/conversational_chat/base.py +++ b/langchain/agents/conversational_chat/base.py @@ -71,36 +71,34 @@ class ConversationalChatAgent(Agent): output_parser: Optional[BaseOutputParser] = None, ) -> BasePromptTemplate: tool_strings = "\n".join( - [f"> {tool.name}: {tool.description}" for tool in tools] - ) + [f"> {tool.name}: {tool.description}" for tool in tools]) tool_names = ", ".join([tool.name for tool in tools]) _output_parser = output_parser or cls._get_default_output_parser() format_instructions = human_message.format( - format_instructions=_output_parser.get_format_instructions() - ) - final_prompt = format_instructions.format( - tool_names=tool_names, tools=tool_strings - ) + format_instructions=_output_parser.get_format_instructions()) + final_prompt = format_instructions.format(tool_names=tool_names, + tools=tool_strings) if input_variables is None: input_variables = ["input", "chat_history", "agent_scratchpad"] messages = [ SystemMessagePromptTemplate.from_template(system_message), MessagesPlaceholder(variable_name="chat_history"), - HumanMessagePromptTemplate.from_template(final_prompt), + HumanMessagePromptTemplate.from_template(final_prompt, + alias="instructions_tools"), MessagesPlaceholder(variable_name="agent_scratchpad"), ] - return ChatPromptTemplate(input_variables=input_variables, messages=messages) + return ChatPromptTemplate(input_variables=input_variables, + messages=messages) def _construct_scratchpad( - self, intermediate_steps: List[Tuple[AgentAction, str]] - ) -> List[BaseMessage]: + self, intermediate_steps: List[Tuple[AgentAction, + str]]) -> List[BaseMessage]: """Construct the scratchpad that lets the agent continue its thought process.""" thoughts: List[BaseMessage] = [] for action, observation in intermediate_steps: thoughts.append(AIMessage(content=action.log)) - human_message = HumanMessage( - content=TEMPLATE_TOOL_RESPONSE.format(observation=observation) - ) + human_message = HumanMessage(content=TEMPLATE_TOOL_RESPONSE.format( + observation=observation)) thoughts.append(human_message) return thoughts diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index 9df7cece..2a4627e4 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union +from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, Optional from pydantic import BaseModel, Field @@ -21,6 +21,8 @@ from langchain.schema import ( class BaseMessagePromptTemplate(BaseModel, ABC): + alias: Optional[str] = None + @abstractmethod def format_messages(self, **kwargs: Any) -> List[BaseMessage]: """To messages.""" @@ -212,6 +214,18 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): raise ValueError(f"Unexpected input: {message_template}") return result + def update(self, msg_alias: str, + replacement_msg: Union[BaseMessagePromptTemplate, BaseMessage]) -> None: + for i, message in enumerate(self.messages): + if isinstance(message, BaseMessage) and message.alias == msg_alias: + if type(message) != type(replacement_msg): + raise ValueError("Replacement message is not of the same type as the original.") + self.messages[i] = replacement_msg + return + raise ValueError(f"No message with alias {msg_alias} found.") + + + def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: raise NotImplementedError diff --git a/langchain/schema.py b/langchain/schema.py index 21552b9b..ae287d7c 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -68,6 +68,7 @@ class BaseMessage(BaseModel): """Message object.""" content: str + alias: Optional[str] = None additional_kwargs: dict = Field(default_factory=dict) @property