You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

210 lines
8.4 KiB
Python

#!/usr/bin/env python3
import sys
import argparse
import json
import os.path
import posixpath
from string import Template
from functools import partial
import urllib.request
from constants import OSKind, Product, WinSeries, DATAFILE_PATH, \
DRIVER_URL_TEMPLATE, DRIVER_DIR_PREFIX, BASE_PATH
from utils import find_driver, linux_driver_key, windows_driver_key
def parse_args():
def check_enum_arg(enum, value):
try:
return enum[value]
except KeyError:
raise argparse.ArgumentTypeError("%s is not valid option for %s" % (repr(value), repr(enum.__name__)))
parser = argparse.ArgumentParser(
description="Adds new Nvidia driver into drivers.json file of "
"in your repo working copy",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
os_options = parser.add_argument_group("OS options")
os_group=os_options.add_mutually_exclusive_group(required=True)
os_group.add_argument("-L", "--linux",
action="store_const",
dest="os",
const=OSKind.Linux,
help="add Linux driver")
os_group.add_argument("-W", "--win",
action="store_const",
dest="os",
const=OSKind.Windows,
help="add Windows driver")
win_opts = parser.add_argument_group("Windows-specific options")
win_opts.add_argument("--variant",
default="",
help="driver variant (use for special cases like "
"\"Studio Driver\")")
win_opts.add_argument("-P", "--product",
type=partial(check_enum_arg, Product),
choices=list(Product),
default=Product.GeForce,
help="product type")
win_opts.add_argument("-w", "--winseries",
type=partial(check_enum_arg, WinSeries),
choices=list(WinSeries),
default=WinSeries.win10,
help="Windows series")
win_opts.add_argument("--patch32",
default="${winseries}_x64/"
"${drvprefix}${version}/nvencodeapi.1337",
help="template for Windows 32bit patch URL")
win_opts.add_argument("--patch64",
default="${winseries}_x64/"
"${drvprefix}${version}/nvencodeapi64.1337",
help="template for Windows 64bit patch URL")
win_opts.add_argument("--skip-patch-check",
action="store_true",
help="skip patch files presense test")
parser.add_argument("-U", "--url",
help="override driver link")
parser.add_argument("--skip-url-check",
action="store_true",
help="skip driver URL check")
parser.add_argument("--no-fbc",
dest="fbc",
action="store_false",
help="add driver w/o NvFBC patch")
parser.add_argument("--no-enc",
dest="enc",
action="store_false",
help="add driver w/o NVENC patch")
parser.add_argument("version",
help="driver version")
args = parser.parse_args()
return args
def posixpath_components(path):
result = []
while True:
head, tail = posixpath.split(path)
if head == path:
break
result.append(tail)
path = head
result.reverse()
if result and not result[-1]:
result.pop()
return result
def validate_url(url):
req = urllib.request.Request(url, method="HEAD")
with urllib.request.urlopen(req, timeout=10) as resp:
if int(resp.headers['Content-Length']) < 50 * 2**20:
raise Exception("Bad driver length: %s" % resp.headers['Content-Length'])
def validate_patch(patch64, patch32):
wc_base = os.path.abspath(os.path.join(BASE_PATH, "..", "..", "win"))
p64_filepath = os.path.join(wc_base, patch64)
p32_filepath = os.path.join(wc_base, patch32)
if not os.path.exists(p64_filepath):
raise Exception("File %s not found!" % p64_filepath)
if not os.path.exists(p32_filepath):
raise Exception("File %s not found!" % p32_filepath)
if os.path.getsize(p64_filepath) == 0:
raise Exception("File %s empty!" % p64_filepath)
if os.path.exists(p32_filepath) == 0:
raise Exception("File %s empty!" % p32_filepath)
def validate_unique(drivers, new_driver, kf):
if find_driver(drivers, kf(new_driver), kf) is not None:
raise Exception("Duplicate driver!")
def main():
args = parse_args()
if not args.url:
if args.os is OSKind.Linux:
url_tmpl = DRIVER_URL_TEMPLATE[(args.os, None, None, None)]
else:
url_tmpl = DRIVER_URL_TEMPLATE[(args.os,
args.product,
args.winseries,
args.variant)]
if isinstance(url_tmpl, str):
url_tmpl = [url_tmpl]
urls = [Template(i).substitute(version=args.version) for i in url_tmpl if i]
else:
urls = [args.url]
url = ""
if urls and not args.skip_url_check:
last_exc = None
for url in urls:
try:
validate_url(url)
break
except KeyboardInterrupt:
raise
except Exception as exc:
last_exc = exc
else:
print("Driver URL validation failed with error: %s" % str(last_exc), file=sys.stderr)
print("Driver URL: %s" % ", ".join(urls), file=sys.stderr)
print("Please use option -U to override driver link manually", file=sys.stderr)
print("or use option --skip-url-check to submit incorrect URL.", file=sys.stderr)
return
if args.os is OSKind.Windows:
driver_dir_prefix = DRIVER_DIR_PREFIX[(args.product, args.variant)]
patch64_url = Template(args.patch64).substitute(winseries=args.winseries,
drvprefix=driver_dir_prefix,
version=args.version)
patch32_url = Template(args.patch32).substitute(winseries=args.winseries,
drvprefix=driver_dir_prefix,
version=args.version)
if not args.skip_patch_check:
try:
validate_patch(patch64_url, patch32_url)
except KeyboardInterrupt:
raise
except Exception as exc:
print("Driver patch validation failed with error: %s" % str(exc), file=sys.stderr)
print("Use options --patch64 and --patch32 to override patch path ", file=sys.stderr)
print("template or use option --skip-patch-check to submit driver with ", file=sys.stderr)
print("missing patch files.", file=sys.stderr)
return
with open(DATAFILE_PATH) as data_file:
data = json.load(data_file)
drivers = data[args.os.value]['x86_64']['drivers']
if args.os is OSKind.Windows:
new_driver = {
"os": str(args.winseries),
"product": str(args.product),
"version": args.version,
"variant": args.variant,
"patch64_url": patch64_url,
"patch32_url": patch32_url,
"driver_url": url,
}
key_fun = windows_driver_key
else:
new_driver = {
"version": args.version,
"nvenc_patch": args.enc,
"nvfbc_patch": args.fbc,
}
if url:
new_driver["driver_url"] = url
key_fun = linux_driver_key
drivers = sorted(drivers, key=key_fun)
try:
validate_unique(drivers, new_driver, key_fun)
except KeyboardInterrupt:
raise
except Exception as exc:
print("Driver uniqueness validation failed with error: %s" % str(exc), file=sys.stderr)
return
data[args.os.value]['x86_64']['drivers'].append(new_driver)
with open(DATAFILE_PATH, 'w') as data_file:
json.dump(data, data_file, indent=4)
data_file.write('\n')
if __name__ == '__main__':
main()