mirror of
https://github.com/corca-ai/EVAL
synced 2024-10-30 09:20:44 +00:00
fix: inject get_session method
This commit is contained in:
parent
4c5be38958
commit
41237d5532
@ -34,7 +34,7 @@ class AgentManager:
|
||||
*self.global_tools,
|
||||
*ToolsFactory.create_per_session_tools(
|
||||
self.toolsets,
|
||||
session,
|
||||
get_session=lambda: (session, self.executors[session]),
|
||||
),
|
||||
],
|
||||
memory=memory,
|
||||
@ -47,7 +47,7 @@ class AgentManager:
|
||||
|
||||
def get_or_create_executor(self, session: str) -> AgentExecutor:
|
||||
if not (session in self.executors):
|
||||
self.executors[session] = self.create_executor(session)
|
||||
self.executors[session] = self.create_executor(session=session)
|
||||
return self.executors[session]
|
||||
|
||||
@staticmethod
|
||||
|
@ -1,7 +1,8 @@
|
||||
from typing import Callable, Optional, Any
|
||||
from typing import Optional, Callable, Tuple
|
||||
from enum import Enum
|
||||
|
||||
from langchain.agents.tools import Tool, BaseTool
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
|
||||
|
||||
class ToolScope(Enum):
|
||||
@ -9,6 +10,9 @@ class ToolScope(Enum):
|
||||
SESSION = "session"
|
||||
|
||||
|
||||
SessionGetter = Callable[[], Tuple[str, AgentExecutor]]
|
||||
|
||||
|
||||
def tool(
|
||||
name: str,
|
||||
description: str,
|
||||
@ -37,10 +41,15 @@ class ToolWrapper:
|
||||
def is_per_session(self) -> bool:
|
||||
return self.scope == ToolScope.SESSION
|
||||
|
||||
def to_tool(self, session: Optional[str] = None) -> BaseTool:
|
||||
def to_tool(
|
||||
self,
|
||||
get_session: SessionGetter = lambda: [],
|
||||
) -> BaseTool:
|
||||
func = self.func
|
||||
if self.is_per_session():
|
||||
func = lambda *args, **kwargs: self.func(*args, **kwargs, session=session)
|
||||
func = lambda *args, **kwargs: self.func(
|
||||
*args, **kwargs, get_session=get_session
|
||||
)
|
||||
|
||||
return Tool(
|
||||
name=self.name,
|
||||
|
@ -1,17 +1,15 @@
|
||||
from env import settings
|
||||
|
||||
from typing import Dict
|
||||
import requests
|
||||
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from llama_index import GPTSimpleVectorIndex
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
|
||||
import subprocess
|
||||
|
||||
from tools.base import tool, BaseToolSet, ToolScope
|
||||
from tools.base import tool, BaseToolSet, ToolScope, SessionGetter
|
||||
from logger import logger
|
||||
|
||||
|
||||
@ -233,9 +231,10 @@ class ExitConversation(BaseToolSet):
|
||||
"The output will be a message that the conversation is over.",
|
||||
scope=ToolScope.SESSION,
|
||||
)
|
||||
def exit(self, *args, session: str) -> str:
|
||||
def exit(self, *args, get_session: SessionGetter) -> str:
|
||||
"""Run the tool."""
|
||||
self.executors.pop(session)
|
||||
_, executor = get_session()
|
||||
del executor
|
||||
|
||||
logger.debug(f"\nProcessed ExitConversation.")
|
||||
|
||||
|
@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.agents import load_tools
|
||||
from langchain.agents.tools import BaseTool
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
from tools.base import BaseToolSet
|
||||
from tools.base import BaseToolSet, SessionGetter
|
||||
|
||||
|
||||
class ToolsFactory:
|
||||
@ -13,7 +12,7 @@ class ToolsFactory:
|
||||
toolset: BaseToolSet,
|
||||
only_global: Optional[bool] = False,
|
||||
only_per_session: Optional[bool] = False,
|
||||
session: Optional[str] = None,
|
||||
get_session: SessionGetter = lambda: [],
|
||||
) -> list[BaseTool]:
|
||||
tools = []
|
||||
for wrapper in toolset.tool_wrappers():
|
||||
@ -21,7 +20,7 @@ class ToolsFactory:
|
||||
continue
|
||||
if only_per_session and not wrapper.is_per_session():
|
||||
continue
|
||||
tools.append(wrapper.to_tool(session))
|
||||
tools.append(wrapper.to_tool(get_session=get_session))
|
||||
return tools
|
||||
|
||||
@staticmethod
|
||||
@ -41,7 +40,7 @@ class ToolsFactory:
|
||||
@staticmethod
|
||||
def create_per_session_tools(
|
||||
toolsets: list[BaseToolSet],
|
||||
session: Optional[str] = None,
|
||||
get_session: SessionGetter = lambda: [],
|
||||
) -> list[BaseTool]:
|
||||
tools = []
|
||||
for toolset in toolsets:
|
||||
@ -49,7 +48,7 @@ class ToolsFactory:
|
||||
ToolsFactory.from_toolset(
|
||||
toolset=toolset,
|
||||
only_per_session=True,
|
||||
session=session,
|
||||
get_session=get_session,
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
Loading…
Reference in New Issue
Block a user