diff --git a/tools/nv-driver-locator/nv-driver-locator.py b/tools/nv-driver-locator/nv-driver-locator.py index 785ca97..d94cfad 100755 --- a/tools/nv-driver-locator/nv-driver-locator.py +++ b/tools/nv-driver-locator/nv-driver-locator.py @@ -150,7 +150,7 @@ class CommandNotifier(BaseNotifier): class BaseChannel(ABC): @abstractmethod - def get_latest_driver(self): + def get_latest_drivers(self): pass @@ -175,10 +175,10 @@ class GFEClientChannel(BaseChannel): self._crd = crd self._timeout = timeout gfe_get_driver = importlib.import_module('gfe_get_driver') - self._get_latest_driver = gfe_get_driver.get_latest_geforce_driver + self._get_latest_drivers = gfe_get_driver.get_latest_geforce_driver - def get_latest_driver(self): - res = self._get_latest_driver(notebook=self._notebook, + def get_latest_drivers(self): + res = self._get_latest_drivers(notebook=self._notebook, x86_64=self._x86_64, os_version=self._os_version, os_build=self._os_build, @@ -187,6 +187,8 @@ class GFEClientChannel(BaseChannel): dch=self._dch, crd=self._crd, timeout=self._timeout) + if res is None: + return res.update({ 'ChannelAttributes': { 'Name': self.name, @@ -201,7 +203,7 @@ class GFEClientChannel(BaseChannel): 'Mobile': self._notebook, } }) - return res + yield res class NvidiaDownloadsChannel(BaseChannel): @@ -224,7 +226,7 @@ class NvidiaDownloadsChannel(BaseChannel): self._cuda_ver = gnd.CUDAToolkitVersion[cuda_ver] self._timeout = timeout - def get_latest_driver(self): + def get_latest_drivers(self): latest = self._gnd.get_drivers(os=self._os, product=self._product, certlevel=self._certlevel, @@ -233,7 +235,7 @@ class NvidiaDownloadsChannel(BaseChannel): cuda_ver=self._cuda_ver, timeout=self._timeout) if not latest: - return None + return res = { 'DriverAttributes': { 'Version': latest['version'], @@ -253,7 +255,7 @@ class NvidiaDownloadsChannel(BaseChannel): } if 'download_url' in latest: res['DriverAttributes']['DownloadURL'] = latest['download_url'] - return res + yield res class CudaToolkitDownloadsChannel(BaseChannel): @@ -264,11 +266,11 @@ class CudaToolkitDownloadsChannel(BaseChannel): self._gcd = gcd self._timeout = timeout - def get_latest_driver(self): + def get_latest_drivers(self): latest = self._gcd.get_latest_cuda_tk(timeout=self._timeout) if not latest: - return None - return { + return + yield { 'DriverAttributes': { 'Version': '???', 'Name': latest, @@ -289,11 +291,11 @@ class VulkanBetaDownloadsChannel(BaseChannel): self._os = os self._timeout = timeout - def get_latest_driver(self): + def get_latest_drivers(self): drivers = vulkan_downloads(timeout=self._timeout) for drv in drivers: if drv["os"] == self._os: - return { + yield { 'DriverAttributes': { 'Version': drv['version'], 'Name': drv['name'], @@ -301,7 +303,7 @@ class VulkanBetaDownloadsChannel(BaseChannel): } } else: - return None + return def parse_args(): @@ -389,26 +391,39 @@ class DriverLocator: def run(self): for ch in self._channels: + counter = 0 try: - drv = ch.get_latest_driver() + drivers = ch.get_latest_drivers() except Exception as e: - self._perror("get_latest_driver() invocation failed for " + self._perror("get_latest_drivers() invocation failed for " "channel %s. Exception: %s. Continuing..." % (repr(ch.name), str(e))) continue - if drv is None: - self._perror("Driver not found for channel %s" % - (repr(ch.name),)) - continue + try: - key = self._hasher.hash_object(drv) + # Fetch + for drv in drivers: + counter += 1 + # Hash + try: + key = self._hasher.hash_object(drv) + except Exception as e: + self._perror("Key evaluation failed for channel %s. " + "Exception: %s" % (repr(name), str(e))) + continue + + # Notify + if not self._db.check_key(key): + if self._notify_all(drv): + self._db.set_key(key, drv) except Exception as e: - self._perror("Key evaluation failed for channel %s. " - "Exception: %s" % (repr(name), str(e))) + self._perror("channel %s enumeration terminated with exception: %s" % + (repr(name), str(e))) continue - if not self._db.check_key(key): - if self._notify_all(drv): - self._db.set_key(key, drv) + + if not counter: + self._perror("Drivers not found for channel %s" % + (repr(ch.name),)) return self._ret_code