fix: inject get_session method

pull/5/head
hanchchch 1 year ago
parent 4c5be38958
commit 41237d5532

@ -34,7 +34,7 @@ class AgentManager:
*self.global_tools, *self.global_tools,
*ToolsFactory.create_per_session_tools( *ToolsFactory.create_per_session_tools(
self.toolsets, self.toolsets,
session, get_session=lambda: (session, self.executors[session]),
), ),
], ],
memory=memory, memory=memory,
@ -47,7 +47,7 @@ class AgentManager:
def get_or_create_executor(self, session: str) -> AgentExecutor: def get_or_create_executor(self, session: str) -> AgentExecutor:
if not (session in self.executors): if not (session in self.executors):
self.executors[session] = self.create_executor(session) self.executors[session] = self.create_executor(session=session)
return self.executors[session] return self.executors[session]
@staticmethod @staticmethod

@ -1,7 +1,8 @@
from typing import Callable, Optional, Any from typing import Optional, Callable, Tuple
from enum import Enum from enum import Enum
from langchain.agents.tools import Tool, BaseTool from langchain.agents.tools import Tool, BaseTool
from langchain.agents.agent import AgentExecutor
class ToolScope(Enum): class ToolScope(Enum):
@ -9,6 +10,9 @@ class ToolScope(Enum):
SESSION = "session" SESSION = "session"
SessionGetter = Callable[[], Tuple[str, AgentExecutor]]
def tool( def tool(
name: str, name: str,
description: str, description: str,
@ -37,10 +41,15 @@ class ToolWrapper:
def is_per_session(self) -> bool: def is_per_session(self) -> bool:
return self.scope == ToolScope.SESSION return self.scope == ToolScope.SESSION
def to_tool(self, session: Optional[str] = None) -> BaseTool: def to_tool(
self,
get_session: SessionGetter = lambda: [],
) -> BaseTool:
func = self.func func = self.func
if self.is_per_session(): if self.is_per_session():
func = lambda *args, **kwargs: self.func(*args, **kwargs, session=session) func = lambda *args, **kwargs: self.func(
*args, **kwargs, get_session=get_session
)
return Tool( return Tool(
name=self.name, name=self.name,

@ -1,17 +1,15 @@
from env import settings from env import settings
from typing import Dict
import requests import requests
from llama_index.readers.database import DatabaseReader from llama_index.readers.database import DatabaseReader
from llama_index import GPTSimpleVectorIndex from llama_index import GPTSimpleVectorIndex
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from langchain.agents.agent import AgentExecutor
import subprocess import subprocess
from tools.base import tool, BaseToolSet, ToolScope from tools.base import tool, BaseToolSet, ToolScope, SessionGetter
from logger import logger from logger import logger
@ -233,9 +231,10 @@ class ExitConversation(BaseToolSet):
"The output will be a message that the conversation is over.", "The output will be a message that the conversation is over.",
scope=ToolScope.SESSION, scope=ToolScope.SESSION,
) )
def exit(self, *args, session: str) -> str: def exit(self, *args, get_session: SessionGetter) -> str:
"""Run the tool.""" """Run the tool."""
self.executors.pop(session) _, executor = get_session()
del executor
logger.debug(f"\nProcessed ExitConversation.") logger.debug(f"\nProcessed ExitConversation.")

@ -1,10 +1,9 @@
from typing import Optional from typing import Optional
from langchain.agents import load_tools from langchain.agents import load_tools
from langchain.agents.tools import BaseTool from langchain.agents.tools import BaseTool
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from tools.base import BaseToolSet from tools.base import BaseToolSet, SessionGetter
class ToolsFactory: class ToolsFactory:
@ -13,7 +12,7 @@ class ToolsFactory:
toolset: BaseToolSet, toolset: BaseToolSet,
only_global: Optional[bool] = False, only_global: Optional[bool] = False,
only_per_session: Optional[bool] = False, only_per_session: Optional[bool] = False,
session: Optional[str] = None, get_session: SessionGetter = lambda: [],
) -> list[BaseTool]: ) -> list[BaseTool]:
tools = [] tools = []
for wrapper in toolset.tool_wrappers(): for wrapper in toolset.tool_wrappers():
@ -21,7 +20,7 @@ class ToolsFactory:
continue continue
if only_per_session and not wrapper.is_per_session(): if only_per_session and not wrapper.is_per_session():
continue continue
tools.append(wrapper.to_tool(session)) tools.append(wrapper.to_tool(get_session=get_session))
return tools return tools
@staticmethod @staticmethod
@ -41,7 +40,7 @@ class ToolsFactory:
@staticmethod @staticmethod
def create_per_session_tools( def create_per_session_tools(
toolsets: list[BaseToolSet], toolsets: list[BaseToolSet],
session: Optional[str] = None, get_session: SessionGetter = lambda: [],
) -> list[BaseTool]: ) -> list[BaseTool]:
tools = [] tools = []
for toolset in toolsets: for toolset in toolsets:
@ -49,7 +48,7 @@ class ToolsFactory:
ToolsFactory.from_toolset( ToolsFactory.from_toolset(
toolset=toolset, toolset=toolset,
only_per_session=True, only_per_session=True,
session=session, get_session=get_session,
) )
) )
return tools return tools

Loading…
Cancel
Save