"""This script checks documentation for broken import statements.""" import importlib import json import logging import os import re import warnings from pathlib import Path from typing import List, Tuple logger = logging.getLogger(__name__) DOCS_DIR = Path(os.path.abspath(__file__)).parents[1] / "docs" import_pattern = re.compile( r"import\s+(\w+)|from\s+([\w\.]+)\s+import\s+((?:\w+(?:,\s*)?)+|\(.*?\))", re.DOTALL ) def _get_imports_from_code_cell(code_lines: str) -> List[Tuple[str, str]]: """Get (module, import) statements from a single code cell.""" import_statements = [] for line in code_lines: line = line.strip() if line.startswith("#") or not line: continue # Join lines that end with a backslash if line.endswith("\\"): line = line[:-1].rstrip() + " " continue matches = import_pattern.findall(line) for match in matches: if match[0]: # simple import statement import_statements.append((match[0], "")) else: # from ___ import statement module, items = match[1], match[2] items_list = items.replace(" ", "").split(",") for item in items_list: import_statements.append((module, item)) return import_statements def _extract_import_statements(notebook_path: str) -> List[Tuple[str, str]]: """Get (module, import) statements from a Jupyter notebook.""" with open(notebook_path, "r", encoding="utf-8") as file: notebook = json.load(file) code_cells = [cell for cell in notebook["cells"] if cell["cell_type"] == "code"] import_statements = [] for cell in code_cells: code_lines = cell["source"] import_statements.extend(_get_imports_from_code_cell(code_lines)) return import_statements def _get_bad_imports(import_statements: List[Tuple[str, str]]) -> List[Tuple[str, str]]: """Collect offending import statements.""" offending_imports = [] for module, item in import_statements: try: if item: try: # submodule full_module_name = f"{module}.{item}" importlib.import_module(full_module_name) except ModuleNotFoundError: # attribute try: imported_module = importlib.import_module(module) getattr(imported_module, item) except AttributeError: offending_imports.append((module, item)) except Exception: offending_imports.append((module, item)) else: importlib.import_module(module) except Exception: offending_imports.append((module, item)) return offending_imports def _is_relevant_import(module: str) -> bool: """Check if module is recognized.""" # Ignore things like langchain_{bla}, where bla is unrecognized. recognized_packages = [ "langchain", "langchain_core", "langchain_community", "langchain_experimental", "langchain_text_splitters", ] return module.split(".")[0] in recognized_packages def _serialize_bad_imports(bad_files: list) -> str: """Serialize bad imports to a string.""" bad_imports_str = "" for file, bad_imports in bad_files: bad_imports_str += f"File: {file}\n" for module, item in bad_imports: bad_imports_str += f" {module}.{item}\n" return bad_imports_str def check_notebooks(directory: str) -> list: """Check notebooks for broken import statements.""" bad_files = [] for root, _, files in os.walk(directory): for file in files: if file.endswith(".ipynb") and not file.endswith("-checkpoint.ipynb"): notebook_path = os.path.join(root, file) import_statements = [ (module, item) for module, item in _extract_import_statements(notebook_path) if _is_relevant_import(module) ] bad_imports = _get_bad_imports(import_statements) if bad_imports: bad_files.append( ( os.path.join(root, file), bad_imports, ) ) return bad_files if __name__ == "__main__": bad_files = check_notebooks(DOCS_DIR) if bad_files: raise ImportError("Found bad imports:\n" f"{_serialize_bad_imports(bad_files)}")