mirror of
https://github.com/corca-ai/EVAL
synced 2024-10-30 09:20:44 +00:00
fix: inject get_session method
This commit is contained in:
parent
4c5be38958
commit
41237d5532
@ -34,7 +34,7 @@ class AgentManager:
|
|||||||
*self.global_tools,
|
*self.global_tools,
|
||||||
*ToolsFactory.create_per_session_tools(
|
*ToolsFactory.create_per_session_tools(
|
||||||
self.toolsets,
|
self.toolsets,
|
||||||
session,
|
get_session=lambda: (session, self.executors[session]),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
memory=memory,
|
memory=memory,
|
||||||
@ -47,7 +47,7 @@ class AgentManager:
|
|||||||
|
|
||||||
def get_or_create_executor(self, session: str) -> AgentExecutor:
|
def get_or_create_executor(self, session: str) -> AgentExecutor:
|
||||||
if not (session in self.executors):
|
if not (session in self.executors):
|
||||||
self.executors[session] = self.create_executor(session)
|
self.executors[session] = self.create_executor(session=session)
|
||||||
return self.executors[session]
|
return self.executors[session]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from typing import Callable, Optional, Any
|
from typing import Optional, Callable, Tuple
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from langchain.agents.tools import Tool, BaseTool
|
from langchain.agents.tools import Tool, BaseTool
|
||||||
|
from langchain.agents.agent import AgentExecutor
|
||||||
|
|
||||||
|
|
||||||
class ToolScope(Enum):
|
class ToolScope(Enum):
|
||||||
@ -9,6 +10,9 @@ class ToolScope(Enum):
|
|||||||
SESSION = "session"
|
SESSION = "session"
|
||||||
|
|
||||||
|
|
||||||
|
SessionGetter = Callable[[], Tuple[str, AgentExecutor]]
|
||||||
|
|
||||||
|
|
||||||
def tool(
|
def tool(
|
||||||
name: str,
|
name: str,
|
||||||
description: str,
|
description: str,
|
||||||
@ -37,10 +41,15 @@ class ToolWrapper:
|
|||||||
def is_per_session(self) -> bool:
|
def is_per_session(self) -> bool:
|
||||||
return self.scope == ToolScope.SESSION
|
return self.scope == ToolScope.SESSION
|
||||||
|
|
||||||
def to_tool(self, session: Optional[str] = None) -> BaseTool:
|
def to_tool(
|
||||||
|
self,
|
||||||
|
get_session: SessionGetter = lambda: [],
|
||||||
|
) -> BaseTool:
|
||||||
func = self.func
|
func = self.func
|
||||||
if self.is_per_session():
|
if self.is_per_session():
|
||||||
func = lambda *args, **kwargs: self.func(*args, **kwargs, session=session)
|
func = lambda *args, **kwargs: self.func(
|
||||||
|
*args, **kwargs, get_session=get_session
|
||||||
|
)
|
||||||
|
|
||||||
return Tool(
|
return Tool(
|
||||||
name=self.name,
|
name=self.name,
|
||||||
|
@ -1,17 +1,15 @@
|
|||||||
from env import settings
|
from env import settings
|
||||||
|
|
||||||
from typing import Dict
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from llama_index.readers.database import DatabaseReader
|
from llama_index.readers.database import DatabaseReader
|
||||||
from llama_index import GPTSimpleVectorIndex
|
from llama_index import GPTSimpleVectorIndex
|
||||||
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from langchain.agents.agent import AgentExecutor
|
|
||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
from tools.base import tool, BaseToolSet, ToolScope
|
from tools.base import tool, BaseToolSet, ToolScope, SessionGetter
|
||||||
from logger import logger
|
from logger import logger
|
||||||
|
|
||||||
|
|
||||||
@ -233,9 +231,10 @@ class ExitConversation(BaseToolSet):
|
|||||||
"The output will be a message that the conversation is over.",
|
"The output will be a message that the conversation is over.",
|
||||||
scope=ToolScope.SESSION,
|
scope=ToolScope.SESSION,
|
||||||
)
|
)
|
||||||
def exit(self, *args, session: str) -> str:
|
def exit(self, *args, get_session: SessionGetter) -> str:
|
||||||
"""Run the tool."""
|
"""Run the tool."""
|
||||||
self.executors.pop(session)
|
_, executor = get_session()
|
||||||
|
del executor
|
||||||
|
|
||||||
logger.debug(f"\nProcessed ExitConversation.")
|
logger.debug(f"\nProcessed ExitConversation.")
|
||||||
|
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langchain.agents import load_tools
|
from langchain.agents import load_tools
|
||||||
from langchain.agents.tools import BaseTool
|
from langchain.agents.tools import BaseTool
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
|
|
||||||
from tools.base import BaseToolSet
|
from tools.base import BaseToolSet, SessionGetter
|
||||||
|
|
||||||
|
|
||||||
class ToolsFactory:
|
class ToolsFactory:
|
||||||
@ -13,7 +12,7 @@ class ToolsFactory:
|
|||||||
toolset: BaseToolSet,
|
toolset: BaseToolSet,
|
||||||
only_global: Optional[bool] = False,
|
only_global: Optional[bool] = False,
|
||||||
only_per_session: Optional[bool] = False,
|
only_per_session: Optional[bool] = False,
|
||||||
session: Optional[str] = None,
|
get_session: SessionGetter = lambda: [],
|
||||||
) -> list[BaseTool]:
|
) -> list[BaseTool]:
|
||||||
tools = []
|
tools = []
|
||||||
for wrapper in toolset.tool_wrappers():
|
for wrapper in toolset.tool_wrappers():
|
||||||
@ -21,7 +20,7 @@ class ToolsFactory:
|
|||||||
continue
|
continue
|
||||||
if only_per_session and not wrapper.is_per_session():
|
if only_per_session and not wrapper.is_per_session():
|
||||||
continue
|
continue
|
||||||
tools.append(wrapper.to_tool(session))
|
tools.append(wrapper.to_tool(get_session=get_session))
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -41,7 +40,7 @@ class ToolsFactory:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def create_per_session_tools(
|
def create_per_session_tools(
|
||||||
toolsets: list[BaseToolSet],
|
toolsets: list[BaseToolSet],
|
||||||
session: Optional[str] = None,
|
get_session: SessionGetter = lambda: [],
|
||||||
) -> list[BaseTool]:
|
) -> list[BaseTool]:
|
||||||
tools = []
|
tools = []
|
||||||
for toolset in toolsets:
|
for toolset in toolsets:
|
||||||
@ -49,7 +48,7 @@ class ToolsFactory:
|
|||||||
ToolsFactory.from_toolset(
|
ToolsFactory.from_toolset(
|
||||||
toolset=toolset,
|
toolset=toolset,
|
||||||
only_per_session=True,
|
only_per_session=True,
|
||||||
session=session,
|
get_session=get_session,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return tools
|
return tools
|
||||||
|
Loading…
Reference in New Issue
Block a user