fix: inject get_session method

This commit is contained in:
hanchchch 2023-03-22 04:48:27 +00:00
parent 4c5be38958
commit 41237d5532
4 changed files with 23 additions and 16 deletions

View File

@ -34,7 +34,7 @@ class AgentManager:
*self.global_tools,
*ToolsFactory.create_per_session_tools(
self.toolsets,
session,
get_session=lambda: (session, self.executors[session]),
),
],
memory=memory,
@ -47,7 +47,7 @@ class AgentManager:
def get_or_create_executor(self, session: str) -> AgentExecutor:
if not (session in self.executors):
self.executors[session] = self.create_executor(session)
self.executors[session] = self.create_executor(session=session)
return self.executors[session]
@staticmethod

View File

@ -1,7 +1,8 @@
from typing import Callable, Optional, Any
from typing import Optional, Callable, Tuple
from enum import Enum
from langchain.agents.tools import Tool, BaseTool
from langchain.agents.agent import AgentExecutor
class ToolScope(Enum):
@ -9,6 +10,9 @@ class ToolScope(Enum):
SESSION = "session"
SessionGetter = Callable[[], Tuple[str, AgentExecutor]]
def tool(
name: str,
description: str,
@ -37,10 +41,15 @@ class ToolWrapper:
def is_per_session(self) -> bool:
return self.scope == ToolScope.SESSION
def to_tool(self, session: Optional[str] = None) -> BaseTool:
def to_tool(
self,
get_session: SessionGetter = lambda: [],
) -> BaseTool:
func = self.func
if self.is_per_session():
func = lambda *args, **kwargs: self.func(*args, **kwargs, session=session)
func = lambda *args, **kwargs: self.func(
*args, **kwargs, get_session=get_session
)
return Tool(
name=self.name,

View File

@ -1,17 +1,15 @@
from env import settings
from typing import Dict
import requests
from llama_index.readers.database import DatabaseReader
from llama_index import GPTSimpleVectorIndex
from bs4 import BeautifulSoup
from langchain.agents.agent import AgentExecutor
import subprocess
from tools.base import tool, BaseToolSet, ToolScope
from tools.base import tool, BaseToolSet, ToolScope, SessionGetter
from logger import logger
@ -233,9 +231,10 @@ class ExitConversation(BaseToolSet):
"The output will be a message that the conversation is over.",
scope=ToolScope.SESSION,
)
def exit(self, *args, session: str) -> str:
def exit(self, *args, get_session: SessionGetter) -> str:
"""Run the tool."""
self.executors.pop(session)
_, executor = get_session()
del executor
logger.debug(f"\nProcessed ExitConversation.")

View File

@ -1,10 +1,9 @@
from typing import Optional
from langchain.agents import load_tools
from langchain.agents.tools import BaseTool
from langchain.llms.base import BaseLLM
from tools.base import BaseToolSet
from tools.base import BaseToolSet, SessionGetter
class ToolsFactory:
@ -13,7 +12,7 @@ class ToolsFactory:
toolset: BaseToolSet,
only_global: Optional[bool] = False,
only_per_session: Optional[bool] = False,
session: Optional[str] = None,
get_session: SessionGetter = lambda: [],
) -> list[BaseTool]:
tools = []
for wrapper in toolset.tool_wrappers():
@ -21,7 +20,7 @@ class ToolsFactory:
continue
if only_per_session and not wrapper.is_per_session():
continue
tools.append(wrapper.to_tool(session))
tools.append(wrapper.to_tool(get_session=get_session))
return tools
@staticmethod
@ -41,7 +40,7 @@ class ToolsFactory:
@staticmethod
def create_per_session_tools(
toolsets: list[BaseToolSet],
session: Optional[str] = None,
get_session: SessionGetter = lambda: [],
) -> list[BaseTool]:
tools = []
for toolset in toolsets:
@ -49,7 +48,7 @@ class ToolsFactory:
ToolsFactory.from_toolset(
toolset=toolset,
only_per_session=True,
session=session,
get_session=get_session,
)
)
return tools