From b079c26f0af8085bccdadc72c61c8164ca5ab0f8 Mon Sep 17 00:00:00 2001 From: Simon Sawicki Date: Sun, 30 Apr 2023 19:50:22 +0200 Subject: [PATCH] [utils] `traverse_obj`: More fixes (#6959) - Fix result when branching with `traverse_string` - Fix `slice` path on `dict`s - Fix tests and docstrings from 21b5ec86c2c37d10c5bb97edd7051d3aac16bb3e - Add `is_iterable_like` helper function Authored by: Grub4K --- test/test_utils.py | 21 +++++++++++++++++++-- yt_dlp/utils.py | 28 ++++++++++++++++++---------- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index f2f3b8170..e1bf6ac20 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2016,7 +2016,7 @@ Line 1 msg='nested `...` queries should work') self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4), msg='`...` query result should be flattened') - self.assertEqual(traverse_obj(range(4), ...), list(range(4)), + self.assertEqual(traverse_obj(iter(range(4)), ...), list(range(4)), msg='`...` should accept iterables') # Test function as key @@ -2025,7 +2025,7 @@ Line 1 msg='function as query key should perform a filter based on (key, value)') self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'}, msg='exceptions in the query function should be catched') - self.assertEqual(traverse_obj(range(4), lambda _, x: x % 2 == 0), [0, 2], + self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2], msg='function key should accept iterables') if __debug__: with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'): @@ -2051,6 +2051,17 @@ Line 1 with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'): traverse_obj(_TEST_DATA, {str.upper, str}) + # Test `slice` as a key + _SLICE_DATA = [0, 1, 2, 3, 4] + self.assertEqual(traverse_obj(_TEST_DATA, ('dict', slice(1))), None, + msg='slice on a dictionary should not throw') + self.assertEqual(traverse_obj(_SLICE_DATA, slice(1)), _SLICE_DATA[:1], + msg='slice key should apply slice to sequence') + self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 2)), _SLICE_DATA[1:2], + msg='slice key should apply slice to sequence') + self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 4, 2)), _SLICE_DATA[1:4:2], + msg='slice key should apply slice to sequence') + # Test alternative paths self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', msg='multiple `paths` should be treated as alternative paths') @@ -2234,6 +2245,12 @@ Line 1 self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), traverse_string=True), ['s', 'r'], msg='branching should result in list if `traverse_string`') + self.assertEqual(traverse_obj({}, (0, ...), traverse_string=True), [], + msg='branching should result in list if `traverse_string`') + self.assertEqual(traverse_obj({}, (0, lambda x, y: True), traverse_string=True), [], + msg='branching should result in list if `traverse_string`') + self.assertEqual(traverse_obj({}, (0, slice(1)), traverse_string=True), [], + msg='branching should result in list if `traverse_string`') # Test is_user_input behavior _IS_USER_INPUT_DATA = {'range8': list(range(8))} diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index f69311462..2f5e66720 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -3273,8 +3273,14 @@ def multipart_encode(data, boundary=None): return out, content_type -def variadic(x, allowed_types=(str, bytes, dict)): - return x if isinstance(x, collections.abc.Iterable) and not isinstance(x, allowed_types) else (x,) +def is_iterable_like(x, allowed_types=collections.abc.Iterable, blocked_types=NO_DEFAULT): + if blocked_types is NO_DEFAULT: + blocked_types = (str, bytes, collections.abc.Mapping) + return isinstance(x, allowed_types) and not isinstance(x, blocked_types) + + +def variadic(x, allowed_types=NO_DEFAULT): + return x if is_iterable_like(x, blocked_types=allowed_types) else (x,) def dict_get(d, key_or_keys, default=None, skip_false_values=True): @@ -5467,7 +5473,7 @@ def traverse_obj( obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True, casesense=True, is_user_input=False, traverse_string=False): """ - Safely traverse nested `dict`s and `Sequence`s + Safely traverse nested `dict`s and `Iterable`s >>> obj = [{}, {"key": "value"}] >>> traverse_obj(obj, (1, "key")) @@ -5475,7 +5481,7 @@ def traverse_obj( Each of the provided `paths` is tested and the first producing a valid result will be returned. The next path will also be tested if the path branched but no results could be found. - Supported values for traversal are `Mapping`, `Sequence` and `re.Match`. + Supported values for traversal are `Mapping`, `Iterable` and `re.Match`. Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. @@ -5492,7 +5498,7 @@ def traverse_obj( Read as: `[traverse_obj(obj, branch) for branch in branches]`. - `function`: Branch out and return values filtered by the function. Read as: `[value for key, value in obj if function(key, value)]`. - For `Sequence`s, `key` is the index of the value. + For `Iterable`s, `key` is the index of the value. For `re.Match`es, `key` is the group number (0 = full match) as well as additionally any group names, if given. - `dict` Transform the current object and return a matching dict. @@ -5540,7 +5546,9 @@ def traverse_obj( result = None if obj is None and traverse_string: - pass + if key is ... or callable(key) or isinstance(key, slice): + branching = True + result = () elif key is None: result = obj @@ -5563,7 +5571,7 @@ def traverse_obj( branching = True if isinstance(obj, collections.abc.Mapping): result = obj.values() - elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)): + elif is_iterable_like(obj): result = obj elif isinstance(obj, re.Match): result = obj.groups() @@ -5577,7 +5585,7 @@ def traverse_obj( branching = True if isinstance(obj, collections.abc.Mapping): iter_obj = obj.items() - elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)): + elif is_iterable_like(obj): iter_obj = enumerate(obj) elif isinstance(obj, re.Match): iter_obj = itertools.chain( @@ -5601,7 +5609,7 @@ def traverse_obj( } or None elif isinstance(obj, collections.abc.Mapping): - result = (obj.get(key) if casesense or (key in obj) else + result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else next((v for k, v in obj.items() if casefold(k) == key), None)) elif isinstance(obj, re.Match): @@ -5613,7 +5621,7 @@ def traverse_obj( result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) elif isinstance(key, (int, slice)): - if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, (str, bytes)): + if is_iterable_like(obj, collections.abc.Sequence): branching = isinstance(key, slice) with contextlib.suppress(IndexError): result = obj[key]