diff --git a/agents/manager.py b/agents/manager.py index 2e14267..d675140 100644 --- a/agents/manager.py +++ b/agents/manager.py @@ -34,7 +34,7 @@ class AgentManager: *self.global_tools, *ToolsFactory.create_per_session_tools( self.toolsets, - session, + get_session=lambda: (session, self.executors[session]), ), ], memory=memory, @@ -47,7 +47,7 @@ class AgentManager: def get_or_create_executor(self, session: str) -> AgentExecutor: if not (session in self.executors): - self.executors[session] = self.create_executor(session) + self.executors[session] = self.create_executor(session=session) return self.executors[session] @staticmethod diff --git a/tools/base.py b/tools/base.py index f2dc980..14f2f1d 100644 --- a/tools/base.py +++ b/tools/base.py @@ -1,7 +1,8 @@ -from typing import Callable, Optional, Any +from typing import Optional, Callable, Tuple from enum import Enum from langchain.agents.tools import Tool, BaseTool +from langchain.agents.agent import AgentExecutor class ToolScope(Enum): @@ -9,6 +10,9 @@ class ToolScope(Enum): SESSION = "session" +SessionGetter = Callable[[], Tuple[str, AgentExecutor]] + + def tool( name: str, description: str, @@ -37,10 +41,15 @@ class ToolWrapper: def is_per_session(self) -> bool: return self.scope == ToolScope.SESSION - def to_tool(self, session: Optional[str] = None) -> BaseTool: + def to_tool( + self, + get_session: SessionGetter = lambda: [], + ) -> BaseTool: func = self.func if self.is_per_session(): - func = lambda *args, **kwargs: self.func(*args, **kwargs, session=session) + func = lambda *args, **kwargs: self.func( + *args, **kwargs, get_session=get_session + ) return Tool( name=self.name, diff --git a/tools/cpu.py b/tools/cpu.py index 7607a19..a296260 100644 --- a/tools/cpu.py +++ b/tools/cpu.py @@ -1,17 +1,15 @@ from env import settings -from typing import Dict import requests from llama_index.readers.database import DatabaseReader from llama_index import GPTSimpleVectorIndex from bs4 import BeautifulSoup -from langchain.agents.agent import AgentExecutor import subprocess -from tools.base import tool, BaseToolSet, ToolScope +from tools.base import tool, BaseToolSet, ToolScope, SessionGetter from logger import logger @@ -233,9 +231,10 @@ class ExitConversation(BaseToolSet): "The output will be a message that the conversation is over.", scope=ToolScope.SESSION, ) - def exit(self, *args, session: str) -> str: + def exit(self, *args, get_session: SessionGetter) -> str: """Run the tool.""" - self.executors.pop(session) + _, executor = get_session() + del executor logger.debug(f"\nProcessed ExitConversation.") diff --git a/tools/factory.py b/tools/factory.py index aa6106a..a327959 100644 --- a/tools/factory.py +++ b/tools/factory.py @@ -1,10 +1,9 @@ from typing import Optional - from langchain.agents import load_tools from langchain.agents.tools import BaseTool from langchain.llms.base import BaseLLM -from tools.base import BaseToolSet +from tools.base import BaseToolSet, SessionGetter class ToolsFactory: @@ -13,7 +12,7 @@ class ToolsFactory: toolset: BaseToolSet, only_global: Optional[bool] = False, only_per_session: Optional[bool] = False, - session: Optional[str] = None, + get_session: SessionGetter = lambda: [], ) -> list[BaseTool]: tools = [] for wrapper in toolset.tool_wrappers(): @@ -21,7 +20,7 @@ class ToolsFactory: continue if only_per_session and not wrapper.is_per_session(): continue - tools.append(wrapper.to_tool(session)) + tools.append(wrapper.to_tool(get_session=get_session)) return tools @staticmethod @@ -41,7 +40,7 @@ class ToolsFactory: @staticmethod def create_per_session_tools( toolsets: list[BaseToolSet], - session: Optional[str] = None, + get_session: SessionGetter = lambda: [], ) -> list[BaseTool]: tools = [] for toolset in toolsets: @@ -49,7 +48,7 @@ class ToolsFactory: ToolsFactory.from_toolset( toolset=toolset, only_per_session=True, - session=session, + get_session=get_session, ) ) return tools