forked from Archives/langchain
Revert "[agent][property type] Change allowed_tools to Set as Duplicate doesn’t make sense" (#4014)
Reverts hwchase17/langchain#3840
This commit is contained in:
parent
df3bc707fc
commit
a5dd73c1a6
@ -7,7 +7,7 @@ import logging
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, root_validator
|
||||
@ -48,8 +48,8 @@ class BaseSingleActionAgent(BaseModel):
|
||||
"""Return values of the agent."""
|
||||
return ["output"]
|
||||
|
||||
def get_allowed_tools(self) -> Set[str]:
|
||||
return set()
|
||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def plan(
|
||||
@ -180,8 +180,8 @@ class BaseMultiActionAgent(BaseModel):
|
||||
"""Return values of the agent."""
|
||||
return ["output"]
|
||||
|
||||
def get_allowed_tools(self) -> Set[str]:
|
||||
return set()
|
||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def plan(
|
||||
@ -374,9 +374,9 @@ class Agent(BaseSingleActionAgent):
|
||||
|
||||
llm_chain: LLMChain
|
||||
output_parser: AgentOutputParser
|
||||
allowed_tools: Set[str] = set()
|
||||
allowed_tools: Optional[List[str]] = None
|
||||
|
||||
def get_allowed_tools(self) -> Set[str]:
|
||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||
return self.allowed_tools
|
||||
|
||||
@property
|
||||
@ -629,7 +629,8 @@ class AgentExecutor(Chain):
|
||||
agent = values["agent"]
|
||||
tools = values["tools"]
|
||||
allowed_tools = agent.get_allowed_tools()
|
||||
if allowed_tools != set([tool.name for tool in tools]):
|
||||
if allowed_tools is not None:
|
||||
if set(allowed_tools) != set([tool.name for tool in tools]):
|
||||
raise ValueError(
|
||||
f"Allowed tools ({allowed_tools}) different than "
|
||||
f"provided tools ({[tool.name for tool in tools]})"
|
||||
|
@ -1 +0,0 @@
|
||||
"""All integration tests for agent."""
|
@ -1,16 +0,0 @@
|
||||
from langchain.agents.chat.base import ChatAgent
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.tools.ddg_search.tool import DuckDuckGoSearchRun
|
||||
|
||||
|
||||
class TestAgent:
|
||||
def test_agent_generation(self) -> None:
|
||||
web_search = DuckDuckGoSearchRun()
|
||||
tools = [web_search]
|
||||
agent = ChatAgent.from_llm_and_tools(
|
||||
ai_name="Tom",
|
||||
ai_role="Assistant",
|
||||
tools=tools,
|
||||
llm=OpenAI(maxTokens=10),
|
||||
)
|
||||
assert agent.allowed_tools == set([web_search.name])
|
Loading…
Reference in New Issue
Block a user