diff --git a/catcli/noder.py b/catcli/noder.py index d0ab9d4..cfd21da 100644 --- a/catcli/noder.py +++ b/catcli/noder.py @@ -14,7 +14,8 @@ import anytree # type: ignore # local imports from catcli import nodes from catcli.nodes import NodeAny, NodeStorage, \ - NodeTop, NodeFile, NodeArchived, NodeDir, NodeMeta + NodeTop, NodeFile, NodeArchived, NodeDir, NodeMeta, \ + typcast_node from catcli.utils import size_to_str, epoch_to_str, md5sum, fix_badchars from catcli.logger import Logger from catcli.nodeprinter import NodePrinter @@ -86,6 +87,7 @@ class Noder: try: bpath = os.path.basename(path) the_node = resolv.get(top, bpath) + typcast_node(the_node) return cast(NodeAny, the_node) except anytree.resolver.ChildResolverError: if not quiet: @@ -296,6 +298,7 @@ class Noder: """remove any node not flagged and clean flags""" cnt = 0 for node in anytree.PreOrderIter(top): + typcast_node(node) if node.type not in [nodes.TYPE_DIR, nodes.TYPE_FILE]: continue if self._clean(node): diff --git a/catcli/nodes.py b/catcli/nodes.py index eb42e81..b622410 100644 --- a/catcli/nodes.py +++ b/catcli/nodes.py @@ -21,6 +21,22 @@ NAME_TOP = 'top' NAME_META = 'meta' +def typcast_node(node: Any) -> None: + """typecast node to its sub type""" + if node.type == TYPE_TOP: + node.__class__ = NodeTop + elif node.type == TYPE_FILE: + node.__class__ = NodeFile + elif node.type == TYPE_DIR: + node.__class__ = NodeDir + elif node.type == TYPE_ARCHIVED: + node.__class__ = NodeArchived + elif node.type == TYPE_STORAGE: + node.__class__ = NodeStorage + elif node.type == TYPE_META: + node.__class__ = NodeMeta + + class NodeAny(NodeMixin): # type: ignore """generic node""" diff --git a/tests/test_update.py b/tests/test_update.py index 1107b45..6f75610 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -123,7 +123,7 @@ class TestUpdate(unittest.TestCase): noder.print_tree(top) # explore the top node to find all nodes - self.assertTrue(len(top.children) == 1) + self.assertEqual(len(top.children), 1) storage = top.children[0] self.assertTrue(len(storage.children) == 8)