fix: inject get_session method

This commit is contained in:
hanchchch 2023-03-22 04:48:27 +00:00
parent 4c5be38958
commit 41237d5532
4 changed files with 23 additions and 16 deletions

View File

@ -34,7 +34,7 @@ class AgentManager:
*self.global_tools, *self.global_tools,
*ToolsFactory.create_per_session_tools( *ToolsFactory.create_per_session_tools(
self.toolsets, self.toolsets,
session, get_session=lambda: (session, self.executors[session]),
), ),
], ],
memory=memory, memory=memory,
@ -47,7 +47,7 @@ class AgentManager:
def get_or_create_executor(self, session: str) -> AgentExecutor: def get_or_create_executor(self, session: str) -> AgentExecutor:
if not (session in self.executors): if not (session in self.executors):
self.executors[session] = self.create_executor(session) self.executors[session] = self.create_executor(session=session)
return self.executors[session] return self.executors[session]
@staticmethod @staticmethod

View File

@ -1,7 +1,8 @@
from typing import Callable, Optional, Any from typing import Optional, Callable, Tuple
from enum import Enum from enum import Enum
from langchain.agents.tools import Tool, BaseTool from langchain.agents.tools import Tool, BaseTool
from langchain.agents.agent import AgentExecutor
class ToolScope(Enum): class ToolScope(Enum):
@ -9,6 +10,9 @@ class ToolScope(Enum):
SESSION = "session" SESSION = "session"
SessionGetter = Callable[[], Tuple[str, AgentExecutor]]
def tool( def tool(
name: str, name: str,
description: str, description: str,
@ -37,10 +41,15 @@ class ToolWrapper:
def is_per_session(self) -> bool: def is_per_session(self) -> bool:
return self.scope == ToolScope.SESSION return self.scope == ToolScope.SESSION
def to_tool(self, session: Optional[str] = None) -> BaseTool: def to_tool(
self,
get_session: SessionGetter = lambda: [],
) -> BaseTool:
func = self.func func = self.func
if self.is_per_session(): if self.is_per_session():
func = lambda *args, **kwargs: self.func(*args, **kwargs, session=session) func = lambda *args, **kwargs: self.func(
*args, **kwargs, get_session=get_session
)
return Tool( return Tool(
name=self.name, name=self.name,

View File

@ -1,17 +1,15 @@
from env import settings from env import settings
from typing import Dict
import requests import requests
from llama_index.readers.database import DatabaseReader from llama_index.readers.database import DatabaseReader
from llama_index import GPTSimpleVectorIndex from llama_index import GPTSimpleVectorIndex
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from langchain.agents.agent import AgentExecutor
import subprocess import subprocess
from tools.base import tool, BaseToolSet, ToolScope from tools.base import tool, BaseToolSet, ToolScope, SessionGetter
from logger import logger from logger import logger
@ -233,9 +231,10 @@ class ExitConversation(BaseToolSet):
"The output will be a message that the conversation is over.", "The output will be a message that the conversation is over.",
scope=ToolScope.SESSION, scope=ToolScope.SESSION,
) )
def exit(self, *args, session: str) -> str: def exit(self, *args, get_session: SessionGetter) -> str:
"""Run the tool.""" """Run the tool."""
self.executors.pop(session) _, executor = get_session()
del executor
logger.debug(f"\nProcessed ExitConversation.") logger.debug(f"\nProcessed ExitConversation.")

View File

@ -1,10 +1,9 @@
from typing import Optional from typing import Optional
from langchain.agents import load_tools from langchain.agents import load_tools
from langchain.agents.tools import BaseTool from langchain.agents.tools import BaseTool
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from tools.base import BaseToolSet from tools.base import BaseToolSet, SessionGetter
class ToolsFactory: class ToolsFactory:
@ -13,7 +12,7 @@ class ToolsFactory:
toolset: BaseToolSet, toolset: BaseToolSet,
only_global: Optional[bool] = False, only_global: Optional[bool] = False,
only_per_session: Optional[bool] = False, only_per_session: Optional[bool] = False,
session: Optional[str] = None, get_session: SessionGetter = lambda: [],
) -> list[BaseTool]: ) -> list[BaseTool]:
tools = [] tools = []
for wrapper in toolset.tool_wrappers(): for wrapper in toolset.tool_wrappers():
@ -21,7 +20,7 @@ class ToolsFactory:
continue continue
if only_per_session and not wrapper.is_per_session(): if only_per_session and not wrapper.is_per_session():
continue continue
tools.append(wrapper.to_tool(session)) tools.append(wrapper.to_tool(get_session=get_session))
return tools return tools
@staticmethod @staticmethod
@ -41,7 +40,7 @@ class ToolsFactory:
@staticmethod @staticmethod
def create_per_session_tools( def create_per_session_tools(
toolsets: list[BaseToolSet], toolsets: list[BaseToolSet],
session: Optional[str] = None, get_session: SessionGetter = lambda: [],
) -> list[BaseTool]: ) -> list[BaseTool]:
tools = [] tools = []
for toolset in toolsets: for toolset in toolsets:
@ -49,7 +48,7 @@ class ToolsFactory:
ToolsFactory.from_toolset( ToolsFactory.from_toolset(
toolset=toolset, toolset=toolset,
only_per_session=True, only_per_session=True,
session=session, get_session=get_session,
) )
) )
return tools return tools