diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 51972d69..fbb1ff6c 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -7,7 +7,7 @@ import logging import time from abc import abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import yaml from pydantic import BaseModel, root_validator @@ -48,8 +48,8 @@ class BaseSingleActionAgent(BaseModel): """Return values of the agent.""" return ["output"] - def get_allowed_tools(self) -> Set[str]: - return set() + def get_allowed_tools(self) -> Optional[List[str]]: + return None @abstractmethod def plan( @@ -180,8 +180,8 @@ class BaseMultiActionAgent(BaseModel): """Return values of the agent.""" return ["output"] - def get_allowed_tools(self) -> Set[str]: - return set() + def get_allowed_tools(self) -> Optional[List[str]]: + return None @abstractmethod def plan( @@ -374,9 +374,9 @@ class Agent(BaseSingleActionAgent): llm_chain: LLMChain output_parser: AgentOutputParser - allowed_tools: Set[str] = set() + allowed_tools: Optional[List[str]] = None - def get_allowed_tools(self) -> Set[str]: + def get_allowed_tools(self) -> Optional[List[str]]: return self.allowed_tools @property @@ -629,11 +629,12 @@ class AgentExecutor(Chain): agent = values["agent"] tools = values["tools"] allowed_tools = agent.get_allowed_tools() - if allowed_tools != set([tool.name for tool in tools]): - raise ValueError( - f"Allowed tools ({allowed_tools}) different than " - f"provided tools ({[tool.name for tool in tools]})" - ) + if allowed_tools is not None: + if set(allowed_tools) != set([tool.name for tool in tools]): + raise ValueError( + f"Allowed tools ({allowed_tools}) different than " + f"provided tools ({[tool.name for tool in tools]})" + ) return values @root_validator() diff --git a/tests/integration_tests/agent/__init__.py b/tests/integration_tests/agent/__init__.py deleted file mode 100644 index 117480e1..00000000 --- a/tests/integration_tests/agent/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""All integration tests for agent.""" diff --git a/tests/integration_tests/agent/test_agent.py b/tests/integration_tests/agent/test_agent.py deleted file mode 100644 index b423414e..00000000 --- a/tests/integration_tests/agent/test_agent.py +++ /dev/null @@ -1,16 +0,0 @@ -from langchain.agents.chat.base import ChatAgent -from langchain.llms.openai import OpenAI -from langchain.tools.ddg_search.tool import DuckDuckGoSearchRun - - -class TestAgent: - def test_agent_generation(self) -> None: - web_search = DuckDuckGoSearchRun() - tools = [web_search] - agent = ChatAgent.from_llm_and_tools( - ai_name="Tom", - ai_role="Assistant", - tools=tools, - llm=OpenAI(maxTokens=10), - ) - assert agent.allowed_tools == set([web_search.name])