|
|
|
@ -4,9 +4,11 @@ import json
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import warnings
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import List, Tuple
|
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
from langchain_core._api import LangChainDeprecationWarning
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
@ -14,9 +16,10 @@ DOCS_DIR = Path(os.path.abspath(__file__)).parents[1] / "docs"
|
|
|
|
|
import_pattern = re.compile(
|
|
|
|
|
r"import\s+(\w+)|from\s+([\w\.]+)\s+import\s+((?:\w+(?:,\s*)?)+|\(.*?\))", re.DOTALL
|
|
|
|
|
)
|
|
|
|
|
Import = Tuple[str, str]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_imports_from_code_cell(code_lines: str) -> List[Tuple[str, str]]:
|
|
|
|
|
def _get_imports_from_code_cell(code_lines: str) -> List[Import]:
|
|
|
|
|
"""Get (module, import) statements from a single code cell."""
|
|
|
|
|
import_statements = []
|
|
|
|
|
for line in code_lines:
|
|
|
|
@ -39,7 +42,7 @@ def _get_imports_from_code_cell(code_lines: str) -> List[Tuple[str, str]]:
|
|
|
|
|
return import_statements
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_import_statements(notebook_path: str) -> List[Tuple[str, str]]:
|
|
|
|
|
def _extract_import_statements(notebook_path: str) -> List[Import]:
|
|
|
|
|
"""Get (module, import) statements from a Jupyter notebook."""
|
|
|
|
|
with open(notebook_path, "r", encoding="utf-8") as file:
|
|
|
|
|
notebook = json.load(file)
|
|
|
|
@ -51,31 +54,43 @@ def _extract_import_statements(notebook_path: str) -> List[Tuple[str, str]]:
|
|
|
|
|
return import_statements
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_bad_imports(import_statements: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
|
|
|
|
"""Collect offending import statements."""
|
|
|
|
|
def _get_bad_imports(import_statements: List[Import]) -> Tuple[List[Import], List[Import]]:
|
|
|
|
|
offending_imports = []
|
|
|
|
|
deprecated_imports = []
|
|
|
|
|
|
|
|
|
|
for module, item in import_statements:
|
|
|
|
|
try:
|
|
|
|
|
if item:
|
|
|
|
|
try:
|
|
|
|
|
# submodule
|
|
|
|
|
full_module_name = f"{module}.{item}"
|
|
|
|
|
importlib.import_module(full_module_name)
|
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
|
# attribute
|
|
|
|
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
|
|
|
|
warnings.simplefilter("always")
|
|
|
|
|
|
|
|
|
|
if item:
|
|
|
|
|
try:
|
|
|
|
|
imported_module = importlib.import_module(module)
|
|
|
|
|
getattr(imported_module, item)
|
|
|
|
|
except AttributeError:
|
|
|
|
|
# submodule
|
|
|
|
|
full_module_name = f"{module}.{item}"
|
|
|
|
|
importlib.import_module(full_module_name)
|
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
|
# attribute
|
|
|
|
|
try:
|
|
|
|
|
imported_module = importlib.import_module(module)
|
|
|
|
|
getattr(imported_module, item)
|
|
|
|
|
except AttributeError:
|
|
|
|
|
offending_imports.append((module, item))
|
|
|
|
|
except Exception:
|
|
|
|
|
offending_imports.append((module, item))
|
|
|
|
|
except Exception:
|
|
|
|
|
offending_imports.append((module, item))
|
|
|
|
|
else:
|
|
|
|
|
importlib.import_module(module)
|
|
|
|
|
else:
|
|
|
|
|
importlib.import_module(module)
|
|
|
|
|
|
|
|
|
|
# Check for deprecation warnings
|
|
|
|
|
|
|
|
|
|
for warning in caught_warnings:
|
|
|
|
|
if issubclass(warning.category, LangChainDeprecationWarning):
|
|
|
|
|
deprecated_imports.append((module, item))
|
|
|
|
|
|
|
|
|
|
except Exception:
|
|
|
|
|
offending_imports.append((module, item))
|
|
|
|
|
|
|
|
|
|
return offending_imports
|
|
|
|
|
return offending_imports, deprecated_imports
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_relevant_import(module: str) -> bool:
|
|
|
|
@ -104,6 +119,7 @@ def _serialize_bad_imports(bad_files: list) -> str:
|
|
|
|
|
def check_notebooks(directory: str) -> list:
|
|
|
|
|
"""Check notebooks for broken import statements."""
|
|
|
|
|
bad_files = []
|
|
|
|
|
deprecated_files = []
|
|
|
|
|
for root, _, files in os.walk(directory):
|
|
|
|
|
for file in files:
|
|
|
|
|
if file.endswith(".ipynb") and not file.endswith("-checkpoint.ipynb"):
|
|
|
|
@ -113,18 +129,27 @@ def check_notebooks(directory: str) -> list:
|
|
|
|
|
for module, item in _extract_import_statements(notebook_path)
|
|
|
|
|
if _is_relevant_import(module)
|
|
|
|
|
]
|
|
|
|
|
bad_imports = _get_bad_imports(import_statements)
|
|
|
|
|
bad_imports, deprecated_imports = _get_bad_imports(import_statements)
|
|
|
|
|
if bad_imports:
|
|
|
|
|
bad_files.append(
|
|
|
|
|
(
|
|
|
|
|
os.path.join(root, file),
|
|
|
|
|
os.path.join(root, file).split("docs/")[-1],
|
|
|
|
|
bad_imports,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return bad_files
|
|
|
|
|
if deprecated_imports:
|
|
|
|
|
deprecated_files.append(
|
|
|
|
|
(
|
|
|
|
|
os.path.join(root, file).split("docs/")[-1],
|
|
|
|
|
deprecated_imports,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return bad_files, deprecated_files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
bad_files = check_notebooks(DOCS_DIR)
|
|
|
|
|
bad_files, deprecated_files = check_notebooks(DOCS_DIR)
|
|
|
|
|
if deprecated_files:
|
|
|
|
|
logger.warning("Found files with deprecated imports:\n" f"{_serialize_bad_imports(deprecated_files)}")
|
|
|
|
|
if bad_files:
|
|
|
|
|
raise ImportError("Found bad imports:\n" f"{_serialize_bad_imports(bad_files)}")
|
|
|
|
|