forked from Archives/langchain
c7ca350cd3
In LangChain, all module classes are enumerated in the `__init__.py` file of the correspondent module. But some classes were missed and were not included in the module `__init__.py` This PR: - added the missed classes to the module `__init__.py` files - `__init__.py:__all_` variable value (a list of the class names) was sorted - `langchain.tools.sql_database.tool.QueryCheckerTool` was renamed into the `QuerySQLCheckerTool` because it conflicted with `langchain.tools.spark_sql.tool.QueryCheckerTool` - changes to `pyproject.toml`: - added `pgvector` to `pyproject.toml:extended_testing` - added `pandas` to `pyproject.toml:[tool.poetry.group.test.dependencies]` - commented out the `streamlit` from `collbacks/__init__.py`, It is because now the `streamlit` requires Python >=3.7, !=3.9.7 - fixed duplicate names in `tools` - fixed correspondent ut-s #### Who can review? @hwchase17 @dev2049
35 lines
1.2 KiB
Python
35 lines
1.2 KiB
Python
from typing import List, Type
|
|
|
|
import langchain.tools
|
|
from langchain.tools import __all__ as tools_all
|
|
from langchain.tools.base import BaseTool, StructuredTool
|
|
|
|
_EXCLUDE = {
|
|
BaseTool,
|
|
StructuredTool,
|
|
}
|
|
|
|
|
|
def _get_tool_classes(skip_tools_without_default_names: bool) -> List[Type[BaseTool]]:
|
|
results = []
|
|
for tool_class_name in tools_all:
|
|
# Resolve the str to the class
|
|
tool_class = getattr(langchain.tools, tool_class_name)
|
|
if isinstance(tool_class, type) and issubclass(tool_class, BaseTool):
|
|
if tool_class in _EXCLUDE:
|
|
continue
|
|
if skip_tools_without_default_names and tool_class.__fields__[
|
|
"name"
|
|
].default in [None, ""]:
|
|
continue
|
|
results.append(tool_class)
|
|
return results
|
|
|
|
|
|
def test_tool_names_unique() -> None:
|
|
"""Test that the default names for our core tools are unique."""
|
|
tool_classes = _get_tool_classes(skip_tools_without_default_names=True)
|
|
names = sorted([tool_cls.__fields__["name"].default for tool_cls in tool_classes])
|
|
duplicated_names = [name for name in names if names.count(name) > 1]
|
|
assert not duplicated_names
|