forked from Archives/langchain
wip dynamic agent tools
This commit is contained in:
parent
8dcad0f272
commit
edbf7045d2
@ -1,7 +1,6 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "5436020b",
|
"id": "5436020b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -56,7 +55,6 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "f8bc72c2",
|
"id": "f8bc72c2",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -69,7 +67,6 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "b63fcc3b",
|
"id": "b63fcc3b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -111,7 +108,6 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "e9b560f7",
|
"id": "e9b560f7",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -145,7 +141,24 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 7,
|
||||||
|
"id": "5f040378",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"tools = [\n",
|
||||||
|
" Tool(\n",
|
||||||
|
" name=\"FooBar\",\n",
|
||||||
|
" description=\"Useful to answer questions related to FooBar\",\n",
|
||||||
|
" func=lambda x: \"FooBar is well and alive !\",\n",
|
||||||
|
" return_direct=True\n",
|
||||||
|
" )\n",
|
||||||
|
"]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
"id": "5b93047d",
|
"id": "5b93047d",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -154,7 +167,46 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# Construct the agent. We will use the default agent type here.\n",
|
"# Construct the agent. We will use the default agent type here.\n",
|
||||||
"# See documentation for a full list of options.\n",
|
"# See documentation for a full list of options.\n",
|
||||||
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)"
|
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, max_iterations=3)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "2bb58cac",
|
||||||
|
"metadata": {
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mI should use FooBar to answer this question\n",
|
||||||
|
"Action: FooBar\n",
|
||||||
|
"Action Input: \"Who is FooBar?\"\u001b[0m\n",
|
||||||
|
"Observation: \u001b[36;1m\u001b[1;3mFooBar is well and alive !\u001b[0m\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3m\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'FooBar is well and alive !'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"agent.run(\"Who is FooBar\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -220,7 +272,6 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "6f12eaf0",
|
"id": "6f12eaf0",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -459,7 +510,6 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "61d2e80b",
|
"id": "61d2e80b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -470,7 +520,6 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "5be41722",
|
"id": "5be41722",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -499,7 +548,6 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "fb0a38eb",
|
"id": "fb0a38eb",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -561,7 +609,6 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "7d68b0ac",
|
"id": "7d68b0ac",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -589,7 +636,6 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "1d0430d6",
|
"id": "1d0430d6",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -857,7 +903,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.2"
|
"version": "3.9.11"
|
||||||
},
|
},
|
||||||
"vscode": {
|
"vscode": {
|
||||||
"interpreter": {
|
"interpreter": {
|
||||||
|
@ -619,7 +619,7 @@ class ExceptionTool(BaseTool):
|
|||||||
class AgentExecutor(Chain):
|
class AgentExecutor(Chain):
|
||||||
"""Consists of an agent using tools."""
|
"""Consists of an agent using tools."""
|
||||||
|
|
||||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent]
|
agent: Union[Agent, BaseSingleActionAgent, BaseMultiActionAgent]
|
||||||
tools: Sequence[BaseTool]
|
tools: Sequence[BaseTool]
|
||||||
return_intermediate_steps: bool = False
|
return_intermediate_steps: bool = False
|
||||||
max_iterations: Optional[int] = 15
|
max_iterations: Optional[int] = 15
|
||||||
@ -703,6 +703,23 @@ class AgentExecutor(Chain):
|
|||||||
"""Lookup tool by name."""
|
"""Lookup tool by name."""
|
||||||
return {tool.name: tool for tool in self.tools}[name]
|
return {tool.name: tool for tool in self.tools}[name]
|
||||||
|
|
||||||
|
def add_tools(self, tools: Sequence[BaseTool]) -> None:
|
||||||
|
"""Add extra tools to an active agent instance."""
|
||||||
|
self.agent._validate_tools(tools)
|
||||||
|
|
||||||
|
#HACK: should not cast to list
|
||||||
|
self.tools = list(self.tools) + list(tools)
|
||||||
|
|
||||||
|
# update allowed_tools
|
||||||
|
new_tool_names = [tool.name for tool in tools]
|
||||||
|
self.agent._allowed_tools = self.agent.get_allowed_tools().extend(new_tool_names)
|
||||||
|
|
||||||
|
# how to update the agent prompt
|
||||||
|
# update the agent's llm_chain prompt
|
||||||
|
self.agent.llm_chain.prompt.update
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
|
def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
|
||||||
if self.max_iterations is not None and iterations >= self.max_iterations:
|
if self.max_iterations is not None and iterations >= self.max_iterations:
|
||||||
return False
|
return False
|
||||||
|
@ -71,36 +71,34 @@ class ConversationalChatAgent(Agent):
|
|||||||
output_parser: Optional[BaseOutputParser] = None,
|
output_parser: Optional[BaseOutputParser] = None,
|
||||||
) -> BasePromptTemplate:
|
) -> BasePromptTemplate:
|
||||||
tool_strings = "\n".join(
|
tool_strings = "\n".join(
|
||||||
[f"> {tool.name}: {tool.description}" for tool in tools]
|
[f"> {tool.name}: {tool.description}" for tool in tools])
|
||||||
)
|
|
||||||
tool_names = ", ".join([tool.name for tool in tools])
|
tool_names = ", ".join([tool.name for tool in tools])
|
||||||
_output_parser = output_parser or cls._get_default_output_parser()
|
_output_parser = output_parser or cls._get_default_output_parser()
|
||||||
format_instructions = human_message.format(
|
format_instructions = human_message.format(
|
||||||
format_instructions=_output_parser.get_format_instructions()
|
format_instructions=_output_parser.get_format_instructions())
|
||||||
)
|
final_prompt = format_instructions.format(tool_names=tool_names,
|
||||||
final_prompt = format_instructions.format(
|
tools=tool_strings)
|
||||||
tool_names=tool_names, tools=tool_strings
|
|
||||||
)
|
|
||||||
if input_variables is None:
|
if input_variables is None:
|
||||||
input_variables = ["input", "chat_history", "agent_scratchpad"]
|
input_variables = ["input", "chat_history", "agent_scratchpad"]
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessagePromptTemplate.from_template(system_message),
|
SystemMessagePromptTemplate.from_template(system_message),
|
||||||
MessagesPlaceholder(variable_name="chat_history"),
|
MessagesPlaceholder(variable_name="chat_history"),
|
||||||
HumanMessagePromptTemplate.from_template(final_prompt),
|
HumanMessagePromptTemplate.from_template(final_prompt,
|
||||||
|
alias="instructions_tools"),
|
||||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||||
]
|
]
|
||||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
return ChatPromptTemplate(input_variables=input_variables,
|
||||||
|
messages=messages)
|
||||||
|
|
||||||
def _construct_scratchpad(
|
def _construct_scratchpad(
|
||||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
self, intermediate_steps: List[Tuple[AgentAction,
|
||||||
) -> List[BaseMessage]:
|
str]]) -> List[BaseMessage]:
|
||||||
"""Construct the scratchpad that lets the agent continue its thought process."""
|
"""Construct the scratchpad that lets the agent continue its thought process."""
|
||||||
thoughts: List[BaseMessage] = []
|
thoughts: List[BaseMessage] = []
|
||||||
for action, observation in intermediate_steps:
|
for action, observation in intermediate_steps:
|
||||||
thoughts.append(AIMessage(content=action.log))
|
thoughts.append(AIMessage(content=action.log))
|
||||||
human_message = HumanMessage(
|
human_message = HumanMessage(content=TEMPLATE_TOOL_RESPONSE.format(
|
||||||
content=TEMPLATE_TOOL_RESPONSE.format(observation=observation)
|
observation=observation))
|
||||||
)
|
|
||||||
thoughts.append(human_message)
|
thoughts.append(human_message)
|
||||||
return thoughts
|
return thoughts
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
|
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -21,6 +21,8 @@ from langchain.schema import (
|
|||||||
|
|
||||||
|
|
||||||
class BaseMessagePromptTemplate(BaseModel, ABC):
|
class BaseMessagePromptTemplate(BaseModel, ABC):
|
||||||
|
alias: Optional[str] = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
"""To messages."""
|
"""To messages."""
|
||||||
@ -212,6 +214,18 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
raise ValueError(f"Unexpected input: {message_template}")
|
raise ValueError(f"Unexpected input: {message_template}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def update(self, msg_alias: str,
|
||||||
|
replacement_msg: Union[BaseMessagePromptTemplate, BaseMessage]) -> None:
|
||||||
|
for i, message in enumerate(self.messages):
|
||||||
|
if isinstance(message, BaseMessage) and message.alias == msg_alias:
|
||||||
|
if type(message) != type(replacement_msg):
|
||||||
|
raise ValueError("Replacement message is not of the same type as the original.")
|
||||||
|
self.messages[i] = replacement_msg
|
||||||
|
return
|
||||||
|
raise ValueError(f"No message with alias {msg_alias} found.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ class BaseMessage(BaseModel):
|
|||||||
"""Message object."""
|
"""Message object."""
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
|
alias: Optional[str] = None
|
||||||
additional_kwargs: dict = Field(default_factory=dict)
|
additional_kwargs: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
Reference in New Issue
Block a user