diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index eb9b859f..ca382031 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -7,7 +7,7 @@ import logging import time from abc import abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union import yaml from pydantic import BaseModel, root_validator @@ -46,8 +46,8 @@ class BaseSingleActionAgent(BaseModel): """Return values of the agent.""" return ["output"] - def get_allowed_tools(self) -> Optional[List[str]]: - return None + def get_allowed_tools(self) -> Set[str]: + return set() @abstractmethod def plan( @@ -178,8 +178,8 @@ class BaseMultiActionAgent(BaseModel): """Return values of the agent.""" return ["output"] - def get_allowed_tools(self) -> Optional[List[str]]: - return None + def get_allowed_tools(self) -> Set[str]: + return set() @abstractmethod def plan( @@ -372,9 +372,9 @@ class Agent(BaseSingleActionAgent): llm_chain: LLMChain output_parser: AgentOutputParser - allowed_tools: Optional[List[str]] = None + allowed_tools: Set[str] = set() - def get_allowed_tools(self) -> Optional[List[str]]: + def get_allowed_tools(self) -> Set[str]: return self.allowed_tools @property @@ -607,12 +607,11 @@ class AgentExecutor(Chain): agent = values["agent"] tools = values["tools"] allowed_tools = agent.get_allowed_tools() - if allowed_tools is not None: - if set(allowed_tools) != set([tool.name for tool in tools]): - raise ValueError( - f"Allowed tools ({allowed_tools}) different than " - f"provided tools ({[tool.name for tool in tools]})" - ) + if allowed_tools != set([tool.name for tool in tools]): + raise ValueError( + f"Allowed tools ({allowed_tools}) different than " + f"provided tools ({[tool.name for tool in tools]})" + ) return values @root_validator() diff --git a/tests/integration_tests/agent/__init__.py b/tests/integration_tests/agent/__init__.py new file mode 100644 index 00000000..117480e1 --- /dev/null +++ b/tests/integration_tests/agent/__init__.py @@ -0,0 +1 @@ +"""All integration tests for agent.""" diff --git a/tests/integration_tests/agent/test_agent.py b/tests/integration_tests/agent/test_agent.py new file mode 100644 index 00000000..b423414e --- /dev/null +++ b/tests/integration_tests/agent/test_agent.py @@ -0,0 +1,16 @@ +from langchain.agents.chat.base import ChatAgent +from langchain.llms.openai import OpenAI +from langchain.tools.ddg_search.tool import DuckDuckGoSearchRun + + +class TestAgent: + def test_agent_generation(self) -> None: + web_search = DuckDuckGoSearchRun() + tools = [web_search] + agent = ChatAgent.from_llm_and_tools( + ai_name="Tom", + ai_role="Assistant", + tools=tools, + llm=OpenAI(maxTokens=10), + ) + assert agent.allowed_tools == set([web_search.name])