wip dynamic agent tools

This commit is contained in:
blob42 2023-05-17 18:33:14 +02:00 committed by blob42
parent 8dcad0f272
commit edbf7045d2
5 changed files with 105 additions and 29 deletions

View File

@ -1,7 +1,6 @@
{ {
"cells": [ "cells": [
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"id": "5436020b", "id": "5436020b",
"metadata": {}, "metadata": {},
@ -56,7 +55,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"id": "f8bc72c2", "id": "f8bc72c2",
"metadata": {}, "metadata": {},
@ -69,7 +67,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"id": "b63fcc3b", "id": "b63fcc3b",
"metadata": {}, "metadata": {},
@ -111,7 +108,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"id": "e9b560f7", "id": "e9b560f7",
"metadata": {}, "metadata": {},
@ -145,7 +141,24 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 7,
"id": "5f040378",
"metadata": {},
"outputs": [],
"source": [
"tools = [\n",
" Tool(\n",
" name=\"FooBar\",\n",
" description=\"Useful to answer questions related to FooBar\",\n",
" func=lambda x: \"FooBar is well and alive !\",\n",
" return_direct=True\n",
" )\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5b93047d", "id": "5b93047d",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -154,7 +167,46 @@
"source": [ "source": [
"# Construct the agent. We will use the default agent type here.\n", "# Construct the agent. We will use the default agent type here.\n",
"# See documentation for a full list of options.\n", "# See documentation for a full list of options.\n",
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)" "agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, max_iterations=3)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2bb58cac",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mI should use FooBar to answer this question\n",
"Action: FooBar\n",
"Action Input: \"Who is FooBar?\"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mFooBar is well and alive !\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'FooBar is well and alive !'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"Who is FooBar\")"
] ]
}, },
{ {
@ -220,7 +272,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"id": "6f12eaf0", "id": "6f12eaf0",
"metadata": {}, "metadata": {},
@ -459,7 +510,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"id": "61d2e80b", "id": "61d2e80b",
"metadata": {}, "metadata": {},
@ -470,7 +520,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"id": "5be41722", "id": "5be41722",
"metadata": {}, "metadata": {},
@ -499,7 +548,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"id": "fb0a38eb", "id": "fb0a38eb",
"metadata": {}, "metadata": {},
@ -561,7 +609,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"id": "7d68b0ac", "id": "7d68b0ac",
"metadata": {}, "metadata": {},
@ -589,7 +636,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"id": "1d0430d6", "id": "1d0430d6",
"metadata": {}, "metadata": {},
@ -857,7 +903,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.2" "version": "3.9.11"
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {

View File

@ -619,7 +619,7 @@ class ExceptionTool(BaseTool):
class AgentExecutor(Chain): class AgentExecutor(Chain):
"""Consists of an agent using tools.""" """Consists of an agent using tools."""
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent] agent: Union[Agent, BaseSingleActionAgent, BaseMultiActionAgent]
tools: Sequence[BaseTool] tools: Sequence[BaseTool]
return_intermediate_steps: bool = False return_intermediate_steps: bool = False
max_iterations: Optional[int] = 15 max_iterations: Optional[int] = 15
@ -703,6 +703,23 @@ class AgentExecutor(Chain):
"""Lookup tool by name.""" """Lookup tool by name."""
return {tool.name: tool for tool in self.tools}[name] return {tool.name: tool for tool in self.tools}[name]
def add_tools(self, tools: Sequence[BaseTool]) -> None:
"""Add extra tools to an active agent instance."""
self.agent._validate_tools(tools)
#HACK: should not cast to list
self.tools = list(self.tools) + list(tools)
# update allowed_tools
new_tool_names = [tool.name for tool in tools]
self.agent._allowed_tools = self.agent.get_allowed_tools().extend(new_tool_names)
# how to update the agent prompt
# update the agent's llm_chain prompt
self.agent.llm_chain.prompt.update
def _should_continue(self, iterations: int, time_elapsed: float) -> bool: def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
if self.max_iterations is not None and iterations >= self.max_iterations: if self.max_iterations is not None and iterations >= self.max_iterations:
return False return False

View File

@ -71,36 +71,34 @@ class ConversationalChatAgent(Agent):
output_parser: Optional[BaseOutputParser] = None, output_parser: Optional[BaseOutputParser] = None,
) -> BasePromptTemplate: ) -> BasePromptTemplate:
tool_strings = "\n".join( tool_strings = "\n".join(
[f"> {tool.name}: {tool.description}" for tool in tools] [f"> {tool.name}: {tool.description}" for tool in tools])
)
tool_names = ", ".join([tool.name for tool in tools]) tool_names = ", ".join([tool.name for tool in tools])
_output_parser = output_parser or cls._get_default_output_parser() _output_parser = output_parser or cls._get_default_output_parser()
format_instructions = human_message.format( format_instructions = human_message.format(
format_instructions=_output_parser.get_format_instructions() format_instructions=_output_parser.get_format_instructions())
) final_prompt = format_instructions.format(tool_names=tool_names,
final_prompt = format_instructions.format( tools=tool_strings)
tool_names=tool_names, tools=tool_strings
)
if input_variables is None: if input_variables is None:
input_variables = ["input", "chat_history", "agent_scratchpad"] input_variables = ["input", "chat_history", "agent_scratchpad"]
messages = [ messages = [
SystemMessagePromptTemplate.from_template(system_message), SystemMessagePromptTemplate.from_template(system_message),
MessagesPlaceholder(variable_name="chat_history"), MessagesPlaceholder(variable_name="chat_history"),
HumanMessagePromptTemplate.from_template(final_prompt), HumanMessagePromptTemplate.from_template(final_prompt,
alias="instructions_tools"),
MessagesPlaceholder(variable_name="agent_scratchpad"), MessagesPlaceholder(variable_name="agent_scratchpad"),
] ]
return ChatPromptTemplate(input_variables=input_variables, messages=messages) return ChatPromptTemplate(input_variables=input_variables,
messages=messages)
def _construct_scratchpad( def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]] self, intermediate_steps: List[Tuple[AgentAction,
) -> List[BaseMessage]: str]]) -> List[BaseMessage]:
"""Construct the scratchpad that lets the agent continue its thought process.""" """Construct the scratchpad that lets the agent continue its thought process."""
thoughts: List[BaseMessage] = [] thoughts: List[BaseMessage] = []
for action, observation in intermediate_steps: for action, observation in intermediate_steps:
thoughts.append(AIMessage(content=action.log)) thoughts.append(AIMessage(content=action.log))
human_message = HumanMessage( human_message = HumanMessage(content=TEMPLATE_TOOL_RESPONSE.format(
content=TEMPLATE_TOOL_RESPONSE.format(observation=observation) observation=observation))
)
thoughts.append(human_message) thoughts.append(human_message)
return thoughts return thoughts

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -21,6 +21,8 @@ from langchain.schema import (
class BaseMessagePromptTemplate(BaseModel, ABC): class BaseMessagePromptTemplate(BaseModel, ABC):
alias: Optional[str] = None
@abstractmethod @abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""To messages.""" """To messages."""
@ -212,6 +214,18 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
raise ValueError(f"Unexpected input: {message_template}") raise ValueError(f"Unexpected input: {message_template}")
return result return result
def update(self, msg_alias: str,
replacement_msg: Union[BaseMessagePromptTemplate, BaseMessage]) -> None:
for i, message in enumerate(self.messages):
if isinstance(message, BaseMessage) and message.alias == msg_alias:
if type(message) != type(replacement_msg):
raise ValueError("Replacement message is not of the same type as the original.")
self.messages[i] = replacement_msg
return
raise ValueError(f"No message with alias {msg_alias} found.")
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
raise NotImplementedError raise NotImplementedError

View File

@ -68,6 +68,7 @@ class BaseMessage(BaseModel):
"""Message object.""" """Message object."""
content: str content: str
alias: Optional[str] = None
additional_kwargs: dict = Field(default_factory=dict) additional_kwargs: dict = Field(default_factory=dict)
@property @property