diff --git a/test/test_networking.py b/test/test_networking.py index 9c33b0d4c..2622d24da 100644 --- a/test/test_networking.py +++ b/test/test_networking.py @@ -1035,17 +1035,17 @@ class TestRequestDirector: assert isinstance(director.send(Request('http://')), FakeResponse) def test_unsupported_handlers(self): - director = RequestDirector(logger=FakeLogger()) - director.add_handler(FakeRH(logger=FakeLogger())) - class SupportedRH(RequestHandler): _SUPPORTED_URL_SCHEMES = ['http'] def _send(self, request: Request): return Response(fp=io.BytesIO(b'supported'), headers={}, url=request.url) - # This handler should by default take preference over FakeRH + director = RequestDirector(logger=FakeLogger()) director.add_handler(SupportedRH(logger=FakeLogger())) + director.add_handler(FakeRH(logger=FakeLogger())) + + # First should take preference assert director.send(Request('http://')).read() == b'supported' assert director.send(Request('any://')).read() == b'' @@ -1072,6 +1072,27 @@ class TestRequestDirector: director.add_handler(UnexpectedRH(logger=FakeLogger)) assert director.send(Request('any://')) + def test_preference(self): + director = RequestDirector(logger=FakeLogger()) + director.add_handler(FakeRH(logger=FakeLogger())) + + class SomeRH(RequestHandler): + _SUPPORTED_URL_SCHEMES = ['http'] + + def _send(self, request: Request): + return Response(fp=io.BytesIO(b'supported'), headers={}, url=request.url) + + def some_preference(rh, request): + return (0 if not isinstance(rh, SomeRH) + else 100 if 'prefer' in request.headers + else -1) + + director.add_handler(SomeRH(logger=FakeLogger())) + director.preferences.add(some_preference) + + assert director.send(Request('http://')).read() == b'' + assert director.send(Request('http://', headers={'prefer': '1'})).read() == b'supported' + # XXX: do we want to move this to test_YoutubeDL.py? class TestYoutubeDLNetworking: diff --git a/yt_dlp/YoutubeDL.py b/yt_dlp/YoutubeDL.py index 87bca5bbe..666d89b46 100644 --- a/yt_dlp/YoutubeDL.py +++ b/yt_dlp/YoutubeDL.py @@ -34,7 +34,7 @@ from .extractor.common import UnsupportedURLIE from .extractor.openload import PhantomJSwrapper from .minicurses import format_text from .networking import HEADRequest, Request, RequestDirector -from .networking.common import _REQUEST_HANDLERS +from .networking.common import _REQUEST_HANDLERS, _RH_PREFERENCES from .networking.exceptions import ( HTTPError, NoSupportingHandlers, @@ -683,7 +683,7 @@ class YoutubeDL: self.params['http_headers'] = HTTPHeaderDict(std_headers, self.params.get('http_headers')) self._load_cookies(self.params['http_headers'].get('Cookie')) # compat self.params['http_headers'].pop('Cookie', None) - self._request_director = self.build_request_director(_REQUEST_HANDLERS.values()) + self._request_director = self.build_request_director(_REQUEST_HANDLERS.values(), _RH_PREFERENCES) if auto_init and auto_init != 'no_verbose_header': self.print_debug_header() @@ -4077,7 +4077,7 @@ class YoutubeDL: except HTTPError as e: # TODO: Remove in a future release raise _CompatHTTPError(e) from e - def build_request_director(self, handlers): + def build_request_director(self, handlers, preferences=None): logger = _YDLLogger(self) headers = self.params['http_headers'].copy() proxies = self.proxies.copy() @@ -4106,6 +4106,7 @@ class YoutubeDL: }, }), )) + director.preferences.update(preferences or []) return director def encode(self, s): diff --git a/yt_dlp/networking/common.py b/yt_dlp/networking/common.py index 8fba8c1c5..584c7bb4d 100644 --- a/yt_dlp/networking/common.py +++ b/yt_dlp/networking/common.py @@ -31,8 +31,19 @@ from ..utils import ( ) from ..utils.networking import HTTPHeaderDict, normalize_url -if typing.TYPE_CHECKING: - RequestData = bytes | Iterable[bytes] | typing.IO | None + +def register_preference(*handlers: type[RequestHandler]): + assert all(issubclass(handler, RequestHandler) for handler in handlers) + + def outer(preference: Preference): + @functools.wraps(preference) + def inner(handler, *args, **kwargs): + if not handlers or isinstance(handler, handlers): + return preference(handler, *args, **kwargs) + return 0 + _RH_PREFERENCES.add(inner) + return inner + return outer class RequestDirector: @@ -40,12 +51,17 @@ class RequestDirector: Helper class that, when given a request, forward it to a RequestHandler that supports it. + Preference functions in the form of func(handler, request) -> int + can be registered into the `preferences` set. These are used to sort handlers + in order of preference. + @param logger: Logger instance. @param verbose: Print debug request information to stdout. """ def __init__(self, logger, verbose=False): self.handlers: dict[str, RequestHandler] = {} + self.preferences: set[Preference] = set() self.logger = logger # TODO(Grub4k): default logger self.verbose = verbose @@ -58,6 +74,16 @@ class RequestDirector: assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler' self.handlers[handler.RH_KEY] = handler + def _get_handlers(self, request: Request) -> list[RequestHandler]: + """Sorts handlers by preference, given a request""" + preferences = { + rh: sum(pref(rh, request) for pref in self.preferences) + for rh in self.handlers.values() + } + self._print_verbose('Handler preferences for this request: %s' % ', '.join( + f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items())) + return sorted(self.handlers.values(), key=preferences.get, reverse=True) + def _print_verbose(self, msg): if self.verbose: self.logger.stdout(f'director: {msg}') @@ -73,8 +99,7 @@ class RequestDirector: unexpected_errors = [] unsupported_errors = [] - # TODO (future): add a per-request preference system - for handler in reversed(list(self.handlers.values())): + for handler in self._get_handlers(request): self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.') try: handler.validate(request) @@ -530,3 +555,10 @@ class Response(io.IOBase): def getheader(self, name, default=None): deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2) return self.get_header(name, default) + + +if typing.TYPE_CHECKING: + RequestData = bytes | Iterable[bytes] | typing.IO | None + Preference = typing.Callable[[RequestHandler, Request], int] + +_RH_PREFERENCES: set[Preference] = set()