#!/usr/bin/env python3 import sys import json import argparse import hashlib import importlib import logging from abc import ABC, abstractmethod import functools HASH_DELIM = b'\x00' HASH = hashlib.sha256 class BaseDB(ABC): @abstractmethod def check_key(self, key): pass @abstractmethod def set_key(self, key, value): pass class FileDB(BaseDB): def __init__(self, workdir): self._ospath = importlib.import_module('os.path') self._tempfile = importlib.import_module('tempfile') self._wd = workdir self._test_writable() def _test_writable(self): TEST_STRING = b"test" with self._tempfile.NamedTemporaryFile('w+b', 0, dir=self._wd) as f: f.write(TEST_STRING) f.flush() with open(f.name, 'rb') as tf: assert tf.read() == TEST_STRING, "Test write failed" def _get_key_filename(self, key): return self._ospath.join(self._wd, key + '.json') def check_key(self, key): filename = self._get_key_filename(key) return self._ospath.isfile(filename) def set_key(self, key, obj): filename = self._get_key_filename(key) with open(filename, 'w') as f: json.dump(obj, f, indent=4) f.flush() class Hasher: def __init__(self, key_components): self._key_components = key_components def _eval_key_component(self, obj, component_path): res = obj try: for path_component in component_path: res = res[path_component] except (KeyError, IndexError): return b'' return str(res).encode('utf-8') def hash_object(self, obj): return HASH(HASH_DELIM.join( self._eval_key_component(obj, c) for c in self._key_components) ).hexdigest() class BaseNotifier(ABC): @abstractmethod def notify(self, obj): pass class EmailNotifier(BaseNotifier): def __init__(self, name, *, from_addr, to_addrs, host='localhost', port=None, local_hostname=None, use_ssl=False, use_starttls=False, login=None, password=None, timeout=10): self.name = name self._from_addr = from_addr self._Mailer = importlib.import_module('mailer').Mailer self._MIMEText = importlib.import_module('email.mime.text').MIMEText self._MIMEMult = importlib.import_module( 'email.mime.multipart').MIMEMultipart self._MIMEBase = importlib.import_module('email.mime.base').MIMEBase self._encoders = importlib.import_module('email.encoders') self._mimeheader = importlib.import_module('email.header').Header self._m = self._Mailer(from_addr=from_addr, host=host, port=port, local_hostname=local_hostname, use_ssl=use_ssl, use_starttls=use_starttls, login=login, password=password, timeout=timeout) self._to_addrs = to_addrs def notify(self, obj): msg = self._MIMEMult() msg['Subject'] = self._mimeheader("New Nvidia driver available!", "utf-8") msg['From'] = self._from_addr msg['To'] = ', '.join(self._to_addrs) obj_text = json.dumps(obj, indent=4, ensure_ascii=False) msg_text = json.dumps(obj, indent=4, ensure_ascii=True) body = "See attached JSON or message body below:\n" body += msg_text msg.attach(self._MIMEText(body, 'plain', 'utf-8')) p = self._MIMEBase('application', 'octet-stream') p.set_payload(obj_text.encode('ascii')) self._encoders.encode_base64(p) p.add_header('Content-Disposition', "attachment; filename=obj.json") msg.attach(p) self._m.send(self._to_addrs, msg.as_string()) class CommandNotifier(BaseNotifier): def __init__(self, name, *, cmdline, timeout=10): self.name = name self._subprocess = importlib.import_module('subprocess') self._cmdline = cmdline self._timeout = timeout def notify(self, obj): proc = self._subprocess.Popen(self._cmdline, stdin=self._subprocess.PIPE) try: proc.communicate(json.dumps(obj, indent=4).encode('utf-8'), self._timeout) except self._subprocess.TimeoutExpired: proc.kill() proc.communicate() class BaseChannel(ABC): @abstractmethod def get_latest_drivers(self): pass class GFEClientChannel(BaseChannel): def __init__(self, name, notebook=False, x86_64=True, os_version="10.0", os_build="17763", language=1033, beta=False, dch=False, crd=False, timeout=10): self.name = name self._notebook = notebook self._x86_64 = x86_64 self._os_version = os_version self._os_build = os_build self._language = language self._beta = beta self._dch = dch self._crd = crd self._timeout = timeout gfe_get_driver = importlib.import_module('gfe_get_driver') self._get_latest_drivers = gfe_get_driver.get_latest_geforce_driver def get_latest_drivers(self): res = self._get_latest_drivers(notebook=self._notebook, x86_64=self._x86_64, os_version=self._os_version, os_build=self._os_build, language=self._language, beta=self._beta, dch=self._dch, crd=self._crd, timeout=self._timeout) if res is None: return res.update({ 'ChannelAttributes': { 'Name': self.name, 'Type': self.__class__.__name__, 'OS': 'Windows%d_%d' % (float(self._os_version), 64 if self._x86_64 else 32), 'OSBuild': self._os_build, 'Language': self._language, 'Beta': self._beta, 'DCH': self._dch, 'CRD': self._crd, 'Mobile': self._notebook, } }) yield res class NvidiaDownloadsChannel(BaseChannel): def __init__(self, name, *, os="Linux_64", product="GeForce", certlevel="All", driver_type="Standard", lang="English", cuda_ver="Nothing", timeout=10): self.name = name gnd = importlib.import_module('get_nvidia_downloads') self._gnd = gnd self._os = gnd.OS[os] self._product = gnd.Product[product] self._certlevel = gnd.CertLevel[certlevel] self._driver_type = gnd.DriverType[driver_type] self._lang = gnd.DriverLanguage[lang] self._cuda_ver = gnd.CUDAToolkitVersion[cuda_ver] self._timeout = timeout def get_latest_drivers(self): latest = self._gnd.get_drivers(os=self._os, product=self._product, certlevel=self._certlevel, driver_type=self._driver_type, lang=self._lang, cuda_ver=self._cuda_ver, timeout=self._timeout) if not latest: return res = { 'DriverAttributes': { 'Version': latest['version'], 'Name': latest['name'], 'NameLocalized': latest['name'], }, 'ChannelAttributes': { 'Name': self.name, 'Type': self.__class__.__name__, 'OS': self._os.name, 'Product': self._product.name, 'CertLevel': self._certlevel.name, 'DriverType': self._driver_type.name, 'Lang': self._lang.name, 'CudaVer': self._cuda_ver.name, } } if 'download_url' in latest: res['DriverAttributes']['DownloadURL'] = latest['download_url'] yield res class CudaToolkitDownloadsChannel(BaseChannel): def __init__(self, name, *, timeout=10): self.name = name gcd = importlib.import_module('get_cuda_downloads') self._gcd = gcd self._timeout = timeout def get_latest_drivers(self): latest = self._gcd.get_latest_cuda_tk(timeout=self._timeout) if not latest: return yield { 'DriverAttributes': { 'Version': '???', 'Name': latest, 'NameLocalized': latest, }, 'ChannelAttributes': { 'Name': self.name, 'Type': self.__class__.__name__, } } class VulkanBetaDownloadsChannel(BaseChannel): def __init__(self, name, *, timeout=10): self.name = name self._timeout = timeout self._gvd = importlib.import_module('get_vulkan_downloads') def get_latest_drivers(self): drivers = self._gvd.get_drivers(timeout=self._timeout) if drivers is None: return for drv in drivers: yield { 'DriverAttributes': { 'Version': drv['version'], 'Name': drv['name'], 'NameLocalized': drv['name'], }, 'ChannelAttributes': { 'Name': self.name, 'Type': self.__class__.__name__, 'OS': drv['os'], } } class DebRepoChannel(BaseChannel): def __init__(self, name, *, url, pkg_pattern, driver_name="Linux x64 (AMD64/EM64T) Display Driver", timeout=10): self.name = name self._gdd = importlib.import_module('get_deb_drivers') self._url = url self._pkg_pattern = pkg_pattern self._driver_name = driver_name self._timeout = timeout def get_latest_drivers(self): drivers = self._gdd.get_deb_versions(url=self._url, name=self._pkg_pattern, timeout=self._timeout) if drivers is None: return for drv in drivers: yield { 'DriverAttributes': { 'Version': drv.version, 'DebPkgName': drv.name, 'Name': self._driver_name, 'NameLocalized': self._driver_name, }, 'ChannelAttributes': { 'Name': self.name, 'Type': self.__class__.__name__, 'OS': 'Linux_64', 'PkgPattern': self._pkg_pattern, } } def parse_args(): parser = argparse.ArgumentParser( description="Watches for GeForce experience driver updates for " "configured systems", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("-c", "--config", default="/etc/nv-driver-locator.json", help="config file location") args = parser.parse_args() return args class DriverLocator: _ret_code = 0 def __init__(self, conf): self._logger = logging.getLogger(self.__class__.__name__) self._channels = self._construct_channels(conf['channels']) self._db = self._construct_db(conf['db']) self._hasher = Hasher(conf['key_components']) self._notifiers = self._construct_notifiers(conf['notifiers']) def _construct_channels(self, channels_config): channel_types = { 'gfe_client': GFEClientChannel, 'nvidia_downloads': NvidiaDownloadsChannel, 'cuda_downloads': CudaToolkitDownloadsChannel, 'vulkan_beta': VulkanBetaDownloadsChannel, 'deb_packages': DebRepoChannel, } channels = [] for ch in channels_config: try: ctor = channel_types[ch['type']] C = ctor(ch['name'], **ch['params']) except Exception as e: self._perror("Channel construction failed with exception: %s. " "Skipping..." % (str(e),)) else: channels.append(C) return channels def _construct_db(self, db_config): db_types = { 'file': FileDB, } ctor = db_types[db_config['type']] db = ctor(**db_config['params']) return db def _construct_notifiers(self, notifiers_config): notifier_types = { 'email': EmailNotifier, 'command': CommandNotifier, } notifiers = [] for nc in notifiers_config: try: ctor = notifier_types[nc['type']] N = ctor(nc['name'], **nc['params']) except Exception as e: self._perror("Notifier construction failed with exception: %s." " Skipping..." % (str(e),)) else: notifiers.append(N) return notifiers def _perror(self, err): self._ret_code = 3 self._logger.error(err) def _notify_all(self, obj): fails = 0 for n in self._notifiers: try: n.notify(obj) except Exception as e: self._perror("Notify channel %s failed with exception: %s." % (n.name, str(e))) fails += 1 return fails < len(self._notifiers) def run(self): for ch in self._channels: counter = 0 try: drivers = ch.get_latest_drivers() except Exception as e: self._perror("get_latest_drivers() invocation failed for " "channel %s. Exception: %s. Continuing..." % (repr(ch.name), str(e))) continue try: # Fetch for drv in drivers: counter += 1 # Hash try: key = self._hasher.hash_object(drv) except Exception as e: self._perror("Key evaluation failed for channel %s. " "Exception: %s" % (repr(name), str(e))) continue # Notify if not self._db.check_key(key): if self._notify_all(drv): self._db.set_key(key, drv) except Exception as e: self._perror("channel %s enumeration terminated with exception: %s" % (repr(name), str(e))) continue if not counter: self._perror("Drivers not found for channel %s" % (repr(ch.name),)) return self._ret_code def setup_logger(name, verbosity): logger = logging.getLogger(name) logger.setLevel(verbosity) handler = logging.StreamHandler() handler.setLevel(verbosity) handler.setFormatter(logging.Formatter('%(asctime)s ' '%(levelname)-8s ' '%(name)s: %(message)s', '%Y-%m-%d %H:%M:%S')) logger.addHandler(handler) return logger def main(): args = parse_args() setup_logger(DriverLocator.__name__, logging.ERROR) with open(args.config, 'r') as conf_file: conf = json.load(conf_file) ret = DriverLocator(conf).run() sys.exit(ret) if __name__ == '__main__': main()