diff --git a/tools/nv-driver-locator/get_deb_drivers.py b/tools/nv-driver-locator/get_deb_drivers.py new file mode 100755 index 0000000..9a1f80f --- /dev/null +++ b/tools/nv-driver-locator/get_deb_drivers.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 + +import urllib.request +import urllib.error +import json +import posixpath +import codecs +import gzip +import sys +from contextlib import contextmanager +import itertools +import string +import codecs +import pprint +import re +import collections + +USER_AGENT = 'Debian APT-HTTP/1.3 (1.6.6)' +DEFAULT_REPO = "https://developer.download.nvidia.com/"\ + "compute/cuda/repos/ubuntu1804/x86_64/Packages.gz" +DEFAULT_REGEX = '^cuda-drivers$' +ENCODING = 'utf-8-sig' + +DriverEntry = collections.namedtuple('DriverEntry', ('name', 'version')) + +def upstream_version(version): + epoch, delim, tail = version.partition(':') + version = tail if delim else epoch + return version.partition('-')[0] + +@contextmanager +def packages_reader(url, timeout): + http_req = urllib.request.Request( + url, + data=None, + headers={ + 'User-Agent': USER_AGENT + } + ) + with urllib.request.urlopen(http_req, None, timeout) as resp: + if url.endswith('.gz') or resp.headers.get('Content-Type', '') == 'application/x-gzip': + with gzip.GzipFile(fileobj=resp) as reader: + yield codecs.getreader(ENCODING)(reader) + else: + yield codecs.getreader(ENCODING)(resp) + +def parse_packages(reader): + for k, g in itertools.groupby(reader, lambda s: bool(s.strip())): + if not k: + continue + pkg = dict() + current_key = None + current_val = None + for line in g: + if line[0] in string.whitespace: + # Continuation + if current_key is None: + continue + current_val += line.rstrip() + else: + # New field + if current_key is not None: + pkg[current_key.lower()] = current_val + current_key, _, current_val = line.partition(':') + current_val = current_val.strip() + if current_key is not None: + pkg[current_key.lower()] = current_val + if 'package' in pkg and 'version' in pkg: + yield pkg + +def _get_deb_versions(*, url=DEFAULT_REPO, name=DEFAULT_REGEX, timeout=10): + pattern = re.compile(name) + with packages_reader(url, timeout) as lines: + for pkg in parse_packages(lines): + if pattern.match(pkg['package']) is not None: + yield DriverEntry(pkg['package'], upstream_version(pkg['version'])) + +def get_deb_versions(*args, **kwargs): + return list(set(_get_deb_versions(*args, **kwargs))) + + +def parse_args(): + import argparse + + def check_positive_float(val): + val = float(val) + if val <= 0: + raise ValueError("Value %s is not valid positive float" % + (repr(val),)) + return val + + parser = argparse.ArgumentParser( + description="Retrieves info about latest NVIDIA drivers available in " + "Nvidia deb packages repositories", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("-U", "--url", + default=DEFAULT_REPO, + help="URL for Packages or Packages.gz file") + parser.add_argument("-N", "--name", + default=DEFAULT_REGEX, + help="Package name regexp") + parser.add_argument("-T", "--timeout", + type=check_positive_float, + default=10., + help="timeout for network operations") + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + drv = get_deb_versions(url=args.url, + name=args.name, + timeout=args.timeout) + if drv is None: + print("NOT FOUND") + sys.exit(3) + if False: #not args.raw: + print("Version: %s" % (drv['DriverAttributes']['Version'],)) + print("Beta: %s" % (bool(int(drv['DriverAttributes']['IsBeta'])),)) + print("WHQL: %s" % (bool(int(drv['DriverAttributes']['IsWHQL'])),)) + print("URL: %s" % (drv['DriverAttributes']['DownloadURLAdmin'],)) + else: + json.dump(drv, sys.stdout, indent=4) + sys.stdout.flush() + + +if __name__ == '__main__': + main()