sphinx-autoapi/tests/test_astroid_utils.py
2018-05-08 22:41:35 -07:00

99 lines
3.1 KiB
Python

import astroid
from autoapi.mappers import astroid_utils
import pytest
def generate_module_names():
for i in range(1, 5):
yield '.'.join('module{}'.format(j) for j in range(i))
yield 'package.repeat.repeat'
def imported_basename_cases():
for module_name in generate_module_names():
import_ = 'import {}'.format(module_name)
basename = '{}.ImportedClass'.format(module_name)
expected = basename
yield (import_, basename, expected)
import_ = 'import {} as aliased'.format(module_name)
basename = 'aliased.ImportedClass'
yield (import_, basename, expected)
if '.' in module_name:
from_name, attribute = module_name.rsplit('.', 1)
import_ = 'from {} import {}'.format(from_name, attribute)
basename = '{}.ImportedClass'.format(attribute)
yield (import_, basename, expected)
import_ += ' as aliased'
basename = 'aliased.ImportedClass'
yield (import_, basename, expected)
import_ = 'from {} import ImportedClass'.format(module_name)
basename = 'ImportedClass'
yield (import_, basename, expected)
import_ = 'from {} import ImportedClass as AliasedClass'.format(module_name)
basename = 'AliasedClass'
yield (import_, basename, expected)
def generate_args():
for i in range(5):
yield ', '.join('arg{}'.format(j) for j in range(i))
def imported_call_cases():
for args in generate_args():
for import_, basename, expected in imported_basename_cases():
basename += '({})'.format(args)
expected += '()'
yield import_, basename, expected
class TestAstroidUtils(object):
@pytest.mark.parametrize(
('import_', 'basename', 'expected'), list(imported_basename_cases()),
)
def test_can_get_full_imported_basename(self, import_, basename, expected):
source = '''
{}
class ThisClass({}): #@
pass
'''.format(import_, basename)
node = astroid.extract_node(source)
basenames = astroid_utils.get_full_basename(
node.bases[0], node.basenames[0],
)
assert basenames == expected
@pytest.mark.parametrize(
('import_', 'basename', 'expected'), list(imported_call_cases()),
)
def test_can_get_full_function_basename(self, import_, basename, expected):
source = '''
{}
class ThisClass({}): #@
pass
'''.format(import_, basename)
node = astroid.extract_node(source)
basenames = astroid_utils.get_full_basename(
node.bases[0], node.basenames[0],
)
assert basenames == expected
@pytest.mark.parametrize(('source', 'expected'), [
('a = "a"', ('a', 'a')),
('a = 1', ('a', 1)),
('a, b, c = (1, 2, 3)', None),
('a = b = 1', None),
])
def test_can_get_assign_values(self, source, expected):
node = astroid.extract_node(source)
value = astroid_utils.get_assign_value(node)
assert value == expected