EVAL/tools/factory.py

62 lines
1.8 KiB
Python
Raw Normal View History

2023-03-18 12:26:19 +00:00
from typing import Optional
from langchain.agents import load_tools
from langchain.agents.tools import BaseTool
2023-03-20 08:27:20 +00:00
from langchain.llms.base import BaseLLM
2023-03-18 12:26:19 +00:00
2023-03-22 04:48:27 +00:00
from tools.base import BaseToolSet, SessionGetter
2023-03-18 12:26:19 +00:00
class ToolsFactory:
@staticmethod
def from_toolset(
toolset: BaseToolSet,
only_global: Optional[bool] = False,
only_per_session: Optional[bool] = False,
2023-03-22 04:48:27 +00:00
get_session: SessionGetter = lambda: [],
) -> list[BaseTool]:
tools = []
for wrapper in toolset.tool_wrappers():
if only_global and not wrapper.is_global():
continue
if only_per_session and not wrapper.is_per_session():
continue
2023-03-22 04:48:27 +00:00
tools.append(wrapper.to_tool(get_session=get_session))
return tools
@staticmethod
def create_global_tools(
toolsets: list[BaseToolSet],
) -> list[BaseTool]:
tools = []
for toolset in toolsets:
tools.extend(
ToolsFactory.from_toolset(
toolset=toolset,
only_global=True,
)
)
return tools
@staticmethod
def create_per_session_tools(
toolsets: list[BaseToolSet],
2023-03-22 04:48:27 +00:00
get_session: SessionGetter = lambda: [],
) -> list[BaseTool]:
2023-03-18 12:26:19 +00:00
tools = []
for toolset in toolsets:
tools.extend(
ToolsFactory.from_toolset(
toolset=toolset,
only_per_session=True,
2023-03-22 04:48:27 +00:00
get_session=get_session,
)
)
2023-03-18 12:26:19 +00:00
return tools
@staticmethod
def create_global_tools_from_names(
toolnames: list[str],
llm: Optional[BaseLLM],
) -> list[BaseTool]:
2023-03-18 12:26:19 +00:00
return load_tools(toolnames, llm=llm)