Revert "[agent][property type] Change allowed_tools to Set as Duplicate doesn’t make sense" (#4014)

Reverts hwchase17/langchain#3840
fix_agent_callbacks
Harrison Chase 1 year ago committed by GitHub
parent df3bc707fc
commit a5dd73c1a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,7 @@ import logging
import time
from abc import abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import yaml
from pydantic import BaseModel, root_validator
@ -48,8 +48,8 @@ class BaseSingleActionAgent(BaseModel):
"""Return values of the agent."""
return ["output"]
def get_allowed_tools(self) -> Set[str]:
return set()
def get_allowed_tools(self) -> Optional[List[str]]:
return None
@abstractmethod
def plan(
@ -180,8 +180,8 @@ class BaseMultiActionAgent(BaseModel):
"""Return values of the agent."""
return ["output"]
def get_allowed_tools(self) -> Set[str]:
return set()
def get_allowed_tools(self) -> Optional[List[str]]:
return None
@abstractmethod
def plan(
@ -374,9 +374,9 @@ class Agent(BaseSingleActionAgent):
llm_chain: LLMChain
output_parser: AgentOutputParser
allowed_tools: Set[str] = set()
allowed_tools: Optional[List[str]] = None
def get_allowed_tools(self) -> Set[str]:
def get_allowed_tools(self) -> Optional[List[str]]:
return self.allowed_tools
@property
@ -629,11 +629,12 @@ class AgentExecutor(Chain):
agent = values["agent"]
tools = values["tools"]
allowed_tools = agent.get_allowed_tools()
if allowed_tools != set([tool.name for tool in tools]):
raise ValueError(
f"Allowed tools ({allowed_tools}) different than "
f"provided tools ({[tool.name for tool in tools]})"
)
if allowed_tools is not None:
if set(allowed_tools) != set([tool.name for tool in tools]):
raise ValueError(
f"Allowed tools ({allowed_tools}) different than "
f"provided tools ({[tool.name for tool in tools]})"
)
return values
@root_validator()

@ -1 +0,0 @@
"""All integration tests for agent."""

@ -1,16 +0,0 @@
from langchain.agents.chat.base import ChatAgent
from langchain.llms.openai import OpenAI
from langchain.tools.ddg_search.tool import DuckDuckGoSearchRun
class TestAgent:
def test_agent_generation(self) -> None:
web_search = DuckDuckGoSearchRun()
tools = [web_search]
agent = ChatAgent.from_llm_and_tools(
ai_name="Tom",
ai_role="Assistant",
tools=tools,
llm=OpenAI(maxTokens=10),
)
assert agent.allowed_tools == set([web_search.name])
Loading…
Cancel
Save