diff --git a/autoapi/extension.py b/autoapi/extension.py index 45dca8d..cbc722a 100644 --- a/autoapi/extension.py +++ b/autoapi/extension.py @@ -172,9 +172,10 @@ def viewcode_find(app, modname): elif obj.type in ('function', 'method'): type_ = 'def' full_name = prefix + obj.name - locations[full_name] = ( - type_, obj.obj['from_line_no'], obj.obj['to_line_no'], - ) + if 'from_line_no' in obj.obj: + locations[full_name] = ( + type_, obj.obj['from_line_no'], obj.obj['to_line_no'], + ) children = getattr(obj, 'children', ()) stack.extend((full_name + '.', gchild) for gchild in children) diff --git a/autoapi/mappers/astroid_utils.py b/autoapi/mappers/astroid_utils.py index 07df51d..afefe98 100644 --- a/autoapi/mappers/astroid_utils.py +++ b/autoapi/mappers/astroid_utils.py @@ -11,8 +11,10 @@ import astroid.nodes if sys.version_info < (3,): _EXCEPTIONS_MODULE = "exceptions" + _STRING_TYPES = basestring # pylint: disable=undefined-variable else: _EXCEPTIONS_MODULE = "builtins" + _STRING_TYPES = str def resolve_import_alias(name, import_names): @@ -40,7 +42,7 @@ def resolve_import_alias(name, import_names): def get_full_import_name(import_from, name): - """Get the full path of a name from an ``import x from y`` statement. + """Get the full path of a name from a ``from x import y`` statement. :param import_from: The astroid node to resolve the name of. :type import_from: astroid.nodes.ImportFrom @@ -264,3 +266,57 @@ def is_exception(node): return any( is_exception(parent) for parent in node.ancestors(recurs=True) ) + + +def is_local_import_from(node, package_name): + """Check if a node is an import from the local package. + + :param node: The node to check. + :type node: astroid.node.NodeNG + + :param package_name: The name of the local package. + :type package_name: str + + :returns: True if the node is an import from the local package, + False otherwise. + :rtype: bool + """ + if not isinstance(node, astroid.ImportFrom): + return False + + return ( + node.level + or node.modname == package_name + or node.modname.startswith(package_name + '.') + ) + + +def get_module_all(node): + """Get the contents of the ``__all__`` variable from a module. + + :param node: The module to get ``__all__`` from. + :type node: astroid.nodes.Module + + :returns: The contents of ``__all__`` if defined. Otherwise None. + :rtype: list(str) or None + """ + all_ = None + + if '__all__' in node.locals: + assigned = next(node.igetattr('__all__')) + if assigned is not astroid.Uninferable: + all_ = [] + for elt in getattr(assigned, 'elts', ()): + try: + elt_name = next(elt.infer()) + except astroid.InferenceError: + continue + + if elt_name is astroid.Uninferable: + continue + + if (isinstance(elt_name, astroid.Const) + and isinstance(elt_name.value, _STRING_TYPES)): + all_.append(elt_name.value) + + return all_ diff --git a/autoapi/mappers/base.py b/autoapi/mappers/base.py index 2c317ee..1f32790 100644 --- a/autoapi/mappers/base.py +++ b/autoapi/mappers/base.py @@ -267,8 +267,11 @@ class SphinxMapperBase(object): ''' self.objects[obj.id] = obj self.all_objects[obj.id] = obj - for child in obj.children: + child_stack = list(obj.children) + while child_stack: + child = child_stack.pop() self.all_objects[child.id] = child + child_stack.extend(getattr(child, 'children', ())) def map(self, options=None): '''Trigger find of serialized sources and build objects''' diff --git a/autoapi/mappers/python.py b/autoapi/mappers/python.py index 3f0eae7..11aa8a4 100644 --- a/autoapi/mappers/python.py +++ b/autoapi/mappers/python.py @@ -1,6 +1,6 @@ -import sys -import os import collections +import copy +import os import astroid import sphinx @@ -49,7 +49,77 @@ class PythonSphinxMapper(SphinxMapperBase): self.app.warn('Error reading file: {0}'.format(path)) return None + def _resolve_placeholders(self): + """Resolve objects that have been imported from elsewhere.""" + placeholders = [] + all_data = {} + child_stack = [] + # Initialise the stack with module level objects + for data in self.paths.values(): + all_data[data['name']] = data + + for child in data['children']: + child_stack.append((data, data['name'], child)) + + # Find all placeholders and everything that can be resolved to + while child_stack: + parent, parent_name, data = child_stack.pop() + if data['type'] == 'placeholder': + placeholders.append((parent, data)) + + full_name = parent_name + '.' + data['name'] + all_data[full_name] = data + + for child in data.get('children', ()): + child_stack.append((data, full_name, child)) + + # Resolve all placeholders + for parent, placeholder in placeholders: + # Check if this was resolved by a previous iteration + if placeholder['type'] != 'placeholder': + continue + + if placeholder['original_path'] not in all_data: + parent['children'].remove(placeholder) + self.app.debug( + 'Could not resolve {0} for {1}.{2}'.format( + placeholder['original_path'], + parent['name'], + placeholder['name'], + ) + ) + continue + + # Find import chains and resolve the placeholders together + visited = {id(placeholder): placeholder} + original = all_data[placeholder['original_path']] + while original['type'] == 'placeholder': + if id(original) in visited: + parent['children'].remove(placeholder) + break + original = all_data[placeholder['original_path']] + visited[id(original)] = original + else: + if original['type'] in ('package', 'module'): + parent['children'].remove(placeholder) + continue + + for to_resolve in visited.values(): + new = copy.deepcopy(original) + new['name'] = to_resolve['name'] + new['imported'] = True + stack = [new] + while stack: + data = stack.pop() + del data['from_line_no'] + del data['to_line_no'] + stack.extend(data.get('children', ())) + to_resolve.clear() + to_resolve.update(new) + def map(self, options=None): + self._resolve_placeholders() + super(PythonSphinxMapper, self).map(options) parents = {obj.name: obj for obj in self.objects.values()} @@ -234,11 +304,11 @@ class TopLevelPythonPythonMapper(PythonPythonMapper): def __init__(self, obj, **kwargs): super(TopLevelPythonPythonMapper, self).__init__(obj, **kwargs) - self._resolve_name() self.top_level_object = '.' not in self.name self.subpackages = [] self.submodules = [] + self.all = obj['all'] @property def functions(self): @@ -252,30 +322,10 @@ class TopLevelPythonPythonMapper(PythonPythonMapper): class PythonModule(TopLevelPythonPythonMapper): type = 'module' - def _resolve_name(self): - name = self.obj['relative_path'] - name = name.replace(os.sep, '.') - ext = '.py' - if name.endswith(ext): - name = name[:-len(ext)] - - self.name = name - class PythonPackage(TopLevelPythonPythonMapper): type = 'package' - def _resolve_name(self): - name = self.obj['relative_path'] - - exts = [os.sep + '__init__.py', '.py'] - for ext in exts: - if name.endswith(ext): - name = name[:-len(ext)] - name = name.replace(os.sep, '.') - - self.name = name - class PythonClass(PythonPythonMapper): type = 'class' @@ -357,8 +407,11 @@ class PythonException(PythonClass): class Parser(object): def parse_file(self, file_path): directory, filename = os.path.split(file_path) - module_part = os.path.splitext(filename)[0] - module_parts = collections.deque([module_part]) + module_parts = [] + if filename != '__init__.py': + module_part = os.path.splitext(filename)[0] + module_parts = [module_part] + module_parts = collections.deque(module_parts) while os.path.isfile(os.path.join(directory, '__init__.py')): directory, module_part = os.path.split(directory) if module_part: @@ -471,13 +524,28 @@ class Parser(object): return result + def _parse_local_import_from(self, node): + result = [] + + for name, alias in node.names: + full_name = astroid_utils.get_full_import_name(node, alias or name) + + data = { + 'type': 'placeholder', + 'name': alias or name, + 'original_path': full_name, + } + result.append(data) + + return result + def parse_module(self, node): path = node.path if isinstance(node.path, list): path = node.path[0] if node.path else None type_ = 'module' - if path.endswith('__init__.py'): + if node.package: type_ = 'package' data = { @@ -486,10 +554,16 @@ class Parser(object): 'doc': node.doc or '', 'children': [], 'file_path': path, + 'all': astroid_utils.get_module_all(node), } + top_name = node.name.split('.', 1)[0] for child in node.get_children(): - child_data = self.parse(child) + if node.package and astroid_utils.is_local_import_from(child, top_name): + child_data = self._parse_local_import_from(child) + else: + child_data = self.parse(child) + if child_data: data['children'].extend(child_data) diff --git a/autoapi/templates/python/module.rst b/autoapi/templates/python/module.rst index eeb1c3b..bb6cc71 100644 --- a/autoapi/templates/python/module.rst +++ b/autoapi/templates/python/module.rst @@ -79,7 +79,9 @@ Functions {%- for obj_item in obj.children %} +{% if obj.all is none or obj_item.short_name in obj.all %} {{ obj_item.rendered|indent(0) }} +{% endif %} {% endfor %} {% endif %}{% endblock %}