From 72707183744ff9daabffc787f9f1ce094630e547 Mon Sep 17 00:00:00 2001 From: Taku Fukada Date: Thu, 6 Aug 2020 16:38:29 +0900 Subject: [PATCH] Add support for overloaded functions --- autoapi/directives.py | 9 ++- autoapi/mappers/python/astroid_utils.py | 36 +++++++++++ autoapi/mappers/python/objects.py | 8 +++ autoapi/mappers/python/parser.py | 72 +++++++++++++++++++--- autoapi/templates/python/function.rst | 9 ++- autoapi/templates/python/method.rst | 22 +++++-- tests/python/py3example/example/example.py | 63 ++++++++++++++++++- tests/python/test_pyintegration.py | 40 ++++++++++++ 8 files changed, 241 insertions(+), 18 deletions(-) diff --git a/autoapi/directives.py b/autoapi/directives.py index a7f4970..a0a4509 100644 --- a/autoapi/directives.py +++ b/autoapi/directives.py @@ -20,9 +20,12 @@ class AutoapiSummary(Autosummary): # pylint: disable=too-few-public-methods for name in names: obj = mapper.all_objects[name] if isinstance(obj, PythonFunction): - sig = "({})".format(obj.args) - if obj.return_annotation is not None: - sig += " -> {}".format(obj.return_annotation) + if len(obj.signatures) > 1: + sig = "(\u2026)" + else: + sig = "({})".format(obj.args) + if obj.return_annotation is not None: + sig += " \u2192 {}".format(obj.return_annotation) else: sig = "" diff --git a/autoapi/mappers/python/astroid_utils.py b/autoapi/mappers/python/astroid_utils.py index 2391b11..abbb1dd 100644 --- a/autoapi/mappers/python/astroid_utils.py +++ b/autoapi/mappers/python/astroid_utils.py @@ -275,6 +275,42 @@ def is_decorated_with_property_setter(node): return False +def is_decorated_with_overload(node): + """Check if the function is decorated as an overload definition. + + :param node: The node to check. + :type node: astroid.nodes.FunctionDef + + :returns: True if the function is an overload definition, False otherwise. + :rtype: bool + """ + if not node.decorators: + return False + + for decorator in node.decorators.nodes: + if not isinstance(decorator, (astroid.Name, astroid.Attribute)): + continue + + try: + if _is_overload_decorator(decorator): + return True + except astroid.InferenceError: + pass + + return False + + +def _is_overload_decorator(decorator): + for inferred in decorator.infer(): + if not isinstance(inferred, astroid.nodes.FunctionDef): + continue + + if inferred.name == "overload" and inferred.root().name == "typing": + return True + + return False + + def is_constructor(node): """Check if the function is a constructor. diff --git a/autoapi/mappers/python/objects.py b/autoapi/mappers/python/objects.py index 8511172..39014d4 100644 --- a/autoapi/mappers/python/objects.py +++ b/autoapi/mappers/python/objects.py @@ -183,6 +183,14 @@ class PythonFunction(PythonPythonMapper): :type: list(str) """ + self.signatures = obj["signatures"] + """The list of all signatures ``[(args, return_annotation), ...]`` of this function. + + When this function is not overloaded, + it must be the same as ``[(self.args, self.return_annotation)]``. + + :type: list(tuple(str, str)) + """ class PythonMethod(PythonFunction): diff --git a/autoapi/mappers/python/parser.py b/autoapi/mappers/python/parser.py index d62daa8..40ad76f 100644 --- a/autoapi/mappers/python/parser.py +++ b/autoapi/mappers/python/parser.py @@ -127,8 +127,10 @@ class Parser(object): } self._name_stack.append(node.name) - seen = set() + overridden = set() + overloads = {} for base in itertools.chain(iter((node,)), node.ancestors()): + seen = set() if base.qname() in ("__builtins__.object", "builtins.object"): continue for child in base.get_children(): @@ -138,14 +140,42 @@ class Parser(object): if not assign_value: continue name = assign_value[0] - if not name or name in seen: + + if not name or name in overridden: continue seen.add(name) child_data = self.parse(child) - if child_data: - for single_data in child_data: - single_data["inherited"] = base is not node - data["children"].extend(child_data) + + for single_data in child_data: + if single_data["type"] in ("method", "property"): + if name in overloads: + grouped = overloads[name] + if single_data["doc"]: + grouped["doc"] += "\n\n" + single_data["doc"] + if single_data["is_overload"]: + grouped["signatures"].append( + ( + single_data["args"], + single_data["return_annotation"], + ) + ) + else: + grouped["args"] = single_data["args"] + grouped["return_annotation"] = single_data[ + "return_annotation" + ] + continue + if single_data["is_overload"] and name not in overloads: + overloads[name] = single_data + single_data["signatures"] = [ + (single_data["args"], single_data["return_annotation"]) + ] + + single_data["inherited"] = base is not node + data["children"].append(single_data) + + overridden.update(seen) + self._name_stack.pop() return [data] @@ -159,6 +189,7 @@ class Parser(object): type_ = "method" properties = [] + if node.type == "function": type_ = "function" elif astroid_utils.is_decorated_with_property(node): @@ -193,6 +224,7 @@ class Parser(object): "to_line_no": node.tolineno, "return_annotation": return_annotation, "properties": properties, + "is_overload": astroid_utils.is_decorated_with_overload(node), } if type_ in ("method", "property"): @@ -249,6 +281,7 @@ class Parser(object): "all": astroid_utils.get_module_all(node), } + overloads = {} top_name = node.name.split(".", 1)[0] for child in node.get_children(): if astroid_utils.is_local_import_from(child, top_name): @@ -256,8 +289,30 @@ class Parser(object): else: child_data = self.parse(child) - if child_data: - data["children"].extend(child_data) + for single_data in child_data: + if single_data["type"] == "function": + name = single_data["name"] + if name in overloads: + grouped = overloads[name] + if single_data["doc"]: + grouped["doc"] += "\n\n" + single_data["doc"] + if single_data["is_overload"]: + grouped["signatures"].append( + (single_data["args"], single_data["return_annotation"]) + ) + else: + grouped["args"] = single_data["args"] + grouped["return_annotation"] = single_data[ + "return_annotation" + ] + continue + if single_data["is_overload"] and name not in overloads: + overloads[name] = single_data + single_data["signatures"] = [ + (single_data["args"], single_data["return_annotation"]) + ] + + data["children"].append(single_data) return data @@ -273,5 +328,4 @@ class Parser(object): data = self.parse(child) if data: break - return data diff --git a/autoapi/templates/python/function.rst b/autoapi/templates/python/function.rst index 829f498..f8feb2b 100644 --- a/autoapi/templates/python/function.rst +++ b/autoapi/templates/python/function.rst @@ -1,6 +1,13 @@ {% if obj.display %} -.. function:: {{ obj.short_name }}({{ obj.args }}){% if obj.return_annotation is not none %} -> {{ obj.return_annotation }}{% endif %} +{% for (args, return_annotation) in obj.signatures %} +{% if loop.index0 == 0 %} +.. function:: {{ obj.short_name }}({{ args }}){% if return_annotation is not none %} -> {{ return_annotation }}{% endif %} +{% else %} + {{ obj.short_name }}({{ args }}){% if return_annotation is not none %} -> {{ return_annotation }}{% endif %} + +{% endif %} +{% endfor %} {% if sphinx_version >= (2, 1) %} {% for property in obj.properties %} :{{ property }}: diff --git a/autoapi/templates/python/method.rst b/autoapi/templates/python/method.rst index d73a3ea..b6d7f99 100644 --- a/autoapi/templates/python/method.rst +++ b/autoapi/templates/python/method.rst @@ -1,19 +1,33 @@ {%- if obj.display %} {% if sphinx_version >= (2, 1) %} -.. method:: {{ obj.short_name }}({{ obj.args }}){% if obj.return_annotation is not none %} -> {{ obj.return_annotation }}{% endif %} - {% if obj.properties %} +{% for (args, return_annotation) in obj.signatures %} +{% if loop.index0 == 0 %} +.. method:: {{ obj.short_name }}({{ args }}){% if return_annotation is not none %} -> {{ return_annotation }}{% endif %} + +{% else %} + {{ obj.short_name }}({{ args }}){% if return_annotation is not none %} -> {{ return_annotation }}{% endif %} +{% endif %} +{% endfor %} + {% if obj.properties %} {% for property in obj.properties %} :{{ property }}: {% endfor %} + {% else %} {% endif %} {% else %} -.. {{ obj.method_type }}:: {{ obj.short_name }}({{ obj.args }}) +{% for (args, return_annotation) in obj.signatures %} +{% if loop.index0 == 0 %} +.. {{ obj.method_type }}:: {{ obj.short_name }}({{ args }}) -{% endif %} +{% else %} + :: {{ obj.short_name }}({{ args }}) +{% endif %} +{% endfor %} +{% endif %} {% if obj.docstring %} {{ obj.docstring|prepare_docstring|indent(3) }} {% endif %} diff --git a/tests/python/py3example/example/example.py b/tests/python/py3example/example/example.py index 1b852d2..83816e1 100644 --- a/tests/python/py3example/example/example.py +++ b/tests/python/py3example/example/example.py @@ -4,7 +4,8 @@ This is a description """ import asyncio -from typing import ClassVar, Dict, Iterable, List, Union +import typing +from typing import ClassVar, Dict, Iterable, List, Union, overload max_rating: int = 10 @@ -35,7 +36,33 @@ def f2(not_yet_a: "A") -> int: ... +@overload +def overloaded_func(a: float) -> float: + ... + + +@typing.overload +def overloaded_func(a: str) -> str: + ... + + +def overloaded_func(a: Union[float, str]) -> Union[float, str]: + """Overloaded function""" + return a * 2 + + +@overload +def undoc_overloaded_func(a: str) -> str: + ... + + +def undoc_overloaded_func(a: str) -> str: + return a * 2 + + class A: + """class A""" + is_an_a: ClassVar[bool] = True not_assigned_to: ClassVar[str] @@ -57,6 +84,40 @@ class A: """My method.""" return "method" + @overload + def overloaded_method(self, a: float) -> float: + ... + + @typing.overload + def overloaded_method(self, a: str) -> str: + ... + + def overloaded_method(self, a: Union[float, str]) -> Union[float, str]: + """Overloaded method""" + return a * 2 + + @overload + def undoc_overloaded_method(self, a: float) -> float: + ... + + def undoc_overloaded_method(self, a: float) -> float: + return a * 2 + + @typing.overload + @classmethod + def overloaded_class_method(cls, a: float) -> float: + ... + + @overload + @classmethod + def overloaded_class_method(cls, a: str) -> str: + ... + + @classmethod + def overloaded_class_method(cls, a: Union[float, str]) -> Union[float, str]: + """Overloaded class method""" + return a * 2 + async def async_function(self, wait: bool) -> int: if wait: diff --git a/tests/python/test_pyintegration.py b/tests/python/test_pyintegration.py index 2234c15..25a34b0 100644 --- a/tests/python/test_pyintegration.py +++ b/tests/python/test_pyintegration.py @@ -176,6 +176,29 @@ class TestPy3Module(object): if sphinx.version_info >= (2, 1): assert "my_method(self) -> str" in example_file + def test_overload(self): + example_path = "_build/text/autoapi/example/index.txt" + with io.open(example_path, encoding="utf8") as example_handle: + example_file = example_handle.read() + + assert "overloaded_func(a: float" in example_file + assert "overloaded_func(a: str" in example_file + assert "overloaded_func(a: Union" not in example_file + assert "Overloaded function" in example_file + + assert "overloaded_method(self, a: float" in example_file + assert "overloaded_method(self, a: str" in example_file + assert "overloaded_method(self, a: Union" not in example_file + assert "Overloaded method" in example_file + + assert "overloaded_class_method(cls, a: float" in example_file + assert "overloaded_class_method(cls, a: str" in example_file + assert "overloaded_class_method(cls, a: Union" not in example_file + assert "Overloaded method" in example_file + + assert "undoc_overloaded_func" in example_file + assert "undoc_overloaded_method" in example_file + def test_async(self): example_path = "_build/text/autoapi/example/index.txt" with io.open(example_path, encoding="utf8") as example_handle: @@ -189,6 +212,23 @@ class TestPy3Module(object): assert "async_function" in example_file +@pytest.mark.skipif( + sys.version_info < (3, 6), reason="Annotations are invalid in Python <3.5" +) +def test_py3_hiding_undoc_overloaded_members(builder): + confoverrides = {"autoapi_options": ["members", "special-members"]} + builder("py3example", confoverrides=confoverrides) + + example_path = "_build/text/autoapi/example/index.txt" + with io.open(example_path, encoding="utf8") as example_handle: + example_file = example_handle.read() + + assert "overloaded_func" in example_file + assert "overloaded_method" in example_file + assert "undoc_overloaded_func" not in example_file + assert "undoc_overloaded_method" not in example_file + + @pytest.mark.skipif( sys.version_info < (3,), reason="Annotations are not supported in astroid<2" )