You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

493 lines
16 KiB
Python

#!/usr/bin/env python3
import sys
import json
import argparse
import hashlib
import importlib
import logging
from abc import ABC, abstractmethod
import functools
HASH_DELIM = b'\x00'
HASH = hashlib.sha256
class BaseDB(ABC):
@abstractmethod
def check_key(self, key):
pass
@abstractmethod
def set_key(self, key, value):
pass
class FileDB(BaseDB):
def __init__(self, workdir):
self._ospath = importlib.import_module('os.path')
self._tempfile = importlib.import_module('tempfile')
self._wd = workdir
self._test_writable()
def _test_writable(self):
TEST_STRING = b"test"
with self._tempfile.NamedTemporaryFile('w+b', 0, dir=self._wd) as f:
f.write(TEST_STRING)
f.flush()
with open(f.name, 'rb') as tf:
assert tf.read() == TEST_STRING, "Test write failed"
def _get_key_filename(self, key):
return self._ospath.join(self._wd, key + '.json')
def check_key(self, key):
filename = self._get_key_filename(key)
return self._ospath.isfile(filename)
def set_key(self, key, obj):
filename = self._get_key_filename(key)
with open(filename, 'w') as f:
json.dump(obj, f, indent=4)
f.flush()
class Hasher:
def __init__(self, key_components):
self._key_components = key_components
def _eval_key_component(self, obj, component_path):
res = obj
try:
for path_component in component_path:
res = res[path_component]
except (KeyError, IndexError):
return b''
return str(res).encode('utf-8')
def hash_object(self, obj):
return HASH(HASH_DELIM.join(
self._eval_key_component(obj, c) for c in self._key_components)
).hexdigest()
class BaseNotifier(ABC):
@abstractmethod
def notify(self, obj):
pass
class EmailNotifier(BaseNotifier):
def __init__(self, name, *,
from_addr,
to_addrs,
host='localhost',
port=None,
local_hostname=None,
use_ssl=False,
use_starttls=False,
login=None,
password=None,
timeout=10):
self.name = name
self._from_addr = from_addr
self._Mailer = importlib.import_module('mailer').Mailer
self._MIMEText = importlib.import_module('email.mime.text').MIMEText
self._MIMEMult = importlib.import_module(
'email.mime.multipart').MIMEMultipart
self._MIMEBase = importlib.import_module('email.mime.base').MIMEBase
self._encoders = importlib.import_module('email.encoders')
self._mimeheader = importlib.import_module('email.header').Header
self._m = self._Mailer(from_addr=from_addr,
host=host,
port=port,
local_hostname=local_hostname,
use_ssl=use_ssl,
use_starttls=use_starttls,
login=login,
password=password,
timeout=timeout)
self._to_addrs = to_addrs
def notify(self, obj):
msg = self._MIMEMult()
msg['Subject'] = self._mimeheader("New Nvidia driver available!", "utf-8")
msg['From'] = self._from_addr
msg['To'] = ', '.join(self._to_addrs)
obj_text = json.dumps(obj, indent=4, ensure_ascii=False)
msg_text = json.dumps(obj, indent=4, ensure_ascii=True)
body = "See attached JSON or message body below:\n"
body += msg_text
msg.attach(self._MIMEText(body, 'plain', 'utf-8'))
p = self._MIMEBase('application', 'octet-stream')
p.set_payload(obj_text.encode('ascii'))
self._encoders.encode_base64(p)
p.add_header('Content-Disposition', "attachment; filename=obj.json")
msg.attach(p)
self._m.send(self._to_addrs, msg.as_string())
class CommandNotifier(BaseNotifier):
def __init__(self, name, *,
cmdline,
timeout=10):
self.name = name
self._subprocess = importlib.import_module('subprocess')
self._cmdline = cmdline
self._timeout = timeout
def notify(self, obj):
proc = self._subprocess.Popen(self._cmdline,
stdin=self._subprocess.PIPE)
try:
proc.communicate(json.dumps(obj, indent=4).encode('utf-8'),
self._timeout)
except self._subprocess.TimeoutExpired:
proc.kill()
proc.communicate()
class BaseChannel(ABC):
@abstractmethod
def get_latest_drivers(self):
pass
class GFEClientChannel(BaseChannel):
def __init__(self, name, notebook=False,
x86_64=True,
os_version="10.0",
os_build="17763",
language=1033,
beta=False,
dch=False,
crd=False,
timeout=10):
self.name = name
self._notebook = notebook
self._x86_64 = x86_64
self._os_version = os_version
self._os_build = os_build
self._language = language
self._beta = beta
self._dch = dch
self._crd = crd
self._timeout = timeout
gfe_get_driver = importlib.import_module('gfe_get_driver')
self._get_latest_drivers = gfe_get_driver.get_latest_geforce_driver
def get_latest_drivers(self):
res = self._get_latest_drivers(notebook=self._notebook,
x86_64=self._x86_64,
os_version=self._os_version,
os_build=self._os_build,
language=self._language,
beta=self._beta,
dch=self._dch,
crd=self._crd,
timeout=self._timeout)
if res is None:
return
res.update({
'ChannelAttributes': {
'Name': self.name,
'Type': self.__class__.__name__,
'OS': 'Windows%d_%d' % (float(self._os_version),
64 if self._x86_64 else 32),
'OSBuild': self._os_build,
'Language': self._language,
'Beta': self._beta,
'DCH': self._dch,
'CRD': self._crd,
'Mobile': self._notebook,
}
})
yield res
class NvidiaDownloadsChannel(BaseChannel):
def __init__(self, name, *,
os="Linux_64",
product="GeForce",
certlevel="All",
driver_type="Standard",
lang="English",
cuda_ver="Nothing",
timeout=10):
self.name = name
gnd = importlib.import_module('get_nvidia_downloads')
self._gnd = gnd
self._os = gnd.OS[os]
self._product = gnd.Product[product]
self._certlevel = gnd.CertLevel[certlevel]
self._driver_type = gnd.DriverType[driver_type]
self._lang = gnd.DriverLanguage[lang]
self._cuda_ver = gnd.CUDAToolkitVersion[cuda_ver]
self._timeout = timeout
def get_latest_drivers(self):
latest = self._gnd.get_drivers(os=self._os,
product=self._product,
certlevel=self._certlevel,
driver_type=self._driver_type,
lang=self._lang,
cuda_ver=self._cuda_ver,
timeout=self._timeout)
if not latest:
return
res = {
'DriverAttributes': {
'Version': latest['version'],
'Name': latest['name'],
'NameLocalized': latest['name'],
},
'ChannelAttributes': {
'Name': self.name,
'Type': self.__class__.__name__,
'OS': self._os.name,
'Product': self._product.name,
'CertLevel': self._certlevel.name,
'DriverType': self._driver_type.name,
'Lang': self._lang.name,
'CudaVer': self._cuda_ver.name,
}
}
if 'download_url' in latest:
res['DriverAttributes']['DownloadURL'] = latest['download_url']
yield res
class CudaToolkitDownloadsChannel(BaseChannel):
def __init__(self, name, *,
timeout=10):
self.name = name
gcd = importlib.import_module('get_cuda_downloads')
self._gcd = gcd
self._timeout = timeout
def get_latest_drivers(self):
latest = self._gcd.get_latest_cuda_tk(timeout=self._timeout)
if not latest:
return
yield {
'DriverAttributes': {
'Version': '???',
'Name': latest,
'NameLocalized': latest,
},
'ChannelAttributes': {
'Name': self.name,
'Type': self.__class__.__name__,
}
}
class VulkanBetaDownloadsChannel(BaseChannel):
def __init__(self, name, *,
timeout=10):
self.name = name
self._timeout = timeout
self._gvd = importlib.import_module('get_vulkan_downloads')
def get_latest_drivers(self):
drivers = self._gvd.get_drivers(timeout=self._timeout)
if drivers is None:
return
for drv in drivers:
yield {
'DriverAttributes': {
'Version': drv['version'],
'Name': drv['name'],
'NameLocalized': drv['name'],
},
'ChannelAttributes': {
'Name': self.name,
'Type': self.__class__.__name__,
'OS': drv['os'],
}
}
class DebRepoChannel(BaseChannel):
def __init__(self, name, *,
url,
pkg_pattern,
driver_name="Linux x64 (AMD64/EM64T) Display Driver",
timeout=10):
self.name = name
self._gdd = importlib.import_module('get_deb_drivers')
self._url = url
self._pkg_pattern = pkg_pattern
self._driver_name = driver_name
self._timeout = timeout
def get_latest_drivers(self):
drivers = self._gdd.get_deb_versions(url=self._url,
name=self._pkg_pattern,
timeout=self._timeout)
if drivers is None:
return
for drv in drivers:
yield {
'DriverAttributes': {
'Version': drv.version,
'DebPkgName': drv.name,
'Name': self._driver_name,
'NameLocalized': self._driver_name,
},
'ChannelAttributes': {
'Name': self.name,
'Type': self.__class__.__name__,
'OS': 'Linux_64',
'PkgPattern': self._pkg_pattern,
}
}
def parse_args():
parser = argparse.ArgumentParser(
description="Watches for GeForce experience driver updates for "
"configured systems",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-c", "--config",
default="/etc/nv-driver-locator.json",
help="config file location")
args = parser.parse_args()
return args
class DriverLocator:
_ret_code = 0
def __init__(self, conf):
self._logger = logging.getLogger(self.__class__.__name__)
self._channels = self._construct_channels(conf['channels'])
self._db = self._construct_db(conf['db'])
self._hasher = Hasher(conf['key_components'])
self._notifiers = self._construct_notifiers(conf['notifiers'])
def _construct_channels(self, channels_config):
channel_types = {
'gfe_client': GFEClientChannel,
'nvidia_downloads': NvidiaDownloadsChannel,
'cuda_downloads': CudaToolkitDownloadsChannel,
'vulkan_beta': VulkanBetaDownloadsChannel,
'deb_packages': DebRepoChannel,
}
channels = []
for ch in channels_config:
try:
ctor = channel_types[ch['type']]
C = ctor(ch['name'], **ch['params'])
except Exception as e:
self._perror("Channel construction failed with exception: %s. "
"Skipping..." % (str(e),))
else:
channels.append(C)
return channels
def _construct_db(self, db_config):
db_types = {
'file': FileDB,
}
ctor = db_types[db_config['type']]
db = ctor(**db_config['params'])
return db
def _construct_notifiers(self, notifiers_config):
notifier_types = {
'email': EmailNotifier,
'command': CommandNotifier,
}
notifiers = []
for nc in notifiers_config:
try:
ctor = notifier_types[nc['type']]
N = ctor(nc['name'], **nc['params'])
except Exception as e:
self._perror("Notifier construction failed with exception: %s."
" Skipping..." % (str(e),))
else:
notifiers.append(N)
return notifiers
def _perror(self, err):
self._ret_code = 3
self._logger.error(err)
def _notify_all(self, obj):
fails = 0
for n in self._notifiers:
try:
n.notify(obj)
except Exception as e:
self._perror("Notify channel %s failed with exception: %s." %
(n.name, str(e)))
fails += 1
return fails < len(self._notifiers)
def run(self):
for ch in self._channels:
counter = 0
try:
drivers = ch.get_latest_drivers()
except Exception as e:
self._perror("get_latest_drivers() invocation failed for "
"channel %s. Exception: %s. Continuing..." %
(repr(ch.name), str(e)))
continue
try:
# Fetch
for drv in drivers:
counter += 1
# Hash
try:
key = self._hasher.hash_object(drv)
except Exception as e:
self._perror("Key evaluation failed for channel %s. "
"Exception: %s" % (repr(name), str(e)))
continue
# Notify
if not self._db.check_key(key):
if self._notify_all(drv):
self._db.set_key(key, drv)
except Exception as e:
self._perror("channel %s enumeration terminated with exception: %s" %
(repr(name), str(e)))
continue
if not counter:
self._perror("Drivers not found for channel %s" %
(repr(ch.name),))
return self._ret_code
def setup_logger(name, verbosity):
logger = logging.getLogger(name)
logger.setLevel(verbosity)
handler = logging.StreamHandler()
handler.setLevel(verbosity)
handler.setFormatter(logging.Formatter('%(asctime)s '
'%(levelname)-8s '
'%(name)s: %(message)s',
'%Y-%m-%d %H:%M:%S'))
logger.addHandler(handler)
return logger
def main():
args = parse_args()
setup_logger(DriverLocator.__name__, logging.ERROR)
with open(args.config, 'r') as conf_file:
conf = json.load(conf_file)
ret = DriverLocator(conf).run()
sys.exit(ret)
if __name__ == '__main__':
main()