From a5dd73c1a646473ed94bd8525721540f82742122 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 2 May 2023 18:58:05 -0700 Subject: [PATCH] =?UTF-8?q?Revert=20"[agent][property=20type]=20Change=20a?= =?UTF-8?q?llowed=5Ftools=20to=20Set=20as=20Duplicate=20doesn=E2=80=99t=20?= =?UTF-8?q?make=20sense"=20(#4014)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reverts hwchase17/langchain#3840 --- langchain/agents/agent.py | 25 +++++++++++---------- tests/integration_tests/agent/__init__.py | 1 - tests/integration_tests/agent/test_agent.py | 16 ------------- 3 files changed, 13 insertions(+), 29 deletions(-) delete mode 100644 tests/integration_tests/agent/__init__.py delete mode 100644 tests/integration_tests/agent/test_agent.py diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 51972d69..fbb1ff6c 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -7,7 +7,7 @@ import logging import time from abc import abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import yaml from pydantic import BaseModel, root_validator @@ -48,8 +48,8 @@ class BaseSingleActionAgent(BaseModel): """Return values of the agent.""" return ["output"] - def get_allowed_tools(self) -> Set[str]: - return set() + def get_allowed_tools(self) -> Optional[List[str]]: + return None @abstractmethod def plan( @@ -180,8 +180,8 @@ class BaseMultiActionAgent(BaseModel): """Return values of the agent.""" return ["output"] - def get_allowed_tools(self) -> Set[str]: - return set() + def get_allowed_tools(self) -> Optional[List[str]]: + return None @abstractmethod def plan( @@ -374,9 +374,9 @@ class Agent(BaseSingleActionAgent): llm_chain: LLMChain output_parser: AgentOutputParser - allowed_tools: Set[str] = set() + allowed_tools: Optional[List[str]] = None - def get_allowed_tools(self) -> Set[str]: + def get_allowed_tools(self) -> Optional[List[str]]: return self.allowed_tools @property @@ -629,11 +629,12 @@ class AgentExecutor(Chain): agent = values["agent"] tools = values["tools"] allowed_tools = agent.get_allowed_tools() - if allowed_tools != set([tool.name for tool in tools]): - raise ValueError( - f"Allowed tools ({allowed_tools}) different than " - f"provided tools ({[tool.name for tool in tools]})" - ) + if allowed_tools is not None: + if set(allowed_tools) != set([tool.name for tool in tools]): + raise ValueError( + f"Allowed tools ({allowed_tools}) different than " + f"provided tools ({[tool.name for tool in tools]})" + ) return values @root_validator() diff --git a/tests/integration_tests/agent/__init__.py b/tests/integration_tests/agent/__init__.py deleted file mode 100644 index 117480e1..00000000 --- a/tests/integration_tests/agent/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""All integration tests for agent.""" diff --git a/tests/integration_tests/agent/test_agent.py b/tests/integration_tests/agent/test_agent.py deleted file mode 100644 index b423414e..00000000 --- a/tests/integration_tests/agent/test_agent.py +++ /dev/null @@ -1,16 +0,0 @@ -from langchain.agents.chat.base import ChatAgent -from langchain.llms.openai import OpenAI -from langchain.tools.ddg_search.tool import DuckDuckGoSearchRun - - -class TestAgent: - def test_agent_generation(self) -> None: - web_search = DuckDuckGoSearchRun() - tools = [web_search] - agent = ChatAgent.from_llm_and_tools( - ai_name="Tom", - ai_role="Assistant", - tools=tools, - llm=OpenAI(maxTokens=10), - ) - assert agent.allowed_tools == set([web_search.name])