2024-05-22 13:32:13 +00:00
|
|
|
import ast
|
2023-12-11 21:53:30 +00:00
|
|
|
import glob
|
|
|
|
import importlib
|
2023-12-21 21:45:42 +00:00
|
|
|
from pathlib import Path
|
2024-05-22 17:19:00 +00:00
|
|
|
from typing import List, Tuple
|
2023-12-11 21:53:30 +00:00
|
|
|
|
2024-05-22 13:32:13 +00:00
|
|
|
COMMUNITY_ROOT = Path(__file__).parent.parent.parent / "langchain_community"
|
|
|
|
ALL_COMMUNITY_GLOB = COMMUNITY_ROOT.as_posix() + "/**/*.py"
|
|
|
|
HERE = Path(__file__).parent
|
|
|
|
ROOT = HERE.parent.parent
|
|
|
|
|
2023-12-11 21:53:30 +00:00
|
|
|
|
|
|
|
def test_importable_all() -> None:
|
2024-05-22 13:32:13 +00:00
|
|
|
for path in glob.glob(ALL_COMMUNITY_GLOB):
|
|
|
|
# Relative to community root
|
|
|
|
relative_path = Path(path).relative_to(COMMUNITY_ROOT)
|
|
|
|
str_path = str(relative_path)
|
|
|
|
if str_path.endswith("__init__.py"):
|
|
|
|
module_name = str(relative_path.parent).replace("/", ".")
|
|
|
|
else:
|
|
|
|
module_name = str(relative_path.with_suffix("")).replace("/", ".")
|
2023-12-11 21:53:30 +00:00
|
|
|
|
2024-05-22 13:32:13 +00:00
|
|
|
try:
|
|
|
|
module = importlib.import_module("langchain_community." + module_name)
|
|
|
|
except ModuleNotFoundError as e:
|
|
|
|
raise ModuleNotFoundError(
|
|
|
|
f"Could not import `{module_name}`. Defined in path: {path}"
|
|
|
|
) from e
|
2023-12-11 21:53:30 +00:00
|
|
|
all_ = getattr(module, "__all__", [])
|
|
|
|
for cls_ in all_:
|
|
|
|
getattr(module, cls_)
|
2024-05-22 13:32:13 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_glob_correct() -> None:
|
|
|
|
"""Verify that the glob pattern is correct."""
|
|
|
|
paths = list(glob.glob(ALL_COMMUNITY_GLOB))
|
|
|
|
# Get paths relative to community root
|
|
|
|
paths_ = [Path(path).relative_to(COMMUNITY_ROOT) for path in paths]
|
|
|
|
# Assert there's a callback paths
|
|
|
|
assert Path("callbacks/__init__.py") in paths_
|
|
|
|
|
|
|
|
|
|
|
|
def _check_correct_or_not_defined__all__(code: str) -> bool:
|
|
|
|
"""Return True if __all__ is correctly defined or not defined at all."""
|
|
|
|
# Parse the code into an AST
|
|
|
|
tree = ast.parse(code)
|
|
|
|
|
|
|
|
all_good = True
|
|
|
|
|
|
|
|
# Iterate through the body of the AST to find assignments
|
|
|
|
for node in tree.body:
|
|
|
|
# Check if the node is an assignment
|
|
|
|
if isinstance(node, ast.Assign):
|
|
|
|
# Check if the target of the assignment is '__all__'
|
|
|
|
for target in node.targets:
|
|
|
|
if isinstance(target, ast.Name) and target.id == "__all__":
|
|
|
|
# Check if the value assigned is a list
|
|
|
|
if isinstance(node.value, ast.List):
|
|
|
|
# Verify all elements in the list are string literals
|
|
|
|
if all(isinstance(el, ast.Str) for el in node.value.elts):
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
all_good = False
|
|
|
|
else:
|
|
|
|
all_good = False
|
|
|
|
return all_good
|
|
|
|
|
|
|
|
|
|
|
|
def test_no_dynamic__all__() -> None:
|
|
|
|
"""Verify that __all__ is not computed at runtime.
|
|
|
|
|
|
|
|
Computing __all__ dynamically can confuse static typing tools like pyright.
|
|
|
|
|
|
|
|
__all__ should always be listed as an explicit list of string literals.
|
|
|
|
"""
|
|
|
|
bad_definitions = []
|
|
|
|
for path in glob.glob(ALL_COMMUNITY_GLOB):
|
|
|
|
if not path.endswith("__init__.py"):
|
|
|
|
continue
|
|
|
|
|
|
|
|
with open(path, "r") as file:
|
|
|
|
code = file.read()
|
|
|
|
|
|
|
|
if _check_correct_or_not_defined__all__(code) is False:
|
|
|
|
bad_definitions.append(path)
|
|
|
|
|
|
|
|
if bad_definitions:
|
|
|
|
raise AssertionError(
|
|
|
|
f"__all__ is not correctly defined in the "
|
|
|
|
f"following files: {sorted(bad_definitions)}"
|
|
|
|
)
|
2024-05-22 17:19:00 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _extract_type_checking_imports(code: str) -> List[Tuple[str, str]]:
|
|
|
|
"""Extract all TYPE CHECKING imports that import from langchain_community."""
|
|
|
|
imports: List[Tuple[str, str]] = []
|
|
|
|
|
|
|
|
tree = ast.parse(code)
|
|
|
|
|
|
|
|
class TypeCheckingVisitor(ast.NodeVisitor):
|
|
|
|
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
|
|
|
if node.module:
|
|
|
|
for alias in node.names:
|
|
|
|
imports.append((node.module, alias.name))
|
|
|
|
|
|
|
|
class GlobalScopeVisitor(ast.NodeVisitor):
|
|
|
|
def visit_If(self, node: ast.If) -> None:
|
|
|
|
if (
|
|
|
|
isinstance(node.test, ast.Name)
|
|
|
|
and node.test.id == "TYPE_CHECKING"
|
|
|
|
and isinstance(node.test.ctx, ast.Load)
|
|
|
|
):
|
|
|
|
TypeCheckingVisitor().visit(node)
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
|
|
GlobalScopeVisitor().visit(tree)
|
|
|
|
return imports
|
|
|
|
|
|
|
|
|
|
|
|
def test_init_files_properly_defined() -> None:
|
|
|
|
"""This is part of a set of tests that verify that init files are properly
|
|
|
|
|
|
|
|
defined if they're using dynamic imports.
|
|
|
|
"""
|
|
|
|
# Please never ever add more modules to this list.
|
|
|
|
# Do feel free to fix the underlying issues and remove exceptions
|
|
|
|
# from the list.
|
|
|
|
excepted_modules = {"llms"} # NEVER ADD MORE MODULES TO THIS LIST
|
|
|
|
for path in glob.glob(ALL_COMMUNITY_GLOB):
|
|
|
|
# Relative to community root
|
|
|
|
relative_path = Path(path).relative_to(COMMUNITY_ROOT)
|
|
|
|
str_path = str(relative_path)
|
|
|
|
|
|
|
|
if not str_path.endswith("__init__.py"):
|
|
|
|
continue
|
|
|
|
|
|
|
|
module_name = str(relative_path.parent).replace("/", ".")
|
|
|
|
|
|
|
|
if module_name in excepted_modules:
|
|
|
|
continue
|
|
|
|
|
|
|
|
code = Path(path).read_text()
|
|
|
|
|
|
|
|
# Check for dynamic __getattr__ definition in the __init__ file
|
|
|
|
if "__getattr__" not in code:
|
|
|
|
continue
|
|
|
|
|
|
|
|
try:
|
|
|
|
module = importlib.import_module("langchain_community." + module_name)
|
|
|
|
except ModuleNotFoundError as e:
|
|
|
|
raise ModuleNotFoundError(
|
|
|
|
f"Could not import `{module_name}`. Defined in path: {path}"
|
|
|
|
) from e
|
|
|
|
|
|
|
|
if not hasattr(module, "__all__"):
|
|
|
|
raise AssertionError(
|
|
|
|
f"__all__ not defined in {module_name}. This is required "
|
|
|
|
f"if __getattr__ is defined."
|
|
|
|
)
|
|
|
|
|
|
|
|
imports = _extract_type_checking_imports(code)
|
|
|
|
|
|
|
|
# Get the names of all the TYPE CHECKING imports
|
|
|
|
names = [name for _, name in imports]
|
|
|
|
|
|
|
|
missing_imports = set(module.__all__) - set(names)
|
|
|
|
|
|
|
|
assert (
|
|
|
|
not missing_imports
|
|
|
|
), f"Missing imports: {missing_imports} in file path: {path}"
|