import astroid from autoapi.mappers import astroid_utils import pytest def generate_module_names(): for i in range(1, 5): yield '.'.join('module{}'.format(j) for j in range(i)) yield 'package.repeat.repeat' def imported_basename_cases(): for module_name in generate_module_names(): import_ = 'import {}'.format(module_name) basename = '{}.ImportedClass'.format(module_name) expected = basename yield (import_, basename, expected) import_ = 'import {} as aliased'.format(module_name) basename = 'aliased.ImportedClass' yield (import_, basename, expected) if '.' in module_name: from_name, attribute = module_name.rsplit('.', 1) import_ = 'from {} import {}'.format(from_name, attribute) basename = '{}.ImportedClass'.format(attribute) yield (import_, basename, expected) import_ += ' as aliased' basename = 'aliased.ImportedClass' yield (import_, basename, expected) import_ = 'from {} import ImportedClass'.format(module_name) basename = 'ImportedClass' yield (import_, basename, expected) import_ = 'from {} import ImportedClass as AliasedClass'.format(module_name) basename = 'AliasedClass' yield (import_, basename, expected) def generate_args(): for i in range(5): yield ', '.join('arg{}'.format(j) for j in range(i)) def imported_call_cases(): for args in generate_args(): for import_, basename, expected in imported_basename_cases(): basename += '({})'.format(args) expected += '()' yield import_, basename, expected class TestAstroidUtils(object): @pytest.mark.parametrize( ('import_', 'basename', 'expected'), list(imported_basename_cases()), ) def test_can_get_full_imported_basename(self, import_, basename, expected): source = ''' {} class ThisClass({}): #@ pass '''.format(import_, basename) node = astroid.extract_node(source) basenames = astroid_utils.get_full_basename( node.bases[0], node.basenames[0], ) assert basenames == expected @pytest.mark.parametrize( ('import_', 'basename', 'expected'), list(imported_call_cases()), ) def test_can_get_full_function_basename(self, import_, basename, expected): source = ''' {} class ThisClass({}): #@ pass '''.format(import_, basename) node = astroid.extract_node(source) basenames = astroid_utils.get_full_basename( node.bases[0], node.basenames[0], ) assert basenames == expected @pytest.mark.parametrize(('source', 'expected'), [ ('a = "a"', ('a', 'a')), ('a = 1', ('a', 1)), ('a, b, c = (1, 2, 3)', None), ('a = b = 1', None), ]) def test_can_get_assign_values(self, source, expected): node = astroid.extract_node(source) value = astroid_utils.get_assign_value(node) assert value == expected