sphinx-autoapi/tests/test_astroid_utils.py
2019-01-26 21:20:45 -08:00

102 lines
3.1 KiB
Python

import astroid
from autoapi.mappers import astroid_utils
import pytest
def generate_module_names():
for i in range(1, 5):
yield ".".join("module{}".format(j) for j in range(i))
yield "package.repeat.repeat"
def imported_basename_cases():
for module_name in generate_module_names():
import_ = "import {}".format(module_name)
basename = "{}.ImportedClass".format(module_name)
expected = basename
yield (import_, basename, expected)
import_ = "import {} as aliased".format(module_name)
basename = "aliased.ImportedClass"
yield (import_, basename, expected)
if "." in module_name:
from_name, attribute = module_name.rsplit(".", 1)
import_ = "from {} import {}".format(from_name, attribute)
basename = "{}.ImportedClass".format(attribute)
yield (import_, basename, expected)
import_ += " as aliased"
basename = "aliased.ImportedClass"
yield (import_, basename, expected)
import_ = "from {} import ImportedClass".format(module_name)
basename = "ImportedClass"
yield (import_, basename, expected)
import_ = "from {} import ImportedClass as AliasedClass".format(module_name)
basename = "AliasedClass"
yield (import_, basename, expected)
def generate_args():
for i in range(5):
yield ", ".join("arg{}".format(j) for j in range(i))
def imported_call_cases():
for args in generate_args():
for import_, basename, expected in imported_basename_cases():
basename += "({})".format(args)
expected += "()"
yield import_, basename, expected
class TestAstroidUtils(object):
@pytest.mark.parametrize(
("import_", "basename", "expected"), list(imported_basename_cases())
)
def test_can_get_full_imported_basename(self, import_, basename, expected):
source = """
{}
class ThisClass({}): #@
pass
""".format(
import_, basename
)
node = astroid.extract_node(source)
basenames = astroid_utils.get_full_basename(node.bases[0], node.basenames[0])
assert basenames == expected
@pytest.mark.parametrize(
("import_", "basename", "expected"), list(imported_call_cases())
)
def test_can_get_full_function_basename(self, import_, basename, expected):
source = """
{}
class ThisClass({}): #@
pass
""".format(
import_, basename
)
node = astroid.extract_node(source)
basenames = astroid_utils.get_full_basename(node.bases[0], node.basenames[0])
assert basenames == expected
@pytest.mark.parametrize(
("source", "expected"),
[
('a = "a"', ("a", "a")),
("a = 1", ("a", 1)),
("a, b, c = (1, 2, 3)", None),
("a = b = 1", None),
],
)
def test_can_get_assign_values(self, source, expected):
node = astroid.extract_node(source)
value = astroid_utils.get_assign_value(node)
assert value == expected