langchain/libs/community/tests/unit_tests/test_imports.py
Eugene Yurtsev 36813d2f00
community[patch]: Fix remaining __inits__ in community (#22037)
Fixes the __init__ files in community to use __all__ which is statically
defined.
2024-05-22 17:42:17 +00:00

171 lines
5.8 KiB
Python

import ast
import glob
import importlib
from pathlib import Path
from typing import List, Tuple
COMMUNITY_ROOT = Path(__file__).parent.parent.parent / "langchain_community"
ALL_COMMUNITY_GLOB = COMMUNITY_ROOT.as_posix() + "/**/*.py"
HERE = Path(__file__).parent
ROOT = HERE.parent.parent
def test_importable_all() -> None:
for path in glob.glob(ALL_COMMUNITY_GLOB):
# Relative to community root
relative_path = Path(path).relative_to(COMMUNITY_ROOT)
str_path = str(relative_path)
if str_path.endswith("__init__.py"):
module_name = str(relative_path.parent).replace("/", ".")
else:
module_name = str(relative_path.with_suffix("")).replace("/", ".")
try:
module = importlib.import_module("langchain_community." + module_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Could not import `{module_name}`. Defined in path: {path}"
) from e
all_ = getattr(module, "__all__", [])
for cls_ in all_:
getattr(module, cls_)
def test_glob_correct() -> None:
"""Verify that the glob pattern is correct."""
paths = list(glob.glob(ALL_COMMUNITY_GLOB))
# Get paths relative to community root
paths_ = [Path(path).relative_to(COMMUNITY_ROOT) for path in paths]
# Assert there's a callback paths
assert Path("callbacks/__init__.py") in paths_
def _check_correct_or_not_defined__all__(code: str) -> bool:
"""Return True if __all__ is correctly defined or not defined at all."""
# Parse the code into an AST
tree = ast.parse(code)
all_good = True
# Iterate through the body of the AST to find assignments
for node in tree.body:
# Check if the node is an assignment
if isinstance(node, ast.Assign):
# Check if the target of the assignment is '__all__'
for target in node.targets:
if isinstance(target, ast.Name) and target.id == "__all__":
# Check if the value assigned is a list
if isinstance(node.value, ast.List):
# Verify all elements in the list are string literals
if all(isinstance(el, ast.Str) for el in node.value.elts):
pass
else:
all_good = False
else:
all_good = False
return all_good
def test_no_dynamic__all__() -> None:
"""Verify that __all__ is not computed at runtime.
Computing __all__ dynamically can confuse static typing tools like pyright.
__all__ should always be listed as an explicit list of string literals.
"""
bad_definitions = []
for path in glob.glob(ALL_COMMUNITY_GLOB):
if not path.endswith("__init__.py"):
continue
with open(path, "r") as file:
code = file.read()
if _check_correct_or_not_defined__all__(code) is False:
bad_definitions.append(path)
if bad_definitions:
raise AssertionError(
f"__all__ is not correctly defined in the "
f"following files: {sorted(bad_definitions)}"
)
def _extract_type_checking_imports(code: str) -> List[Tuple[str, str]]:
"""Extract all TYPE CHECKING imports that import from langchain_community."""
imports: List[Tuple[str, str]] = []
tree = ast.parse(code)
class TypeCheckingVisitor(ast.NodeVisitor):
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module:
for alias in node.names:
imports.append((node.module, alias.name))
class GlobalScopeVisitor(ast.NodeVisitor):
def visit_If(self, node: ast.If) -> None:
if (
isinstance(node.test, ast.Name)
and node.test.id == "TYPE_CHECKING"
and isinstance(node.test.ctx, ast.Load)
):
TypeCheckingVisitor().visit(node)
self.generic_visit(node)
GlobalScopeVisitor().visit(tree)
return imports
def test_init_files_properly_defined() -> None:
"""This is part of a set of tests that verify that init files are properly
defined if they're using dynamic imports.
"""
# Please never ever add more modules to this list.
# Do feel free to fix the underlying issues and remove exceptions
# from the list.
excepted_modules = {"llms"} # NEVER ADD MORE MODULES TO THIS LIST
for path in glob.glob(ALL_COMMUNITY_GLOB):
# Relative to community root
relative_path = Path(path).relative_to(COMMUNITY_ROOT)
str_path = str(relative_path)
if not str_path.endswith("__init__.py"):
continue
module_name = str(relative_path.parent).replace("/", ".")
if module_name in excepted_modules:
continue
code = Path(path).read_text()
# Check for dynamic __getattr__ definition in the __init__ file
if "__getattr__" not in code:
continue
try:
module = importlib.import_module("langchain_community." + module_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Could not import `{module_name}`. Defined in path: {path}"
) from e
if not hasattr(module, "__all__"):
raise AssertionError(
f"__all__ not defined in {module_name}. This is required "
f"if __getattr__ is defined."
)
imports = _extract_type_checking_imports(code)
# Get the names of all the TYPE CHECKING imports
names = [name for _, name in imports]
missing_imports = set(module.__all__) - set(names)
assert (
not missing_imports
), f"Missing imports: {missing_imports} in file path: {path}"