Add support for overloaded functions

pull/245/head
Taku Fukada 4 years ago committed by Ashley Whetter
parent d0de570660
commit 7270718374

@ -20,9 +20,12 @@ class AutoapiSummary(Autosummary): # pylint: disable=too-few-public-methods
for name in names:
obj = mapper.all_objects[name]
if isinstance(obj, PythonFunction):
sig = "({})".format(obj.args)
if obj.return_annotation is not None:
sig += " -> {}".format(obj.return_annotation)
if len(obj.signatures) > 1:
sig = "(\u2026)"
else:
sig = "({})".format(obj.args)
if obj.return_annotation is not None:
sig += " \u2192 {}".format(obj.return_annotation)
else:
sig = ""

@ -275,6 +275,42 @@ def is_decorated_with_property_setter(node):
return False
def is_decorated_with_overload(node):
"""Check if the function is decorated as an overload definition.
:param node: The node to check.
:type node: astroid.nodes.FunctionDef
:returns: True if the function is an overload definition, False otherwise.
:rtype: bool
"""
if not node.decorators:
return False
for decorator in node.decorators.nodes:
if not isinstance(decorator, (astroid.Name, astroid.Attribute)):
continue
try:
if _is_overload_decorator(decorator):
return True
except astroid.InferenceError:
pass
return False
def _is_overload_decorator(decorator):
for inferred in decorator.infer():
if not isinstance(inferred, astroid.nodes.FunctionDef):
continue
if inferred.name == "overload" and inferred.root().name == "typing":
return True
return False
def is_constructor(node):
"""Check if the function is a constructor.

@ -183,6 +183,14 @@ class PythonFunction(PythonPythonMapper):
:type: list(str)
"""
self.signatures = obj["signatures"]
"""The list of all signatures ``[(args, return_annotation), ...]`` of this function.
When this function is not overloaded,
it must be the same as ``[(self.args, self.return_annotation)]``.
:type: list(tuple(str, str))
"""
class PythonMethod(PythonFunction):

@ -127,8 +127,10 @@ class Parser(object):
}
self._name_stack.append(node.name)
seen = set()
overridden = set()
overloads = {}
for base in itertools.chain(iter((node,)), node.ancestors()):
seen = set()
if base.qname() in ("__builtins__.object", "builtins.object"):
continue
for child in base.get_children():
@ -138,14 +140,42 @@ class Parser(object):
if not assign_value:
continue
name = assign_value[0]
if not name or name in seen:
if not name or name in overridden:
continue
seen.add(name)
child_data = self.parse(child)
if child_data:
for single_data in child_data:
single_data["inherited"] = base is not node
data["children"].extend(child_data)
for single_data in child_data:
if single_data["type"] in ("method", "property"):
if name in overloads:
grouped = overloads[name]
if single_data["doc"]:
grouped["doc"] += "\n\n" + single_data["doc"]
if single_data["is_overload"]:
grouped["signatures"].append(
(
single_data["args"],
single_data["return_annotation"],
)
)
else:
grouped["args"] = single_data["args"]
grouped["return_annotation"] = single_data[
"return_annotation"
]
continue
if single_data["is_overload"] and name not in overloads:
overloads[name] = single_data
single_data["signatures"] = [
(single_data["args"], single_data["return_annotation"])
]
single_data["inherited"] = base is not node
data["children"].append(single_data)
overridden.update(seen)
self._name_stack.pop()
return [data]
@ -159,6 +189,7 @@ class Parser(object):
type_ = "method"
properties = []
if node.type == "function":
type_ = "function"
elif astroid_utils.is_decorated_with_property(node):
@ -193,6 +224,7 @@ class Parser(object):
"to_line_no": node.tolineno,
"return_annotation": return_annotation,
"properties": properties,
"is_overload": astroid_utils.is_decorated_with_overload(node),
}
if type_ in ("method", "property"):
@ -249,6 +281,7 @@ class Parser(object):
"all": astroid_utils.get_module_all(node),
}
overloads = {}
top_name = node.name.split(".", 1)[0]
for child in node.get_children():
if astroid_utils.is_local_import_from(child, top_name):
@ -256,8 +289,30 @@ class Parser(object):
else:
child_data = self.parse(child)
if child_data:
data["children"].extend(child_data)
for single_data in child_data:
if single_data["type"] == "function":
name = single_data["name"]
if name in overloads:
grouped = overloads[name]
if single_data["doc"]:
grouped["doc"] += "\n\n" + single_data["doc"]
if single_data["is_overload"]:
grouped["signatures"].append(
(single_data["args"], single_data["return_annotation"])
)
else:
grouped["args"] = single_data["args"]
grouped["return_annotation"] = single_data[
"return_annotation"
]
continue
if single_data["is_overload"] and name not in overloads:
overloads[name] = single_data
single_data["signatures"] = [
(single_data["args"], single_data["return_annotation"])
]
data["children"].append(single_data)
return data
@ -273,5 +328,4 @@ class Parser(object):
data = self.parse(child)
if data:
break
return data

@ -1,6 +1,13 @@
{% if obj.display %}
.. function:: {{ obj.short_name }}({{ obj.args }}){% if obj.return_annotation is not none %} -> {{ obj.return_annotation }}{% endif %}
{% for (args, return_annotation) in obj.signatures %}
{% if loop.index0 == 0 %}
.. function:: {{ obj.short_name }}({{ args }}){% if return_annotation is not none %} -> {{ return_annotation }}{% endif %}
{% else %}
{{ obj.short_name }}({{ args }}){% if return_annotation is not none %} -> {{ return_annotation }}{% endif %}
{% endif %}
{% endfor %}
{% if sphinx_version >= (2, 1) %}
{% for property in obj.properties %}
:{{ property }}:

@ -1,19 +1,33 @@
{%- if obj.display %}
{% if sphinx_version >= (2, 1) %}
.. method:: {{ obj.short_name }}({{ obj.args }}){% if obj.return_annotation is not none %} -> {{ obj.return_annotation }}{% endif %}
{% if obj.properties %}
{% for (args, return_annotation) in obj.signatures %}
{% if loop.index0 == 0 %}
.. method:: {{ obj.short_name }}({{ args }}){% if return_annotation is not none %} -> {{ return_annotation }}{% endif %}
{% else %}
{{ obj.short_name }}({{ args }}){% if return_annotation is not none %} -> {{ return_annotation }}{% endif %}
{% endif %}
{% endfor %}
{% if obj.properties %}
{% for property in obj.properties %}
:{{ property }}:
{% endfor %}
{% else %}
{% endif %}
{% else %}
.. {{ obj.method_type }}:: {{ obj.short_name }}({{ obj.args }})
{% for (args, return_annotation) in obj.signatures %}
{% if loop.index0 == 0 %}
.. {{ obj.method_type }}:: {{ obj.short_name }}({{ args }})
{% endif %}
{% else %}
:: {{ obj.short_name }}({{ args }})
{% endif %}
{% endfor %}
{% endif %}
{% if obj.docstring %}
{{ obj.docstring|prepare_docstring|indent(3) }}
{% endif %}

@ -4,7 +4,8 @@
This is a description
"""
import asyncio
from typing import ClassVar, Dict, Iterable, List, Union
import typing
from typing import ClassVar, Dict, Iterable, List, Union, overload
max_rating: int = 10
@ -35,7 +36,33 @@ def f2(not_yet_a: "A") -> int:
...
@overload
def overloaded_func(a: float) -> float:
...
@typing.overload
def overloaded_func(a: str) -> str:
...
def overloaded_func(a: Union[float, str]) -> Union[float, str]:
"""Overloaded function"""
return a * 2
@overload
def undoc_overloaded_func(a: str) -> str:
...
def undoc_overloaded_func(a: str) -> str:
return a * 2
class A:
"""class A"""
is_an_a: ClassVar[bool] = True
not_assigned_to: ClassVar[str]
@ -57,6 +84,40 @@ class A:
"""My method."""
return "method"
@overload
def overloaded_method(self, a: float) -> float:
...
@typing.overload
def overloaded_method(self, a: str) -> str:
...
def overloaded_method(self, a: Union[float, str]) -> Union[float, str]:
"""Overloaded method"""
return a * 2
@overload
def undoc_overloaded_method(self, a: float) -> float:
...
def undoc_overloaded_method(self, a: float) -> float:
return a * 2
@typing.overload
@classmethod
def overloaded_class_method(cls, a: float) -> float:
...
@overload
@classmethod
def overloaded_class_method(cls, a: str) -> str:
...
@classmethod
def overloaded_class_method(cls, a: Union[float, str]) -> Union[float, str]:
"""Overloaded class method"""
return a * 2
async def async_function(self, wait: bool) -> int:
if wait:

@ -176,6 +176,29 @@ class TestPy3Module(object):
if sphinx.version_info >= (2, 1):
assert "my_method(self) -> str" in example_file
def test_overload(self):
example_path = "_build/text/autoapi/example/index.txt"
with io.open(example_path, encoding="utf8") as example_handle:
example_file = example_handle.read()
assert "overloaded_func(a: float" in example_file
assert "overloaded_func(a: str" in example_file
assert "overloaded_func(a: Union" not in example_file
assert "Overloaded function" in example_file
assert "overloaded_method(self, a: float" in example_file
assert "overloaded_method(self, a: str" in example_file
assert "overloaded_method(self, a: Union" not in example_file
assert "Overloaded method" in example_file
assert "overloaded_class_method(cls, a: float" in example_file
assert "overloaded_class_method(cls, a: str" in example_file
assert "overloaded_class_method(cls, a: Union" not in example_file
assert "Overloaded method" in example_file
assert "undoc_overloaded_func" in example_file
assert "undoc_overloaded_method" in example_file
def test_async(self):
example_path = "_build/text/autoapi/example/index.txt"
with io.open(example_path, encoding="utf8") as example_handle:
@ -189,6 +212,23 @@ class TestPy3Module(object):
assert "async_function" in example_file
@pytest.mark.skipif(
sys.version_info < (3, 6), reason="Annotations are invalid in Python <3.5"
)
def test_py3_hiding_undoc_overloaded_members(builder):
confoverrides = {"autoapi_options": ["members", "special-members"]}
builder("py3example", confoverrides=confoverrides)
example_path = "_build/text/autoapi/example/index.txt"
with io.open(example_path, encoding="utf8") as example_handle:
example_file = example_handle.read()
assert "overloaded_func" in example_file
assert "overloaded_method" in example_file
assert "undoc_overloaded_func" not in example_file
assert "undoc_overloaded_method" not in example_file
@pytest.mark.skipif(
sys.version_info < (3,), reason="Annotations are not supported in astroid<2"
)

Loading…
Cancel
Save