diff --git a/.github/workflows/check-style.yaml b/.github/workflows/check-style.yaml new file mode 100644 index 0000000..29a0f82 --- /dev/null +++ b/.github/workflows/check-style.yaml @@ -0,0 +1,26 @@ +name: Check style + +on: + push: + branches: [ master ] + pull_request: + +jobs: + black: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: psf/black@stable + with: + options: "--check --diff" + version: "22.3.0" + isort: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.8 + - uses: isort/isort-action@master + with: + isortVersion: "5.10.1" diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml new file mode 100644 index 0000000..e6bb8ea --- /dev/null +++ b/.github/workflows/run-tests.yaml @@ -0,0 +1,89 @@ +name: Tests + +on: + push: + branches: [ master ] + pull_request: + +jobs: + convert-model: + runs-on: ubuntu-latest + env: + BLOOM_TESTING_WRITE_TOKEN: ${{ secrets.BLOOM_TESTING_WRITE_TOKEN }} + timeout-minutes: 15 + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Cache dependencies + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: Key-v1-py3.9-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Delete previous model, if exists + run: | + python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \ + name='test-bloomd-350m-$GITHUB_HEAD_REF', organization='bloom-testing')" || true + - name: Convert model and push to hub + run: | + python -m cli.convert_model --model bigscience/bloom-350m --output_path ./converted_model \ + --output_repo bloom-testing/test-bloomd-350m-$GITHUB_HEAD_REF --use_auth_token $BLOOM_TESTING_WRITE_TOKEN + + + run-tests: + runs-on: ubuntu-latest + needs: convert-model + strategy: + matrix: + python-version: [ 3.7, 3.8, 3.9 ] + fail-fast: false + timeout-minutes: 15 + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Cache dependencies + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -r requirements-dev.txt + - name: Test + run: | + export MODEL_NAME=bloom-testing/test-bloomd-350m-$GITHUB_HEAD_REF + python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ + --torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 & + SERVER1_PID=$! + + export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g + # ^-- server 1 multiaddr is determined by --identity and --host_maddrs + + python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:24 \ + --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log & + SERVER2_PID=$! + + sleep 30 # wait for server to download layers + + # test individual blocks + export PYTHONPATH=. + BLOCK_UID=$MODEL_NAME.0 REF_NAME=$MODEL_NAME REF_INDEX=0 pytest tests/test_block_exact_match.py + BLOCK_UID=$MODEL_NAME.19 REF_NAME=$MODEL_NAME REF_INDEX=19 pytest tests/test_block_exact_match.py + + REF_NAME=$MODEL_NAME pytest tests/test_chained_calls.py + + REF_NAME=bigscience/bloom-350m pytest tests/test_full_model.py + + kill -s SIGINT $SERVER1_PID $SERVER2_PID + echo "Done!" diff --git a/cli/convert_model.py b/cli/convert_model.py index a7ffd63..0864a5f 100644 --- a/cli/convert_model.py +++ b/cli/convert_model.py @@ -10,8 +10,9 @@ from huggingface_hub import Repository from tqdm.auto import tqdm from src import BloomModel +from src.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH from src.client import DistributedBloomConfig -from src.bloom.from_pretrained import CLIENT_BRANCH, BLOCK_BRANCH_PREFIX + use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/cli/speed_test.py b/cli/speed_test.py index 186b529..d3342c3 100755 --- a/cli/speed_test.py +++ b/cli/speed_test.py @@ -31,12 +31,13 @@ import xml.parsers.expat try: import gzip + GZIP_BASE = gzip.GzipFile except ImportError: gzip = None GZIP_BASE = object -__version__ = '2.1.4b1' +__version__ = "2.1.4b1" class FakeShutdownEvent(object): @@ -46,7 +47,7 @@ class FakeShutdownEvent(object): @staticmethod def isSet(): - "Dummy method to always return false""" + "Dummy method to always return false" "" return False is_set = isSet @@ -71,6 +72,7 @@ except ImportError: try: import xml.etree.ElementTree as ET + try: from xml.etree.ElementTree import _Element as ET_Element except ImportError: @@ -78,23 +80,40 @@ try: except ImportError: from xml.dom import minidom as DOM from xml.parsers.expat import ExpatError + ET = None try: - from urllib2 import (urlopen, Request, HTTPError, URLError, - AbstractHTTPHandler, ProxyHandler, - HTTPDefaultErrorHandler, HTTPRedirectHandler, - HTTPErrorProcessor, OpenerDirector) + from urllib2 import ( + AbstractHTTPHandler, + HTTPDefaultErrorHandler, + HTTPError, + HTTPErrorProcessor, + HTTPRedirectHandler, + OpenerDirector, + ProxyHandler, + Request, + URLError, + urlopen, + ) except ImportError: - from urllib.request import (urlopen, Request, HTTPError, URLError, - AbstractHTTPHandler, ProxyHandler, - HTTPDefaultErrorHandler, HTTPRedirectHandler, - HTTPErrorProcessor, OpenerDirector) + from urllib.request import ( + AbstractHTTPHandler, + HTTPDefaultErrorHandler, + HTTPError, + HTTPErrorProcessor, + HTTPRedirectHandler, + OpenerDirector, + ProxyHandler, + Request, + URLError, + urlopen, + ) try: - from httplib import HTTPConnection, BadStatusLine + from httplib import BadStatusLine, HTTPConnection except ImportError: - from http.client import HTTPConnection, BadStatusLine + from http.client import BadStatusLine, HTTPConnection try: from httplib import HTTPSConnection @@ -133,51 +152,50 @@ except ImportError: from md5 import md5 try: - from argparse import ArgumentParser as ArgParser - from argparse import SUPPRESS as ARG_SUPPRESS + from argparse import SUPPRESS as ARG_SUPPRESS, ArgumentParser as ArgParser + PARSER_TYPE_INT = int PARSER_TYPE_STR = str PARSER_TYPE_FLOAT = float except ImportError: - from optparse import OptionParser as ArgParser - from optparse import SUPPRESS_HELP as ARG_SUPPRESS - PARSER_TYPE_INT = 'int' - PARSER_TYPE_STR = 'string' - PARSER_TYPE_FLOAT = 'float' + from optparse import SUPPRESS_HELP as ARG_SUPPRESS, OptionParser as ArgParser + + PARSER_TYPE_INT = "int" + PARSER_TYPE_STR = "string" + PARSER_TYPE_FLOAT = "float" try: from cStringIO import StringIO + BytesIO = None except ImportError: try: from StringIO import StringIO + BytesIO = None except ImportError: - from io import StringIO, BytesIO + from io import BytesIO, StringIO try: import __builtin__ except ImportError: import builtins - from io import TextIOWrapper, FileIO + from io import FileIO, TextIOWrapper class _Py3Utf8Output(TextIOWrapper): """UTF-8 encoded wrapper around stdout for py3, to override ASCII stdout """ + def __init__(self, f, **kwargs): - buf = FileIO(f.fileno(), 'w') - super(_Py3Utf8Output, self).__init__( - buf, - encoding='utf8', - errors='strict' - ) + buf = FileIO(f.fileno(), "w") + super(_Py3Utf8Output, self).__init__(buf, encoding="utf8", errors="strict") def write(self, s): super(_Py3Utf8Output, self).write(s) self.flush() - _py3_print = getattr(builtins, 'print') + _py3_print = getattr(builtins, "print") try: _py3_utf8_stdout = _Py3Utf8Output(sys.stdout) _py3_utf8_stderr = _Py3Utf8Output(sys.stderr) @@ -193,18 +211,19 @@ except ImportError: def print_(*args, **kwargs): """Wrapper function for py3 to print, with a utf-8 encoded stdout""" - if kwargs.get('file') == sys.stderr: - kwargs['file'] = _py3_utf8_stderr + if kwargs.get("file") == sys.stderr: + kwargs["file"] = _py3_utf8_stderr else: - kwargs['file'] = kwargs.get('file', _py3_utf8_stdout) + kwargs["file"] = kwargs.get("file", _py3_utf8_stdout) _py3_print(*args, **kwargs) + else: del __builtin__ def to_utf8(v): """Encode value to utf-8 if possible for py2""" try: - return v.encode('utf8', 'strict') + return v.encode("utf8", "strict") except AttributeError: return v @@ -223,16 +242,15 @@ else: if not isinstance(data, basestring): data = str(data) # If the file has an encoding, encode unicode with it. - encoding = 'utf8' # Always trust UTF-8 for output - if (isinstance(fp, file) and - isinstance(data, unicode) and - encoding is not None): + encoding = "utf8" # Always trust UTF-8 for output + if isinstance(fp, file) and isinstance(data, unicode) and encoding is not None: errors = getattr(fp, "errors", None) if errors is None: errors = "strict" data = data.encode(encoding, errors) fp.write(data) fp.flush() + want_unicode = False sep = kwargs.pop("sep", None) if sep is not None: @@ -269,18 +287,17 @@ else: write(arg) write(end) + # Exception "constants" to support Python 2 through Python 3 try: import ssl + try: CERT_ERROR = (ssl.CertificateError,) except AttributeError: CERT_ERROR = tuple() - HTTP_ERRORS = ( - (HTTPError, URLError, socket.error, ssl.SSLError, BadStatusLine) + - CERT_ERROR - ) + HTTP_ERRORS = (HTTPError, URLError, socket.error, ssl.SSLError, BadStatusLine) + CERT_ERROR except ImportError: ssl = None HTTP_ERRORS = (HTTPError, URLError, socket.error, BadStatusLine) @@ -373,8 +390,7 @@ class SpeedtestMissingBestServer(SpeedtestException): """get_best_server not called or not able to determine best server""" -def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, - source_address=None): +def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None): """Connect to *address* and return the socket object. Convenience function. Connect to *address* (a 2-tuple ``(host, @@ -418,9 +434,10 @@ class SpeedtestHTTPConnection(HTTPConnection): """Custom HTTPConnection to support source_address across Python 2.4 - Python 3 """ + def __init__(self, *args, **kwargs): - source_address = kwargs.pop('source_address', None) - timeout = kwargs.pop('timeout', 10) + source_address = kwargs.pop("source_address", None) + timeout = kwargs.pop("timeout", 10) self._tunnel_host = None @@ -432,32 +449,26 @@ class SpeedtestHTTPConnection(HTTPConnection): def connect(self): """Connect to the host and port specified in __init__.""" try: - self.sock = socket.create_connection( - (self.host, self.port), - self.timeout, - self.source_address - ) + self.sock = socket.create_connection((self.host, self.port), self.timeout, self.source_address) except (AttributeError, TypeError): - self.sock = create_connection( - (self.host, self.port), - self.timeout, - self.source_address - ) + self.sock = create_connection((self.host, self.port), self.timeout, self.source_address) if self._tunnel_host: self._tunnel() if HTTPSConnection: + class SpeedtestHTTPSConnection(HTTPSConnection): """Custom HTTPSConnection to support source_address across Python 2.4 - Python 3 """ + default_port = 443 def __init__(self, *args, **kwargs): - source_address = kwargs.pop('source_address', None) - timeout = kwargs.pop('timeout', 10) + source_address = kwargs.pop("source_address", None) + timeout = kwargs.pop("timeout", 10) self._tunnel_host = None @@ -469,17 +480,9 @@ if HTTPSConnection: def connect(self): "Connect to a host on a given (SSL) port." try: - self.sock = socket.create_connection( - (self.host, self.port), - self.timeout, - self.source_address - ) + self.sock = socket.create_connection((self.host, self.port), self.timeout, self.source_address) except (AttributeError, TypeError): - self.sock = create_connection( - (self.host, self.port), - self.timeout, - self.source_address - ) + self.sock = create_connection((self.host, self.port), self.timeout, self.source_address) if self._tunnel_host: self._tunnel() @@ -487,11 +490,11 @@ if HTTPSConnection: if ssl: try: kwargs = {} - if hasattr(ssl, 'SSLContext'): + if hasattr(ssl, "SSLContext"): if self._tunnel_host: - kwargs['server_hostname'] = self._tunnel_host + kwargs["server_hostname"] = self._tunnel_host else: - kwargs['server_hostname'] = self.host + kwargs["server_hostname"] = self.host self.sock = self._context.wrap_socket(self.sock, **kwargs) except AttributeError: self.sock = ssl.wrap_socket(self.sock) @@ -504,15 +507,9 @@ if HTTPSConnection: try: self.sock = FakeSocket(self.sock, socket.ssl(self.sock)) except AttributeError: - raise SpeedtestException( - 'This version of Python does not support HTTPS/SSL ' - 'functionality' - ) + raise SpeedtestException("This version of Python does not support HTTPS/SSL " "functionality") else: - raise SpeedtestException( - 'This version of Python does not support HTTPS/SSL ' - 'functionality' - ) + raise SpeedtestException("This version of Python does not support HTTPS/SSL " "functionality") def _build_connection(connection, source_address, timeout, context=None): @@ -522,14 +519,13 @@ def _build_connection(connection, source_address, timeout, context=None): Called from ``http(s)_open`` methods of ``SpeedtestHTTPHandler`` or ``SpeedtestHTTPSHandler`` """ + def inner(host, **kwargs): - kwargs.update({ - 'source_address': source_address, - 'timeout': timeout - }) + kwargs.update({"source_address": source_address, "timeout": timeout}) if context: - kwargs['context'] = context + kwargs["context"] = context return connection(host, **kwargs) + return inner @@ -537,20 +533,14 @@ class SpeedtestHTTPHandler(AbstractHTTPHandler): """Custom ``HTTPHandler`` that can build a ``HTTPConnection`` with the args we need for ``source_address`` and ``timeout`` """ + def __init__(self, debuglevel=0, source_address=None, timeout=10): AbstractHTTPHandler.__init__(self, debuglevel) self.source_address = source_address self.timeout = timeout def http_open(self, req): - return self.do_open( - _build_connection( - SpeedtestHTTPConnection, - self.source_address, - self.timeout - ), - req - ) + return self.do_open(_build_connection(SpeedtestHTTPConnection, self.source_address, self.timeout), req) http_request = AbstractHTTPHandler.do_request_ @@ -559,8 +549,8 @@ class SpeedtestHTTPSHandler(AbstractHTTPHandler): """Custom ``HTTPSHandler`` that can build a ``HTTPSConnection`` with the args we need for ``source_address`` and ``timeout`` """ - def __init__(self, debuglevel=0, context=None, source_address=None, - timeout=10): + + def __init__(self, debuglevel=0, context=None, source_address=None, timeout=10): AbstractHTTPHandler.__init__(self, debuglevel) self._context = context self.source_address = source_address @@ -574,7 +564,7 @@ class SpeedtestHTTPSHandler(AbstractHTTPHandler): self.timeout, context=self._context, ), - req + req, ) https_request = AbstractHTTPHandler.do_request_ @@ -587,28 +577,25 @@ def build_opener(source_address=None, timeout=10): `User-Agent` """ - printer('Timeout set to %d' % timeout, debug=True) + printer("Timeout set to %d" % timeout, debug=True) if source_address: source_address_tuple = (source_address, 0) - printer('Binding to source address: %r' % (source_address_tuple,), - debug=True) + printer("Binding to source address: %r" % (source_address_tuple,), debug=True) else: source_address_tuple = None handlers = [ ProxyHandler(), - SpeedtestHTTPHandler(source_address=source_address_tuple, - timeout=timeout), - SpeedtestHTTPSHandler(source_address=source_address_tuple, - timeout=timeout), + SpeedtestHTTPHandler(source_address=source_address_tuple, timeout=timeout), + SpeedtestHTTPSHandler(source_address=source_address_tuple, timeout=timeout), HTTPDefaultErrorHandler(), HTTPRedirectHandler(), - HTTPErrorProcessor() + HTTPErrorProcessor(), ] opener = OpenerDirector() - opener.addheaders = [('User-agent', build_user_agent())] + opener.addheaders = [("User-agent", build_user_agent())] for handler in handlers: opener.add_handler(handler) @@ -623,12 +610,12 @@ class GzipDecodedResponse(GZIP_BASE): Largely copied from ``xmlrpclib``/``xmlrpc.client`` and modified to work for py2.4-py3 """ + def __init__(self, response): # response doesn't support tell() and read(), required by # GzipFile if not gzip: - raise SpeedtestHTTPError('HTTP response body is gzip encoded, ' - 'but gzip support is not available') + raise SpeedtestHTTPError("HTTP response body is gzip encoded, " "but gzip support is not available") IO = BytesIO or StringIO self.io = IO() while 1: @@ -637,7 +624,7 @@ class GzipDecodedResponse(GZIP_BASE): break self.io.write(chunk) self.io.seek(0) - gzip.GzipFile.__init__(self, mode='rb', fileobj=self.io) + gzip.GzipFile.__init__(self, mode="rb", fileobj=self.io) def close(self): try: @@ -662,10 +649,9 @@ def distance(origin, destination): dlat = math.radians(lat2 - lat1) dlon = math.radians(lon2 - lon1) - a = (math.sin(dlat / 2) * math.sin(dlat / 2) + - math.cos(math.radians(lat1)) * - math.cos(math.radians(lat2)) * math.sin(dlon / 2) * - math.sin(dlon / 2)) + a = math.sin(dlat / 2) * math.sin(dlat / 2) + math.cos(math.radians(lat1)) * math.cos( + math.radians(lat2) + ) * math.sin(dlon / 2) * math.sin(dlon / 2) c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) d = radius * c @@ -676,19 +662,18 @@ def build_user_agent(): """Build a Mozilla/5.0 compatible User-Agent string""" ua_tuple = ( - 'Mozilla/5.0', - '(%s; U; %s; en-us)' % (platform.platform(), - platform.architecture()[0]), - 'Python/%s' % platform.python_version(), - '(KHTML, like Gecko)', - 'speedtest-cli/%s' % __version__ + "Mozilla/5.0", + "(%s; U; %s; en-us)" % (platform.platform(), platform.architecture()[0]), + "Python/%s" % platform.python_version(), + "(KHTML, like Gecko)", + "speedtest-cli/%s" % __version__, ) - user_agent = ' '.join(ua_tuple) - printer('User-Agent: %s' % user_agent, debug=True) + user_agent = " ".join(ua_tuple) + printer("User-Agent: %s" % user_agent, debug=True) return user_agent -def build_request(url, data=None, headers=None, bump='0', secure=False): +def build_request(url, data=None, headers=None, bump="0", secure=False): """Build a urllib2 request object This function automatically adds a User-Agent header to all requests @@ -698,28 +683,27 @@ def build_request(url, data=None, headers=None, bump='0', secure=False): if not headers: headers = {} - if url[0] == ':': - scheme = ('http', 'https')[bool(secure)] - schemed_url = '%s%s' % (scheme, url) + if url[0] == ":": + scheme = ("http", "https")[bool(secure)] + schemed_url = "%s%s" % (scheme, url) else: schemed_url = url - if '?' in url: - delim = '&' + if "?" in url: + delim = "&" else: - delim = '?' + delim = "?" # WHO YOU GONNA CALL? CACHE BUSTERS! - final_url = '%s%sx=%s.%s' % (schemed_url, delim, - int(timeit.time.time() * 1000), - bump) + final_url = "%s%sx=%s.%s" % (schemed_url, delim, int(timeit.time.time() * 1000), bump) - headers.update({ - 'Cache-Control': 'no-cache', - }) + headers.update( + { + "Cache-Control": "no-cache", + } + ) - printer('%s %s' % (('GET', 'POST')[bool(data)], final_url), - debug=True) + printer("%s %s" % (("GET", "POST")[bool(data)], final_url), debug=True) return Request(final_url, data=data, headers=headers) @@ -738,7 +722,7 @@ def catch_request(request, opener=None): try: uh = _open(request) if request.get_full_url() != uh.geturl(): - printer('Redirected to %s' % uh.geturl(), debug=True) + printer("Redirected to %s" % uh.geturl(), debug=True) return uh, False except HTTP_ERRORS: e = get_exception() @@ -756,7 +740,7 @@ def get_response_stream(response): except AttributeError: getheader = response.getheader - if getheader('content-encoding') == 'gzip': + if getheader("content-encoding") == "gzip": return GzipDecodedResponse(response) return response @@ -777,14 +761,16 @@ def print_dots(shutdown_event): """Built in callback function used by Thread classes for printing status """ + def inner(current, total, start=False, end=False): if event_is_set(shutdown_event): return - sys.stdout.write('.') + sys.stdout.write(".") if current + 1 == total and end is True: - sys.stdout.write('\n') + sys.stdout.write("\n") sys.stdout.flush() + return inner @@ -795,8 +781,7 @@ def do_nothing(*args, **kwargs): class HTTPDownloader(threading.Thread): """Thread class for retrieving a URL""" - def __init__(self, i, request, start, timeout, opener=None, - shutdown_event=None): + def __init__(self, i, request, start, timeout, opener=None, shutdown_event=None): threading.Thread.__init__(self) self.request = request self.result = [0] @@ -817,9 +802,9 @@ class HTTPDownloader(threading.Thread): try: if (timeit.default_timer() - self.starttime) <= self.timeout: f = self._opener(self.request) - while (not event_is_set(self._shutdown_event) and - (timeit.default_timer() - self.starttime) <= - self.timeout): + while ( + not event_is_set(self._shutdown_event) and (timeit.default_timer() - self.starttime) <= self.timeout + ): self.result.append(len(f.read(10240))) if self.result[-1] == 0: break @@ -850,20 +835,13 @@ class HTTPUploaderData(object): self.total = [0] def pre_allocate(self): - chars = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ' + chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" multiplier = int(round(int(self.length) / 36.0)) IO = BytesIO or StringIO try: - self._data = IO( - ('content1=%s' % - (chars * multiplier)[0:int(self.length) - 9] - ).encode() - ) + self._data = IO(("content1=%s" % (chars * multiplier)[0 : int(self.length) - 9]).encode()) except MemoryError: - raise SpeedtestCLIError( - 'Insufficient memory to pre-allocate upload data. Please ' - 'use --no-pre-allocate' - ) + raise SpeedtestCLIError("Insufficient memory to pre-allocate upload data. Please " "use --no-pre-allocate") @property def data(self): @@ -872,8 +850,7 @@ class HTTPUploaderData(object): return self._data def read(self, n=10240): - if ((timeit.default_timer() - self.start) <= self.timeout and - not event_is_set(self._shutdown_event)): + if (timeit.default_timer() - self.start) <= self.timeout and not event_is_set(self._shutdown_event): chunk = self.data.read(n) self.total.append(len(chunk)) return chunk @@ -887,8 +864,7 @@ class HTTPUploaderData(object): class HTTPUploader(threading.Thread): """Thread class for putting a URL""" - def __init__(self, i, request, start, size, timeout, opener=None, - shutdown_event=None): + def __init__(self, i, request, start, size, timeout, opener=None, shutdown_event=None): threading.Thread.__init__(self) self.request = request self.request.data.start = self.starttime = start @@ -910,16 +886,14 @@ class HTTPUploader(threading.Thread): def run(self): request = self.request try: - if ((timeit.default_timer() - self.starttime) <= self.timeout and - not event_is_set(self._shutdown_event)): + if (timeit.default_timer() - self.starttime) <= self.timeout and not event_is_set(self._shutdown_event): try: f = self._opener(request) except TypeError: # PY24 expects a string or buffer # This also causes issues with Ctrl-C, but we will concede # for the moment that Ctrl-C on PY24 isn't immediate - request = build_request(self.request.get_full_url(), - data=request.data.read(self.size)) + request = build_request(self.request.get_full_url(), data=request.data.read(self.size)) f = self._opener(request) f.read(11) f.close() @@ -945,8 +919,7 @@ class SpeedtestResults(object): to get a share results image link. """ - def __init__(self, download=0, upload=0, ping=0, server=None, client=None, - opener=None, secure=False): + def __init__(self, download=0, upload=0, ping=0, server=None, client=None, opener=None, secure=False): self.download = download self.upload = upload self.ping = ping @@ -957,7 +930,7 @@ class SpeedtestResults(object): self.client = client or {} self._share = None - self.timestamp = '%sZ' % datetime.datetime.utcnow().isoformat() + self.timestamp = "%sZ" % datetime.datetime.utcnow().isoformat() self.bytes_received = 0 self.bytes_sent = 0 @@ -987,29 +960,27 @@ class SpeedtestResults(object): # We use a list instead of a dict because the API expects parameters # in a certain order api_data = [ - 'recommendedserverid=%s' % self.server['id'], - 'ping=%s' % ping, - 'screenresolution=', - 'promo=', - 'download=%s' % download, - 'screendpi=', - 'upload=%s' % upload, - 'testmethod=http', - 'hash=%s' % md5(('%s-%s-%s-%s' % - (ping, upload, download, '297aae72')) - .encode()).hexdigest(), - 'touchscreen=none', - 'startmode=pingselect', - 'accuracy=1', - 'bytesreceived=%s' % self.bytes_received, - 'bytessent=%s' % self.bytes_sent, - 'serverid=%s' % self.server['id'], + "recommendedserverid=%s" % self.server["id"], + "ping=%s" % ping, + "screenresolution=", + "promo=", + "download=%s" % download, + "screendpi=", + "upload=%s" % upload, + "testmethod=http", + "hash=%s" % md5(("%s-%s-%s-%s" % (ping, upload, download, "297aae72")).encode()).hexdigest(), + "touchscreen=none", + "startmode=pingselect", + "accuracy=1", + "bytesreceived=%s" % self.bytes_received, + "bytessent=%s" % self.bytes_sent, + "serverid=%s" % self.server["id"], ] - headers = {'Referer': 'http://c.speedtest.net/flash/speedtest.swf'} - request = build_request('://www.speedtest.net/api/api.php', - data='&'.join(api_data).encode(), - headers=headers, secure=self._secure) + headers = {"Referer": "http://c.speedtest.net/flash/speedtest.swf"} + request = build_request( + "://www.speedtest.net/api/api.php", data="&".join(api_data).encode(), headers=headers, secure=self._secure + ) f, e = catch_request(request, opener=self._opener) if e: raise ShareResultsConnectFailure(e) @@ -1019,16 +990,14 @@ class SpeedtestResults(object): f.close() if int(code) != 200: - raise ShareResultsSubmitFailure('Could not submit results to ' - 'speedtest.net') + raise ShareResultsSubmitFailure("Could not submit results to " "speedtest.net") qsargs = parse_qs(response.decode()) - resultid = qsargs.get('resultid') + resultid = qsargs.get("resultid") if not resultid or len(resultid) != 1: - raise ShareResultsSubmitFailure('Could not submit results to ' - 'speedtest.net') + raise ShareResultsSubmitFailure("Could not submit results to " "speedtest.net") - self._share = 'http://www.speedtest.net/result/%s.png' % resultid[0] + self._share = "http://www.speedtest.net/result/%s.png" % resultid[0] return self._share @@ -1036,38 +1005,56 @@ class SpeedtestResults(object): """Return dictionary of result data""" return { - 'download': self.download, - 'upload': self.upload, - 'ping': self.ping, - 'server': self.server, - 'timestamp': self.timestamp, - 'bytes_sent': self.bytes_sent, - 'bytes_received': self.bytes_received, - 'share': self._share, - 'client': self.client, + "download": self.download, + "upload": self.upload, + "ping": self.ping, + "server": self.server, + "timestamp": self.timestamp, + "bytes_sent": self.bytes_sent, + "bytes_received": self.bytes_received, + "share": self._share, + "client": self.client, } @staticmethod - def csv_header(delimiter=','): + def csv_header(delimiter=","): """Return CSV Headers""" - row = ['Server ID', 'Sponsor', 'Server Name', 'Timestamp', 'Distance', - 'Ping', 'Download', 'Upload', 'Share', 'IP Address'] + row = [ + "Server ID", + "Sponsor", + "Server Name", + "Timestamp", + "Distance", + "Ping", + "Download", + "Upload", + "Share", + "IP Address", + ] out = StringIO() - writer = csv.writer(out, delimiter=delimiter, lineterminator='') + writer = csv.writer(out, delimiter=delimiter, lineterminator="") writer.writerow([to_utf8(v) for v in row]) return out.getvalue() - def csv(self, delimiter=','): + def csv(self, delimiter=","): """Return data in CSV format""" data = self.dict() out = StringIO() - writer = csv.writer(out, delimiter=delimiter, lineterminator='') - row = [data['server']['id'], data['server']['sponsor'], - data['server']['name'], data['timestamp'], - data['server']['d'], data['ping'], data['download'], - data['upload'], self._share or '', self.client['ip']] + writer = csv.writer(out, delimiter=delimiter, lineterminator="") + row = [ + data["server"]["id"], + data["server"]["sponsor"], + data["server"]["name"], + data["timestamp"], + data["server"]["d"], + data["ping"], + data["download"], + data["upload"], + self._share or "", + self.client["ip"], + ] writer.writerow([to_utf8(v) for v in row]) return out.getvalue() @@ -1076,18 +1063,14 @@ class SpeedtestResults(object): kwargs = {} if pretty: - kwargs.update({ - 'indent': 4, - 'sort_keys': True - }) + kwargs.update({"indent": 4, "sort_keys": True}) return json.dumps(self.dict(), **kwargs) class Speedtest(object): """Class for performing standard speedtest.net testing operations""" - def __init__(self, config=None, source_address=None, timeout=10, - secure=False, shutdown_event=None): + def __init__(self, config=None, source_address=None, timeout=10, secure=False, shutdown_event=None): self.config = {} self._source_address = source_address @@ -1110,7 +1093,7 @@ class Speedtest(object): self._best = {} self.results = SpeedtestResults( - client=self.config['client'], + client=self.config["client"], opener=self._opener, secure=secure, ) @@ -1128,9 +1111,8 @@ class Speedtest(object): headers = {} if gzip: - headers['Accept-Encoding'] = 'gzip' - request = build_request('://www.speedtest.net/speedtest-config.php', - headers=headers, secure=self._secure) + headers["Accept-Encoding"] = "gzip" + request = build_request("://www.speedtest.net/speedtest-config.php", headers=headers, secure=self._secure) uh, e = catch_request(request, opener=self._opener) if e: raise ConfigRetrievalError(e) @@ -1151,89 +1133,69 @@ class Speedtest(object): if int(uh.code) != 200: return None - configxml = ''.encode().join(configxml_list) + configxml = "".encode().join(configxml_list) - printer('Config XML:\n%s' % configxml, debug=True) + printer("Config XML:\n%s" % configxml, debug=True) try: try: root = ET.fromstring(configxml) except ET.ParseError: e = get_exception() - raise SpeedtestConfigError( - 'Malformed speedtest.net configuration: %s' % e - ) - server_config = root.find('server-config').attrib - download = root.find('download').attrib - upload = root.find('upload').attrib + raise SpeedtestConfigError("Malformed speedtest.net configuration: %s" % e) + server_config = root.find("server-config").attrib + download = root.find("download").attrib + upload = root.find("upload").attrib # times = root.find('times').attrib - client = root.find('client').attrib + client = root.find("client").attrib except AttributeError: try: root = DOM.parseString(configxml) except ExpatError: e = get_exception() - raise SpeedtestConfigError( - 'Malformed speedtest.net configuration: %s' % e - ) - server_config = get_attributes_by_tag_name(root, 'server-config') - download = get_attributes_by_tag_name(root, 'download') - upload = get_attributes_by_tag_name(root, 'upload') + raise SpeedtestConfigError("Malformed speedtest.net configuration: %s" % e) + server_config = get_attributes_by_tag_name(root, "server-config") + download = get_attributes_by_tag_name(root, "download") + upload = get_attributes_by_tag_name(root, "upload") # times = get_attributes_by_tag_name(root, 'times') - client = get_attributes_by_tag_name(root, 'client') + client = get_attributes_by_tag_name(root, "client") - ignore_servers = [ - int(i) for i in server_config['ignoreids'].split(',') if i - ] + ignore_servers = [int(i) for i in server_config["ignoreids"].split(",") if i] - ratio = int(upload['ratio']) - upload_max = int(upload['maxchunkcount']) + ratio = int(upload["ratio"]) + upload_max = int(upload["maxchunkcount"]) up_sizes = [32768, 65536, 131072, 262144, 524288, 1048576, 7340032] - sizes = { - 'upload': up_sizes[ratio - 1:], - 'download': [350, 500, 750, 1000, 1500, 2000, 2500, - 3000, 3500, 4000] - } + sizes = {"upload": up_sizes[ratio - 1 :], "download": [350, 500, 750, 1000, 1500, 2000, 2500, 3000, 3500, 4000]} - size_count = len(sizes['upload']) + size_count = len(sizes["upload"]) upload_count = int(math.ceil(upload_max / size_count)) - counts = { - 'upload': upload_count, - 'download': int(download['threadsperurl']) - } + counts = {"upload": upload_count, "download": int(download["threadsperurl"])} - threads = { - 'upload': int(upload['threads']), - 'download': int(server_config['threadcount']) * 2 - } + threads = {"upload": int(upload["threads"]), "download": int(server_config["threadcount"]) * 2} - length = { - 'upload': int(upload['testlength']), - 'download': int(download['testlength']) - } + length = {"upload": int(upload["testlength"]), "download": int(download["testlength"])} - self.config.update({ - 'client': client, - 'ignore_servers': ignore_servers, - 'sizes': sizes, - 'counts': counts, - 'threads': threads, - 'length': length, - 'upload_max': upload_count * size_count - }) + self.config.update( + { + "client": client, + "ignore_servers": ignore_servers, + "sizes": sizes, + "counts": counts, + "threads": threads, + "length": length, + "upload_max": upload_count * size_count, + } + ) try: - self.lat_lon = (float(client['lat']), float(client['lon'])) + self.lat_lon = (float(client["lat"]), float(client["lon"])) except ValueError: - raise SpeedtestConfigError( - 'Unknown location: lat=%r lon=%r' % - (client.get('lat'), client.get('lon')) - ) + raise SpeedtestConfigError("Unknown location: lat=%r lon=%r" % (client.get("lat"), client.get("lon"))) - printer('Config:\n%r' % self.config, debug=True) + printer("Config:\n%r" % self.config, debug=True) return self.config @@ -1254,33 +1216,28 @@ class Speedtest(object): try: server_list[i] = int(s) except ValueError: - raise InvalidServerIDType( - '%s is an invalid server type, must be int' % s - ) + raise InvalidServerIDType("%s is an invalid server type, must be int" % s) urls = [ - '://www.speedtest.net/speedtest-servers-static.php', - 'http://c.speedtest.net/speedtest-servers-static.php', - '://www.speedtest.net/speedtest-servers.php', - 'http://c.speedtest.net/speedtest-servers.php', + "://www.speedtest.net/speedtest-servers-static.php", + "http://c.speedtest.net/speedtest-servers-static.php", + "://www.speedtest.net/speedtest-servers.php", + "http://c.speedtest.net/speedtest-servers.php", ] headers = {} if gzip: - headers['Accept-Encoding'] = 'gzip' + headers["Accept-Encoding"] = "gzip" errors = [] for url in urls: try: request = build_request( - '%s?threads=%s' % (url, - self.config['threads']['download']), - headers=headers, - secure=self._secure + "%s?threads=%s" % (url, self.config["threads"]["download"]), headers=headers, secure=self._secure ) uh, e = catch_request(request, opener=self._opener) if e: - errors.append('%s' % e) + errors.append("%s" % e) raise ServersRetrievalError() stream = get_response_stream(uh) @@ -1300,9 +1257,9 @@ class Speedtest(object): if int(uh.code) != 200: raise ServersRetrievalError() - serversxml = ''.encode().join(serversxml_list) + serversxml = "".encode().join(serversxml_list) - printer('Servers XML:\n%s' % serversxml, debug=True) + printer("Servers XML:\n%s" % serversxml, debug=True) try: try: @@ -1310,19 +1267,15 @@ class Speedtest(object): root = ET.fromstring(serversxml) except ET.ParseError: e = get_exception() - raise SpeedtestServersError( - 'Malformed speedtest.net server list: %s' % e - ) - elements = etree_iter(root, 'server') + raise SpeedtestServersError("Malformed speedtest.net server list: %s" % e) + elements = etree_iter(root, "server") except AttributeError: try: root = DOM.parseString(serversxml) except ExpatError: e = get_exception() - raise SpeedtestServersError( - 'Malformed speedtest.net server list: %s' % e - ) - elements = root.getElementsByTagName('server') + raise SpeedtestServersError("Malformed speedtest.net server list: %s" % e) + elements = root.getElementsByTagName("server") except (SyntaxError, xml.parsers.expat.ExpatError): raise ServersRetrievalError() @@ -1332,21 +1285,18 @@ class Speedtest(object): except AttributeError: attrib = dict(list(server.attributes.items())) - if servers and int(attrib.get('id')) not in servers: + if servers and int(attrib.get("id")) not in servers: continue - if (int(attrib.get('id')) in self.config['ignore_servers'] - or int(attrib.get('id')) in exclude): + if int(attrib.get("id")) in self.config["ignore_servers"] or int(attrib.get("id")) in exclude: continue try: - d = distance(self.lat_lon, - (float(attrib.get('lat')), - float(attrib.get('lon')))) + d = distance(self.lat_lon, (float(attrib.get("lat")), float(attrib.get("lon")))) except Exception: continue - attrib['d'] = d + attrib["d"] = d try: self.servers[d].append(attrib) @@ -1379,41 +1329,36 @@ class Speedtest(object): request = build_request(url) uh, e = catch_request(request, opener=self._opener) if e: - raise SpeedtestMiniConnectFailure('Failed to connect to %s' % - server) + raise SpeedtestMiniConnectFailure("Failed to connect to %s" % server) else: text = uh.read() uh.close() - extension = re.findall('upload_?[Ee]xtension: "([^"]+)"', - text.decode()) + extension = re.findall('upload_?[Ee]xtension: "([^"]+)"', text.decode()) if not extension: - for ext in ['php', 'asp', 'aspx', 'jsp']: + for ext in ["php", "asp", "aspx", "jsp"]: try: - f = self._opener.open( - '%s/speedtest/upload.%s' % (url, ext) - ) + f = self._opener.open("%s/speedtest/upload.%s" % (url, ext)) except Exception: pass else: data = f.read().strip().decode() - if (f.code == 200 and - len(data.splitlines()) == 1 and - re.match('size=[0-9]', data)): + if f.code == 200 and len(data.splitlines()) == 1 and re.match("size=[0-9]", data): extension = [ext] break if not urlparts or not extension: - raise InvalidSpeedtestMiniServer('Invalid Speedtest Mini Server: ' - '%s' % server) - - self.servers = [{ - 'sponsor': 'Speedtest Mini', - 'name': urlparts[1], - 'd': 0, - 'url': '%s/speedtest/upload.%s' % (url.rstrip('/'), extension[0]), - 'latency': 0, - 'id': 0 - }] + raise InvalidSpeedtestMiniServer("Invalid Speedtest Mini Server: " "%s" % server) + + self.servers = [ + { + "sponsor": "Speedtest Mini", + "name": urlparts[1], + "d": 0, + "url": "%s/speedtest/upload.%s" % (url.rstrip("/"), extension[0]), + "latency": 0, + "id": 0, + } + ] return self.servers @@ -1434,7 +1379,7 @@ class Speedtest(object): continue break - printer('Closest Servers:\n%r' % self.closest, debug=True) + printer("Closest Servers:\n%r" % self.closest, debug=True) return self.closest def get_best_server(self, servers=None): @@ -1457,39 +1402,32 @@ class Speedtest(object): results = {} for server in servers: cum = [] - url = os.path.dirname(server['url']) + url = os.path.dirname(server["url"]) stamp = int(timeit.time.time() * 1000) - latency_url = '%s/latency.txt?x=%s' % (url, stamp) + latency_url = "%s/latency.txt?x=%s" % (url, stamp) for i in range(0, 3): - this_latency_url = '%s.%s' % (latency_url, i) - printer('%s %s' % ('GET', this_latency_url), - debug=True) + this_latency_url = "%s.%s" % (latency_url, i) + printer("%s %s" % ("GET", this_latency_url), debug=True) urlparts = urlparse(latency_url) try: - if urlparts[0] == 'https': - h = SpeedtestHTTPSConnection( - urlparts[1], - source_address=source_address_tuple - ) + if urlparts[0] == "https": + h = SpeedtestHTTPSConnection(urlparts[1], source_address=source_address_tuple) else: - h = SpeedtestHTTPConnection( - urlparts[1], - source_address=source_address_tuple - ) - headers = {'User-Agent': user_agent} - path = '%s?%s' % (urlparts[2], urlparts[4]) + h = SpeedtestHTTPConnection(urlparts[1], source_address=source_address_tuple) + headers = {"User-Agent": user_agent} + path = "%s?%s" % (urlparts[2], urlparts[4]) start = timeit.default_timer() h.request("GET", path, headers=headers) r = h.getresponse() - total = (timeit.default_timer() - start) + total = timeit.default_timer() - start except HTTP_ERRORS: e = get_exception() - printer('ERROR: %r' % e, debug=True) + printer("ERROR: %r" % e, debug=True) cum.append(3600) continue text = r.read(9) - if int(r.status) == 200 and text == 'test=test'.encode(): + if int(r.status) == 200 and text == "test=test".encode(): cum.append(total) else: cum.append(3600) @@ -1501,16 +1439,15 @@ class Speedtest(object): try: fastest = sorted(results.keys())[0] except IndexError: - raise SpeedtestBestServerFailure('Unable to connect to servers to ' - 'test latency.') + raise SpeedtestBestServerFailure("Unable to connect to servers to " "test latency.") best = results[fastest] - best['latency'] = fastest + best["latency"] = fastest self.results.ping = fastest self.results.server = best self._best.update(best) - printer('Best Server:\n%r' % best, debug=True) + printer("Best Server:\n%r" % best, debug=True) return best def download(self, callback=do_nothing, threads=None): @@ -1521,20 +1458,17 @@ class Speedtest(object): """ urls = [] - for size in self.config['sizes']['download']: - for _ in range(0, self.config['counts']['download']): - urls.append('%s/random%sx%s.jpg' % - (os.path.dirname(self.best['url']), size, size)) + for size in self.config["sizes"]["download"]: + for _ in range(0, self.config["counts"]["download"]): + urls.append("%s/random%sx%s.jpg" % (os.path.dirname(self.best["url"]), size, size)) request_count = len(urls) requests = [] for i, url in enumerate(urls): - requests.append( - build_request(url, bump=i, secure=self._secure) - ) + requests.append(build_request(url, bump=i, secure=self._secure)) - max_threads = threads or self.config['threads']['download'] - in_flight = {'threads': 0} + max_threads = threads or self.config["threads"]["download"] + in_flight = {"threads": 0} def producer(q, requests, request_count): for i, request in enumerate(requests): @@ -1542,15 +1476,15 @@ class Speedtest(object): i, request, start, - self.config['length']['download'], + self.config["length"]["download"], opener=self._opener, - shutdown_event=self._shutdown_event + shutdown_event=self._shutdown_event, ) - while in_flight['threads'] >= max_threads: + while in_flight["threads"] >= max_threads: timeit.time.sleep(0.001) thread.start() q.put(thread, True) - in_flight['threads'] += 1 + in_flight["threads"] += 1 callback(i, request_count, start=True) finished = [] @@ -1561,15 +1495,13 @@ class Speedtest(object): thread = q.get(True) while _is_alive(thread): thread.join(timeout=0.001) - in_flight['threads'] -= 1 + in_flight["threads"] -= 1 finished.append(sum(thread.result)) callback(thread.i, request_count, end=True) q = Queue(max_threads) - prod_thread = threading.Thread(target=producer, - args=(q, requests, request_count)) - cons_thread = threading.Thread(target=consumer, - args=(q, request_count)) + prod_thread = threading.Thread(target=producer, args=(q, requests, request_count)) + cons_thread = threading.Thread(target=consumer, args=(q, request_count)) start = timeit.default_timer() prod_thread.start() cons_thread.start() @@ -1581,11 +1513,9 @@ class Speedtest(object): stop = timeit.default_timer() self.results.bytes_received = sum(finished) - self.results.download = ( - (self.results.bytes_received / (stop - start)) * 8.0 - ) + self.results.download = (self.results.bytes_received / (stop - start)) * 8.0 if self.results.download > 100000: - self.config['threads']['upload'] = 8 + self.config["threads"]["upload"] = 8 return self.results.download def upload(self, callback=do_nothing, pre_allocate=True, threads=None): @@ -1597,37 +1527,26 @@ class Speedtest(object): sizes = [] - for size in self.config['sizes']['upload']: - for _ in range(0, self.config['counts']['upload']): + for size in self.config["sizes"]["upload"]: + for _ in range(0, self.config["counts"]["upload"]): sizes.append(size) # request_count = len(sizes) - request_count = self.config['upload_max'] + request_count = self.config["upload_max"] requests = [] for i, size in enumerate(sizes): # We set ``0`` for ``start`` and handle setting the actual # ``start`` in ``HTTPUploader`` to get better measurements - data = HTTPUploaderData( - size, - 0, - self.config['length']['upload'], - shutdown_event=self._shutdown_event - ) + data = HTTPUploaderData(size, 0, self.config["length"]["upload"], shutdown_event=self._shutdown_event) if pre_allocate: data.pre_allocate() - headers = {'Content-length': size} - requests.append( - ( - build_request(self.best['url'], data, secure=self._secure, - headers=headers), - size - ) - ) + headers = {"Content-length": size} + requests.append((build_request(self.best["url"], data, secure=self._secure, headers=headers), size)) - max_threads = threads or self.config['threads']['upload'] - in_flight = {'threads': 0} + max_threads = threads or self.config["threads"]["upload"] + in_flight = {"threads": 0} def producer(q, requests, request_count): for i, request in enumerate(requests[:request_count]): @@ -1636,15 +1555,15 @@ class Speedtest(object): request[0], start, request[1], - self.config['length']['upload'], + self.config["length"]["upload"], opener=self._opener, - shutdown_event=self._shutdown_event + shutdown_event=self._shutdown_event, ) - while in_flight['threads'] >= max_threads: + while in_flight["threads"] >= max_threads: timeit.time.sleep(0.001) thread.start() q.put(thread, True) - in_flight['threads'] += 1 + in_flight["threads"] += 1 callback(i, request_count, start=True) finished = [] @@ -1655,15 +1574,13 @@ class Speedtest(object): thread = q.get(True) while _is_alive(thread): thread.join(timeout=0.001) - in_flight['threads'] -= 1 + in_flight["threads"] -= 1 finished.append(thread.result) callback(thread.i, request_count, end=True) - q = Queue(threads or self.config['threads']['upload']) - prod_thread = threading.Thread(target=producer, - args=(q, requests, request_count)) - cons_thread = threading.Thread(target=consumer, - args=(q, request_count)) + q = Queue(threads or self.config["threads"]["upload"]) + prod_thread = threading.Thread(target=producer, args=(q, requests, request_count)) + cons_thread = threading.Thread(target=consumer, args=(q, request_count)) start = timeit.default_timer() prod_thread.start() cons_thread.start() @@ -1675,9 +1592,7 @@ class Speedtest(object): stop = timeit.default_timer() self.results.bytes_sent = sum(finished) - self.results.upload = ( - (self.results.bytes_sent / (stop - start)) * 8.0 - ) + self.results.upload = (self.results.bytes_sent / (stop - start)) * 8.0 return self.results.upload @@ -1685,22 +1600,24 @@ def ctrl_c(shutdown_event): """Catch Ctrl-C key sequence and set a SHUTDOWN_EVENT for our threaded operations """ + def inner(signum, frame): shutdown_event.set() - printer('\nCancelling...', error=True) + printer("\nCancelling...", error=True) sys.exit(0) + return inner def version(): """Print the version""" - printer('speedtest-cli %s' % __version__) - printer('Python %s' % sys.version.replace('\n', '')) + printer("speedtest-cli %s" % __version__) + printer("Python %s" % sys.version.replace("\n", "")) sys.exit(0) -def csv_header(delimiter=','): +def csv_header(delimiter=","): """Print the CSV Headers""" printer(SpeedtestResults.csv_header(delimiter=delimiter)) @@ -1710,11 +1627,12 @@ def csv_header(delimiter=','): def parse_args(): """Function to handle building and parsing of command line arguments""" description = ( - 'Command line interface for testing internet bandwidth using ' - 'speedtest.net.\n' - '------------------------------------------------------------' - '--------------\n' - 'https://github.com/sivel/speedtest-cli') + "Command line interface for testing internet bandwidth using " + "speedtest.net.\n" + "------------------------------------------------------------" + "--------------\n" + "https://github.com/sivel/speedtest-cli" + ) parser = ArgParser(description=description) # Give optparse.OptionParser an `add_argument` method for @@ -1723,67 +1641,101 @@ def parse_args(): parser.add_argument = parser.add_option except AttributeError: pass - parser.add_argument('--no-download', dest='download', default=True, - action='store_const', const=False, - help='Do not perform download test') - parser.add_argument('--no-upload', dest='upload', default=True, - action='store_const', const=False, - help='Do not perform upload test') - parser.add_argument('--single', default=False, action='store_true', - help='Only use a single connection instead of ' - 'multiple. This simulates a typical file ' - 'transfer.') - parser.add_argument('--bytes', dest='units', action='store_const', - const=('byte', 8), default=('bit', 1), - help='Display values in bytes instead of bits. Does ' - 'not affect the image generated by --share, nor ' - 'output from --json or --csv') - parser.add_argument('--share', action='store_true', - help='Generate and provide a URL to the speedtest.net ' - 'share results image, not displayed with --csv') - parser.add_argument('--simple', action='store_true', default=False, - help='Suppress verbose output, only show basic ' - 'information') - parser.add_argument('--csv', action='store_true', default=False, - help='Suppress verbose output, only show basic ' - 'information in CSV format. Speeds listed in ' - 'bit/s and not affected by --bytes') - parser.add_argument('--csv-delimiter', default=',', type=PARSER_TYPE_STR, - help='Single character delimiter to use in CSV ' - 'output. Default ","') - parser.add_argument('--csv-header', action='store_true', default=False, - help='Print CSV headers') - parser.add_argument('--json', action='store_true', default=False, - help='Suppress verbose output, only show basic ' - 'information in JSON format. Speeds listed in ' - 'bit/s and not affected by --bytes') - parser.add_argument('--list', action='store_true', - help='Display a list of speedtest.net servers ' - 'sorted by distance') - parser.add_argument('--server', type=PARSER_TYPE_INT, action='append', - help='Specify a server ID to test against. Can be ' - 'supplied multiple times') - parser.add_argument('--exclude', type=PARSER_TYPE_INT, action='append', - help='Exclude a server from selection. Can be ' - 'supplied multiple times') - parser.add_argument('--mini', help='URL of the Speedtest Mini server') - parser.add_argument('--source', help='Source IP address to bind to') - parser.add_argument('--timeout', default=10, type=PARSER_TYPE_FLOAT, - help='HTTP timeout in seconds. Default 10') - parser.add_argument('--secure', action='store_true', - help='Use HTTPS instead of HTTP when communicating ' - 'with speedtest.net operated servers') - parser.add_argument('--no-pre-allocate', dest='pre_allocate', - action='store_const', default=True, const=False, - help='Do not pre allocate upload data. Pre allocation ' - 'is enabled by default to improve upload ' - 'performance. To support systems with ' - 'insufficient memory, use this option to avoid a ' - 'MemoryError') - parser.add_argument('--version', action='store_true', - help='Show the version number and exit') - parser.add_argument('--debug', action='store_true', - help=ARG_SUPPRESS, default=ARG_SUPPRESS) + parser.add_argument( + "--no-download", + dest="download", + default=True, + action="store_const", + const=False, + help="Do not perform download test", + ) + parser.add_argument( + "--no-upload", dest="upload", default=True, action="store_const", const=False, help="Do not perform upload test" + ) + parser.add_argument( + "--single", + default=False, + action="store_true", + help="Only use a single connection instead of " "multiple. This simulates a typical file " "transfer.", + ) + parser.add_argument( + "--bytes", + dest="units", + action="store_const", + const=("byte", 8), + default=("bit", 1), + help="Display values in bytes instead of bits. Does " + "not affect the image generated by --share, nor " + "output from --json or --csv", + ) + parser.add_argument( + "--share", + action="store_true", + help="Generate and provide a URL to the speedtest.net " "share results image, not displayed with --csv", + ) + parser.add_argument( + "--simple", action="store_true", default=False, help="Suppress verbose output, only show basic " "information" + ) + parser.add_argument( + "--csv", + action="store_true", + default=False, + help="Suppress verbose output, only show basic " + "information in CSV format. Speeds listed in " + "bit/s and not affected by --bytes", + ) + parser.add_argument( + "--csv-delimiter", + default=",", + type=PARSER_TYPE_STR, + help="Single character delimiter to use in CSV " 'output. Default ","', + ) + parser.add_argument("--csv-header", action="store_true", default=False, help="Print CSV headers") + parser.add_argument( + "--json", + action="store_true", + default=False, + help="Suppress verbose output, only show basic " + "information in JSON format. Speeds listed in " + "bit/s and not affected by --bytes", + ) + parser.add_argument( + "--list", action="store_true", help="Display a list of speedtest.net servers " "sorted by distance" + ) + parser.add_argument( + "--server", + type=PARSER_TYPE_INT, + action="append", + help="Specify a server ID to test against. Can be " "supplied multiple times", + ) + parser.add_argument( + "--exclude", + type=PARSER_TYPE_INT, + action="append", + help="Exclude a server from selection. Can be " "supplied multiple times", + ) + parser.add_argument("--mini", help="URL of the Speedtest Mini server") + parser.add_argument("--source", help="Source IP address to bind to") + parser.add_argument("--timeout", default=10, type=PARSER_TYPE_FLOAT, help="HTTP timeout in seconds. Default 10") + parser.add_argument( + "--secure", + action="store_true", + help="Use HTTPS instead of HTTP when communicating " "with speedtest.net operated servers", + ) + parser.add_argument( + "--no-pre-allocate", + dest="pre_allocate", + action="store_const", + default=True, + const=False, + help="Do not pre allocate upload data. Pre allocation " + "is enabled by default to improve upload " + "performance. To support systems with " + "insufficient memory, use this option to avoid a " + "MemoryError", + ) + parser.add_argument("--version", action="store_true", help="Show the version number and exit") + parser.add_argument("--debug", action="store_true", help=ARG_SUPPRESS, default=ARG_SUPPRESS) options = parser.parse_args() if isinstance(options, tuple): @@ -1801,14 +1753,13 @@ def validate_optional_args(args): with an error stating which module is missing. """ optional_args = { - 'json': ('json/simplejson python module', json), - 'secure': ('SSL support', HTTPSConnection), + "json": ("json/simplejson python module", json), + "secure": ("SSL support", HTTPSConnection), } for arg, info in optional_args.items(): if getattr(args, arg, False) and info[1] is None: - raise SystemExit('%s is not installed. --%s is ' - 'unavailable' % (info[0], arg)) + raise SystemExit("%s is not installed. --%s is " "unavailable" % (info[0], arg)) def printer(string, quiet=False, debug=False, error=False, **kwargs): @@ -1819,14 +1770,14 @@ def printer(string, quiet=False, debug=False, error=False, **kwargs): if debug: if sys.stdout.isatty(): - out = '\033[1;30mDEBUG: %s\033[0m' % string + out = "\033[1;30mDEBUG: %s\033[0m" % string else: - out = 'DEBUG: %s' % string + out = "DEBUG: %s" % string else: out = string if error: - kwargs['file'] = sys.stderr + kwargs["file"] = sys.stderr if not quiet: print_(out, **kwargs) @@ -1847,19 +1798,18 @@ def shell(): version() if not args.download and not args.upload: - raise SpeedtestCLIError('Cannot supply both --no-download and ' - '--no-upload') + raise SpeedtestCLIError("Cannot supply both --no-download and " "--no-upload") if len(args.csv_delimiter) != 1: - raise SpeedtestCLIError('--csv-delimiter must be a single character') + raise SpeedtestCLIError("--csv-delimiter must be a single character") if args.csv_header: csv_header(args.csv_delimiter) validate_optional_args(args) - debug = getattr(args, 'debug', False) - if debug == 'SUPPRESSHELP': + debug = getattr(args, "debug", False) + if debug == "SUPPRESSHELP": debug = False if debug: DEBUG = True @@ -1880,28 +1830,23 @@ def shell(): else: callback = print_dots(shutdown_event) - printer('Retrieving speedtest.net configuration...', quiet) + printer("Retrieving speedtest.net configuration...", quiet) try: - speedtest = Speedtest( - source_address=args.source, - timeout=args.timeout, - secure=args.secure - ) + speedtest = Speedtest(source_address=args.source, timeout=args.timeout, secure=args.secure) except (ConfigRetrievalError,) + HTTP_ERRORS: - printer('Cannot retrieve speedtest configuration', error=True) + printer("Cannot retrieve speedtest configuration", error=True) raise SpeedtestCLIError(get_exception()) if args.list: try: speedtest.get_servers() except (ServersRetrievalError,) + HTTP_ERRORS: - printer('Cannot retrieve speedtest server list', error=True) + printer("Cannot retrieve speedtest server list", error=True) raise SpeedtestCLIError(get_exception()) for _, servers in sorted(speedtest.servers.items()): for server in servers: - line = ('%(id)5s) %(sponsor)s (%(name)s, %(country)s) ' - '[%(d)0.2f km]' % server) + line = "%(id)5s) %(sponsor)s (%(name)s, %(country)s) " "[%(d)0.2f km]" % server try: printer(line) except IOError: @@ -1910,104 +1855,87 @@ def shell(): raise sys.exit(0) - printer('Testing from %(isp)s (%(ip)s)...' % speedtest.config['client'], - quiet) + printer("Testing from %(isp)s (%(ip)s)..." % speedtest.config["client"], quiet) if not args.mini: - printer('Retrieving speedtest.net server list...', quiet) + printer("Retrieving speedtest.net server list...", quiet) try: speedtest.get_servers(servers=args.server, exclude=args.exclude) except NoMatchedServers: - raise SpeedtestCLIError( - 'No matched servers: %s' % - ', '.join('%s' % s for s in args.server) - ) + raise SpeedtestCLIError("No matched servers: %s" % ", ".join("%s" % s for s in args.server)) except (ServersRetrievalError,) + HTTP_ERRORS: - printer('Cannot retrieve speedtest server list', error=True) + printer("Cannot retrieve speedtest server list", error=True) raise SpeedtestCLIError(get_exception()) except InvalidServerIDType: raise SpeedtestCLIError( - '%s is an invalid server type, must ' - 'be an int' % ', '.join('%s' % s for s in args.server) + "%s is an invalid server type, must " "be an int" % ", ".join("%s" % s for s in args.server) ) if args.server and len(args.server) == 1: - printer('Retrieving information for the selected server...', quiet) + printer("Retrieving information for the selected server...", quiet) else: - printer('Selecting best server based on ping...', quiet) + printer("Selecting best server based on ping...", quiet) speedtest.get_best_server() elif args.mini: speedtest.get_best_server(speedtest.set_mini_server(args.mini)) results = speedtest.results - printer('Hosted by %(sponsor)s (%(name)s) [%(d)0.2f km]: ' - '%(latency)s ms' % results.server, quiet) + printer("Hosted by %(sponsor)s (%(name)s) [%(d)0.2f km]: " "%(latency)s ms" % results.server, quiet) if args.download: - printer('Testing download speed', quiet, - end=('', '\n')[bool(debug)]) - speedtest.download( - callback=callback, - threads=(None, 1)[args.single] - ) - printer('Download: %0.2f M%s/s' % - ((results.download / 1000.0 / 1000.0) / args.units[1], - args.units[0]), - quiet) + printer("Testing download speed", quiet, end=("", "\n")[bool(debug)]) + speedtest.download(callback=callback, threads=(None, 1)[args.single]) + printer("Download: %0.2f M%s/s" % ((results.download / 1000.0 / 1000.0) / args.units[1], args.units[0]), quiet) else: - printer('Skipping download test', quiet) + printer("Skipping download test", quiet) if args.upload: - printer('Testing upload speed', quiet, - end=('', '\n')[bool(debug)]) - speedtest.upload( - callback=callback, - pre_allocate=args.pre_allocate, - threads=(None, 1)[args.single] - ) - printer('Upload: %0.2f M%s/s' % - ((results.upload / 1000.0 / 1000.0) / args.units[1], - args.units[0]), - quiet) + printer("Testing upload speed", quiet, end=("", "\n")[bool(debug)]) + speedtest.upload(callback=callback, pre_allocate=args.pre_allocate, threads=(None, 1)[args.single]) + printer("Upload: %0.2f M%s/s" % ((results.upload / 1000.0 / 1000.0) / args.units[1], args.units[0]), quiet) else: - printer('Skipping upload test', quiet) + printer("Skipping upload test", quiet) - printer('Results:\n%r' % results.dict(), debug=True) + printer("Results:\n%r" % results.dict(), debug=True) if not args.simple and args.share: results.share() if args.simple: - printer('Ping: %s ms\nDownload: %0.2f M%s/s\nUpload: %0.2f M%s/s' % - (results.ping, - (results.download / 1000.0 / 1000.0) / args.units[1], - args.units[0], - (results.upload / 1000.0 / 1000.0) / args.units[1], - args.units[0])) + printer( + "Ping: %s ms\nDownload: %0.2f M%s/s\nUpload: %0.2f M%s/s" + % ( + results.ping, + (results.download / 1000.0 / 1000.0) / args.units[1], + args.units[0], + (results.upload / 1000.0 / 1000.0) / args.units[1], + args.units[0], + ) + ) elif args.csv: printer(results.csv(delimiter=args.csv_delimiter)) elif args.json: printer(results.json()) if args.share and not machine_format: - printer('Share results: %s' % results.share()) + printer("Share results: %s" % results.share()) def main(): try: shell() except KeyboardInterrupt: - printer('\nCancelling...', error=True) + printer("\nCancelling...", error=True) except (SpeedtestException, SystemExit): e = get_exception() # Ignore a successful exit, or argparse exit - if getattr(e, 'code', 1) not in (0, 2): - msg = '%s' % e + if getattr(e, "code", 1) not in (0, 2): + msg = "%s" % e if not msg: - msg = '%r' % e - raise SystemExit('ERROR: %s' % msg) + msg = "%r" % e + raise SystemExit("ERROR: %s" % msg) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..a5b7e30 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,10 @@ +[tool.black] +line-length = 120 +required-version = "22.3.0" + +[tool.isort] +profile = "black" +line_length = 120 +combine_as_imports = true +combine_star = true +known_local_folder = ["tests", "cli"] \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..637434d --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,6 @@ +pytest==6.2.5 # see https://github.com/pytest-dev/pytest/issues/9621 +pytest-forked +pytest-asyncio==0.16.0 +black==22.3.0 +isort==5.10.1 +psutil \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..da0e072 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +torch==1.12.0 +accelerate==0.10.0 +huggingface-hub==0.7.0 +bitsandbytes-cuda113==0.26.0 +https://github.com/learning-at-home/hivemind/archive/d42c70331da43667da6d9020666df54806d8b561.zip +https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip diff --git a/src/bloom/block.py b/src/bloom/block.py index a4175a7..1898d37 100644 --- a/src/bloom/block.py +++ b/src/bloom/block.py @@ -9,8 +9,15 @@ import torch import torch.nn as nn import torch.nn.quantized.dynamic.modules.linear -from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add, - pre_process_alibi_for_pad, split_tensor_along_last_dim) +from src.bloom.ops import ( + BloomGelu, + BloomScaledSoftmax, + attention_mask_func, + build_alibi_tensor, + dropout_add, + pre_process_alibi_for_pad, + split_tensor_along_last_dim, +) class BloomAttention(nn.Module): diff --git a/src/bloom/model.py b/src/bloom/model.py index a83c378..7c522d4 100644 --- a/src/bloom/model.py +++ b/src/bloom/model.py @@ -10,14 +10,16 @@ import torch.nn.functional as F import torch.utils.checkpoint from hivemind import use_hivemind_log_handler from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss, LayerNorm -from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings, - add_start_docstrings_to_model_forward) +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss +from transformers.file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, +) from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, - TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel from transformers.models.bloom.configuration_bloom import BloomConfig @@ -445,12 +447,27 @@ class LMHead(nn.Module): self.word_embeddings = word_embeddings self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu + @property + def in_features(self) -> int: + return self.word_embeddings.num_embeddings + + @property + def out_features(self) -> int: + return self.word_embeddings.embedding_dim + + @property + def weight(self): + return self.word_embeddings.weight + + @property + def bias(self): + return None + def forward(self, hidden_states): word_embeddings = self.word_embeddings.weight - + # We use 'chunked_forward' only when embeddings are in half-precision on CPU. - if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \ - word_embeddings.device.type == 'cpu': + if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu": lm_logits = self.chunked_forward(hidden_states) else: # Switch dtype in case word_embeddings are fp16/bf16 @@ -459,20 +476,20 @@ class LMHead(nn.Module): return lm_logits def chunked_forward(self, hidden_states): - """ Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU. - chunk_size: provides trade-off between efficiency and extra memory consumption. + """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU. + chunk_size: provides trade-off between efficiency and extra memory consumption. """ assert self.chunk_size > 0, "Chunk size for chunked forward must be positive" word_embeddings = self.word_embeddings.weight num_embeddings = self.word_embeddings.num_embeddings - hidden_states = hidden_states.float() + hidden_states = hidden_states.float() output = torch.zeros(*hidden_states.shape[:-1], num_embeddings) for i in range(0, num_embeddings, self.chunk_size): - chunk = word_embeddings[i: i + self.chunk_size].float() - output[..., i: i + self.chunk_size] = F.linear(hidden_states, chunk) + chunk = word_embeddings[i : i + self.chunk_size].float() + output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk) return output @@ -565,7 +582,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel): f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None diff --git a/src/client/__init__.py b/src/client/__init__.py index 19a5ac7..8ca8c8e 100644 --- a/src/client/__init__.py +++ b/src/client/__init__.py @@ -1,4 +1,4 @@ from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel -from src.client.remote_sequence_info import RemoteSequenceInfo from src.client.remote_sequential import RemoteSequential +from src.client.sequence_manager import RemoteSequenceManager diff --git a/src/client/remote_model.py b/src/client/remote_model.py index badf952..aa6a991 100644 --- a/src/client/remote_model.py +++ b/src/client/remote_model.py @@ -2,15 +2,20 @@ import os from typing import Optional, Tuple +import hivemind import torch import torch.nn as nn - -import hivemind from hivemind import get_logger, use_hivemind_log_handler -from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead, BloomForSequenceClassification +from src.bloom.model import ( + BloomConfig, + BloomForCausalLM, + BloomForSequenceClassification, + BloomModel, + BloomPreTrainedModel, + LMHead, +) from src.client.remote_sequential import RemoteSequential -from src.data_structures import UID_DELIMITER use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) @@ -25,12 +30,13 @@ class DistributedBloomConfig(BloomConfig): initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name) dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models - chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU - num_prefix_tokens: int = 0 # a number of tokens for prompt tuning. + chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU + num_prefix_tokens: int = 0 # a number of tokens for prompt tuning. class DistributedBloomModel(BloomModel): """BloomModel, but all transformer layers are hosted by the swarm""" + config_class = DistributedBloomConfig def __init__(self, config: DistributedBloomConfig): @@ -49,7 +55,7 @@ class DistributedBloomModel(BloomModel): ) assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance" self.h = RemoteSequential(config, dht, config.dht_prefix) - + # Forbid accumulate grads for embeddings and layernorm self.set_requires_grad(False) @@ -57,6 +63,14 @@ class DistributedBloomModel(BloomModel): for p in self.parameters(): p.requires_grad = value + def forward(self, *args, use_cache=None, **kwargs): + if use_cache: + raise ValueError( + "Distributed forward does not support use_cache; for efficient cache-aware generation, " + "please use model.transformer.inference_session() or model.generate(...)" + ) + return super().forward(*args, use_cache=False, **kwargs) + class DistributedBloomPrefix(DistributedBloomModel): """DistributedBloomModel with prefix tokens for prompt tuning""" @@ -76,7 +90,7 @@ class DistributedBloomPrefix(DistributedBloomModel): return prompts def forward( - self, + self, input_ids: Optional[torch.LongTensor], inputs_embeds: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor], @@ -86,14 +100,16 @@ class DistributedBloomPrefix(DistributedBloomModel): use_cache=None, output_attentions=None, output_hidden_states=None, - return_dict=None + return_dict=None, ): - assert input_ids is None or inputs_embeds is None, "You cannot specify both input_ids and inputs_embeds at the same time" + assert ( + input_ids is None or inputs_embeds is None + ), "You cannot specify both input_ids and inputs_embeds at the same time" assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds" - + if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - + batch_size = inputs_embeds.shape[0] if attention_mask is not None: @@ -104,25 +120,26 @@ class DistributedBloomPrefix(DistributedBloomModel): inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) transformer_outputs = super().forward( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, past_key_values=past_key_values, position_ids=position_ids, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict + return_dict=return_dict, ) # Remove prefix - last_hidden_state = transformer_outputs[0][:, self.prefix_length:] - transformer_outputs['last_hidden_state'] = last_hidden_state + last_hidden_state = transformer_outputs[0][:, self.prefix_length :] + transformer_outputs["last_hidden_state"] = last_hidden_state return transformer_outputs class DistributedBloomForCausalLM(BloomForCausalLM): - """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm""" + """Similar to BloomForCausalLM, but all transformer layers are hosted by the swarm""" + config_class = DistributedBloomConfig def __init__(self, config: DistributedBloomConfig): @@ -136,11 +153,23 @@ class DistributedBloomForCausalLM(BloomForCausalLM): # Initialize weights and apply final processing self.post_init() - def get_output_embeddings(self): - return self.lm_head.word_embeddings + def get_input_embeddings(self): + return self.transformer.word_embeddings - def set_output_embeddings(self, new_embeddings): - self.lm_head.word_embeddings.weight = new_embeddings.weight + def get_output_embeddings(self): + if self.config.tie_word_embeddings: + return None + return self.lm_head + + def set_input_embeddings(self, new_embeddings: nn.Embedding): + assert isinstance(new_embeddings, nn.Embedding) + self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings + assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings + + def set_output_embeddings(self, new_lm_head: nn.Linear): + with torch.no_grad(): + self.lm_head.word_embeddings.weight[...] = new_lm_head.weight + self.lm_head.bias[...] = new_lm_head.bias class DistributedBloomForSequenceClassification(BloomForSequenceClassification): diff --git a/src/client/remote_sequential.py b/src/client/remote_sequential.py index 1b3a71e..7026e7c 100644 --- a/src/client/remote_sequential.py +++ b/src/client/remote_sequential.py @@ -3,6 +3,7 @@ from __future__ import annotations import contextlib import logging import random +from typing import Optional, Union import torch from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler @@ -12,7 +13,7 @@ from torch import nn import src from src.client.remote_block import RemoteTransformerBlock -from src.client.remote_sequence_info import RemoteSequenceInfo +from src.client.sequence_manager import RemoteSequenceManager from src.data_structures import UID_DELIMITER from src.dht_utils import _create_remote_modules_from_infos @@ -25,7 +26,15 @@ class RemoteSequential(nn.Module): A sequence of transformer blocks hosted by the swarm. """ - def __init__(self, config: src.DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3): + def __init__( + self, + config: src.DistributedBloomConfig, + dht: DHT, + prefix: str, + max_retries: int = 3, + p2p: Optional[P2P] = None, + sequence_manager: Optional[RemoteSequenceManager] = None, + ): logger.warning(f"{self.__class__.__name__} is in active development; expect adventures") if prefix.endswith(UID_DELIMITER): logger.warning( @@ -39,12 +48,17 @@ class RemoteSequential(nn.Module): self.dht = dht self.prefix = prefix self.max_retries = max_retries - self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) - - block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)) - - logger.debug(f"Remote block uids: {block_uids}") - self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids) + self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p + + block_uids = [f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] + if sequence_manager is None: + logger.debug(f"Creating new sequence manager for block uids: {block_uids}") + self.sequence_manager = RemoteSequenceManager(dht, block_uids) + self.is_subsequence = False + else: + assert isinstance(sequence_manager.block_uids, list) + logger.debug(f"Reusing sequence manager with {len(self.sequence_manager)}") + self.is_subsequence = self.sequence_manager.block_uids == block_uids def forward(self, inputs: torch.Tensor): assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed @@ -64,27 +78,38 @@ class RemoteSequential(nn.Module): logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True) return inputs - def __getitem__(self, block_index: int): - assert 0 <= block_index < self.config.n_layer - (module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p) - return module + def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]: + assert isinstance(ix, (int, slice)) + if isinstance(ix, int): + assert 0 <= ix < self.config.n_layer + (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p) + return module + else: + return RemoteSequential( + self.config, + self.dht, + prefix=self.prefix, + max_retries=self.max_retries, + p2p=self.p2p, + sequence_manager=self.sequence_manager[ix], + ) def __iter__(self): for block_index in range(self.config.n_layer): yield self[block_index] def __len__(self): - return len(self.remote_sequence_info) + return len(self.sequence_manager) def inference_session(self) -> RemoteSequentialInferenceSession: - self.remote_sequence_info.update_() - return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p) + self.sequence_manager.update_() + return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p) class RemoteSequentialInferenceSession: """An interface to a multi-step *inference* session for a sequence of remote transformer blocks""" - def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P): + def __init__(self, remote_sequence_info: RemoteSequenceManager, p2p: P2P): self.remote_sequence_info = remote_sequence_info self.p2p = p2p self.closed = False diff --git a/src/client/remote_sequence_info.py b/src/client/sequence_manager.py similarity index 68% rename from src/client/remote_sequence_info.py rename to src/client/sequence_manager.py index 922d6c9..7a05bb2 100644 --- a/src/client/remote_sequence_info.py +++ b/src/client/sequence_manager.py @@ -1,29 +1,27 @@ from __future__ import annotations import threading -from typing import List, NamedTuple, Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Tuple, Union -from hivemind import DHT, PeerID +from hivemind import DHT, DHTExpiration from hivemind.utils.logging import get_logger, use_hivemind_log_handler -from src.data_structures import ModuleUID, RemoteModuleInfo, ServerState +from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState from src.dht_utils import get_remote_module_infos use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) -Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)]) - - -class RemoteSequenceInfo: +class RemoteSequenceManager: """Keeps and updates the meta-information about which peers host which blocks""" dht: DHT block_uids: List[ModuleUID] block_infos: List[Optional[RemoteModuleInfo]] - spans_by_priority: List[Span] # sorted from best to worst - spans_containing_block: Tuple[List[Span]] + spans_by_priority: List[RemoteSpanInfo] # sorted from best to worst + spans_containing_block: Tuple[List[RemoteSpanInfo], ...] + last_update_time: DHTExpiration lock_changes: threading.Lock def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]): @@ -32,6 +30,7 @@ class RemoteSequenceInfo: self.block_infos = [None] * len(self.block_uids) self.spans_by_priority = [] self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids))) + self.last_update_time = -float("inf") self.lock_changes = threading.Lock() self.update_() @@ -39,6 +38,18 @@ class RemoteSequenceInfo: assert info is not None, f"Found no remote peers for block {uid}" assert self.spans_by_priority and self.spans_containing_block + def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager: + """Get a RemoteSequenceManager for a sub-sequence of blocks""" + assert isinstance(ix, (int, slice)) + if not isinstance(ix, slice): + ix = slice(int(ix), int(ix) + 1, 1) + with self.lock_changes: + subseq = RemoteSequenceManager(self.dht, self.block_uids[ix]) + subseq.block_infos = self.block_infos[ix] + subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos) + subseq.last_update_time = self.last_update_time + return subseq + def update_(self): with self.lock_changes: self.update_block_infos_() @@ -67,15 +78,15 @@ class RemoteSequenceInfo: if server.state != ServerState.ONLINE: continue if peer_id not in active_spans: - active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id) + active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id) else: # peer_id in active_spans - active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1) + active_spans[peer_id].end = block_index + 1 for peer_id in list(active_spans.keys()): if ( - peer_id not in info.servers or - info.servers[peer_id].state != ServerState.ONLINE or - block_index == len(block_infos) - 1 + peer_id not in info.servers + or info.servers[peer_id].state != ServerState.ONLINE + or block_index == len(block_infos) - 1 ): closed_spans.append(active_spans.pop(peer_id)) assert not active_spans diff --git a/src/data_structures.py b/src/data_structures.py index a12bb83..d0719a9 100644 --- a/src/data_structures.py +++ b/src/data_structures.py @@ -23,5 +23,16 @@ class ServerInfo: @dataclass class RemoteModuleInfo: + """A remote module that is served by one or more servers""" + uid: ModuleUID servers: Dict[PeerID, ServerInfo] + + +@dataclass +class RemoteSpanInfo: + """A chain of remote blocks served by one specific remote peer""" + + start: int + end: int + peer_id: PeerID diff --git a/src/dht_utils.py b/src/dht_utils.py index 46c6196..fe5df32 100644 --- a/src/dht_utils.py +++ b/src/dht_utils.py @@ -136,8 +136,12 @@ async def _get_remote_module_infos( try: peer_id = PeerID.from_base58(peer_id) state, throughput = server_info.value - if not (isinstance(state, int) and isinstance(throughput, float) and - math.isfinite(throughput) and throughput >= 0.0): + if not ( + isinstance(state, int) + and isinstance(throughput, float) + and math.isfinite(throughput) + and throughput >= 0.0 + ): raise ValueError(f"Invalid server info: {server_info}") servers[peer_id] = ServerInfo(ServerState(state), throughput) except (TypeError, ValueError) as e: diff --git a/src/server/block_selection.py b/src/server/block_selection.py index 4ce63fd..75ee471 100644 --- a/src/server/block_selection.py +++ b/src/server/block_selection.py @@ -9,10 +9,10 @@ def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[Remot if module is None: throughputs.append(0) continue - throughputs.append(sum(server.throughput for server in module.servers.values() - if server.state != ServerState.OFFLINE)) + throughputs.append( + sum(server.throughput for server in module.servers.values() if server.state != ServerState.OFFLINE) + ) - options = [(sorted(throughputs[i:i + num_blocks]), i) - for i in range(0, len(throughputs) - num_blocks + 1)] + options = [(sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1)] best_start = min(options)[1] return list(range(best_start, best_start + num_blocks)) diff --git a/src/server/server.py b/src/server/server.py index 790ac4c..a379648 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -4,7 +4,7 @@ import multiprocessing as mp import random import threading import time -from typing import Dict, Literal, Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Union import torch from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time @@ -13,7 +13,7 @@ from hivemind.moe.server.runtime import Runtime from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.logging import get_logger, use_hivemind_log_handler -from src import declare_active_modules, BloomConfig +from src import BloomConfig, declare_active_modules from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from src.dht_utils import get_remote_module_infos @@ -98,7 +98,7 @@ class Server(threading.Thread): cls, prefix: Optional[str], converted_model_name_or_path: str, - throughput: Union[float, Literal['auto', 'eval']], + throughput: Union[float, str], num_blocks: Optional[int] = None, block_indices: Optional[str] = None, num_handlers: Optional[int] = None, @@ -140,17 +140,15 @@ class Server(threading.Thread): device = device or ("cuda" if torch.cuda.is_available() else "cpu") memory_cache = MemoryCache(device, cache_size_bytes) - assert isinstance(throughput, float) or throughput in ['auto', 'eval'] - if throughput in ['auto', 'eval']: - throughput = get_host_throughput(device, force_eval=(throughput == 'eval')) + assert isinstance(throughput, float) or throughput in ["auto", "eval"] + if throughput in ["auto", "eval"]: + throughput = get_host_throughput(device, force_eval=(throughput == "eval")) if isinstance(torch_dtype, str): torch_dtype = DTYPE_MAP[torch_dtype] assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" - block_config = BloomConfig.from_pretrained( - converted_model_name_or_path, use_auth_token=use_auth_token - ) + block_config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token) if block_indices is not None: try: @@ -288,7 +286,7 @@ class ModuleAnnouncerThread(threading.Thread): throughput: float, update_period: float = 30, expiration: float, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.module_backends = module_backends diff --git a/src/server/throughput.py b/src/server/throughput.py index 981c6cc..f14e936 100644 --- a/src/server/throughput.py +++ b/src/server/throughput.py @@ -20,10 +20,10 @@ use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) -DEFAULT_CACHE_PATH = Path(Path.home(), '.cache', project_name, 'throughput.json') -DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, 'throughput.lock') +DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", project_name, "throughput.json") +DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, "throughput.lock") -SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], 'cli', 'speed_test.py') +SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], "cli", "speed_test.py") @dataclass @@ -43,7 +43,7 @@ def get_host_throughput( # We use the system-wide lock since only one process at a time can measure the host throughput os.makedirs(lock_path.parent, exist_ok=True) - with open(lock_path, 'wb') as lock_fd: + with open(lock_path, "wb") as lock_fd: logger.info("Loading throughput info") fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX) # The OS will release the lock when lock_fd is closed or the process is killed @@ -63,7 +63,7 @@ def get_host_throughput( info = measure_throughput_info() try: os.makedirs(cache_path.parent, exist_ok=True) - with open(cache_path, 'w') as cache_fd: + with open(cache_path, "w") as cache_fd: json.dump(asdict(info), cache_fd) except Exception: logger.exception(f"Failed to save throughput info in {cache_path}") @@ -73,29 +73,30 @@ def get_host_throughput( def measure_throughput_info() -> ThroughputInfo: - logger.info("Measuring network, CPU, and GPU throughput. " - "This takes about a minute and will be cached for future runs") + logger.info( + "Measuring network, CPU, and GPU throughput. " "This takes about a minute and will be cached for future runs" + ) # We measure throughput in "(inference) requests per second" (RPS) using a fixed model - config = BloomConfig.from_pretrained('bigscience/test-bloomd-6b3') + config = BloomConfig.from_pretrained("bigscience/test-bloomd-6b3") network_rps = measure_network_rps(config) - device_rps = {'cpu': measure_device_rps('cpu', config)} + device_rps = {"cpu": measure_device_rps("cpu", config)} if torch.cuda.is_available(): - device_rps['cuda'] = measure_device_rps('cuda', config) + device_rps["cuda"] = measure_device_rps("cuda", config) return ThroughputInfo(network_rps=network_rps, device_rps=device_rps) def measure_network_rps(config: BloomConfig) -> float: - proc = subprocess.run([SPEED_TEST_PATH, '--json'], capture_output=True) + proc = subprocess.run([SPEED_TEST_PATH, "--json"], capture_output=True) if proc.returncode != 0: raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})") network_info = json.loads(proc.stdout) bits_per_request = config.hidden_size * 32 - network_rps = min(network_info['download'], network_info['upload']) / bits_per_request + network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request logger.info( f"Network throughput: " @@ -120,7 +121,7 @@ def measure_device_rps(device: str, config: BloomConfig, layer_index: int = 0, n elapsed += time.perf_counter() - start_time device_rps = n_steps / elapsed - device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == 'cuda' else 'CPU' + device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == "cuda" else "CPU" logger.info(f"Compute throughput ({device_name}): {device_rps:.2f} RPS") return device_rps diff --git a/tests/test.id b/tests/test.id new file mode 100644 index 0000000..2806712 Binary files /dev/null and b/tests/test.id differ diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index 2894fad..1a0caa6 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -3,6 +3,7 @@ import os import hivemind import torch +import transformers from src.bloom.from_pretrained import load_pretrained_block from src.client.remote_block import RemoteTransformerBlock @@ -19,16 +20,18 @@ if not BLOCK_UID: raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested") REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3") -REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID[-1].split(".")[-1])) +REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID.split(".")[-1])) def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3): dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) + remote_block = get_remote_module(dht, BLOCK_UID) assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT" assert isinstance(remote_block, RemoteTransformerBlock) + ref_config = transformers.AutoConfig.from_pretrained(REF_NAME) - inputs = torch.randn(1, 8, 4096) + inputs = torch.randn(1, 8, ref_config.hidden_size) (outputs_forward,) = remote_block(inputs) outputs_inference = [] diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py new file mode 100644 index 0000000..140d58d --- /dev/null +++ b/tests/test_chained_calls.py @@ -0,0 +1,97 @@ +###### +# Warning:torch this test is a work in progress. It will be modified soon. +# - if you want more stable tests, see test_block_exact_match +# - if you want to figure out chained inference, ask yozh + +import os + +import hivemind +import torch +import transformers +from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo + +from src.bloom.from_pretrained import load_pretrained_block +from src.client.remote_block import RemoteTransformerBlock +from src.dht_utils import get_remote_module + +INITIAL_PEERS = os.environ.get("INITIAL_PEERS") +if not INITIAL_PEERS: + raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids") +INITIAL_PEERS = INITIAL_PEERS.split() + + +MODEL_NAME = os.environ.get("MODEL_NAME") +if not MODEL_NAME: + raise RuntimeError("Must specify MODEL_NAME as a name of a model to be tested") + +REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3") + + +def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1): + dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) + config = transformers.AutoConfig.from_pretrained(MODEL_NAME) + remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0") + assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT" + assert isinstance(remote_block, RemoteTransformerBlock) + + _ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info + remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id) + + ref_blocks = [ + load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32), + load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32), + load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32), + ] + inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True) + outputs_rpc = remote_block.forward(inputs)[0] + outputs_rpc.sum().backward() + grads_rpc = inputs.grad + + inputs.grad = None + hidden_states = inputs + for ref_block in ref_blocks: + hidden_states = ref_block.forward(hidden_states)[0] + outputs_ref = hidden_states + outputs_ref.sum().backward() + grads_ref = inputs.grad + + assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward) + assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward) + + +def test_chained_inference_exact_match(atol_inference=1e-4): + dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) + config = transformers.AutoConfig.from_pretrained(MODEL_NAME) + remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0") + assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT" + assert isinstance(remote_block, RemoteTransformerBlock) + + _ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info + remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id) + + inputs = torch.randn(1, 8, config.hidden_size) + + outputs_inference = [] + with remote_block.inference_session() as sess: + for i in range(inputs.shape[1]): + outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) + outputs_inference = torch.cat(outputs_inference, dim=1) + + ref_blocks = [ + load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32), + load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32), + ] + outputs_ref = [] + caches = [None, None] + for i in range(inputs.shape[1]): + new_caches = [] + hidden_states = inputs[:, i : i + 1, :] + for ref_block, cache in zip(ref_blocks, caches): + with torch.no_grad(): + hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache) + new_caches.append(new_cache) + + outputs_ref.append(hidden_states) + caches = new_caches + outputs_ref = torch.cat(outputs_ref, dim=1) + assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference) diff --git a/tests/test_chained_forward_backward.py b/tests/test_chained_forward_backward.py deleted file mode 100644 index 0a1b726..0000000 --- a/tests/test_chained_forward_backward.py +++ /dev/null @@ -1,59 +0,0 @@ -###### -# Warning:torch this test is a work in progress. It will be modified soon. -# - if you want more stable tests, see test_block_exact_match -# - if you want to figure out chained inference, ask yozh - -import os - -import hivemind -import torch -from hivemind.moe.expert_uid import ExpertInfo - -from src.bloom.from_pretrained import load_pretrained_block -from src.client.remote_block import RemoteTransformerBlock -from src.dht_utils import get_remote_module - -INITIAL_PEERS = os.environ.get("INITIAL_PEERS") -if not INITIAL_PEERS: - raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids") -INITIAL_PEERS = INITIAL_PEERS.split() - - -BLOCK_UID = os.environ.get("BLOCK_UID") -if not BLOCK_UID: - raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested") - -REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3") - - -# seq_length > 128: rpc_forward_stream & rpc_backward_stream -# seq_length <= 128: rpc_forward & rpc_backward -def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1): - dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) - (remote_block,) = get_remote_module(dht, BLOCK_UID) - assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT" - assert isinstance(remote_block, RemoteTransformerBlock) - - _ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info - remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4 bloom6b3.5", remote_block._info.peer_id) - - ref_blocks = [ - load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32), - load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32), - load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32), - ] - inputs = torch.randn(1, seq_length, 4096, requires_grad=True) - outputs_rpc = remote_block.forward(inputs)[0] - outputs_rpc.sum().backward() - grads_rpc = inputs.grad - - inputs.grad = None - hidden_states = inputs - for ref_block in ref_blocks: - hidden_states = ref_block.forward(hidden_states)[0] - outputs_ref = hidden_states - outputs_ref.sum().backward() - grads_ref = inputs.grad - - assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward) - assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward) diff --git a/tests/test_chained_inference.py b/tests/test_chained_inference.py deleted file mode 100644 index 4d10079..0000000 --- a/tests/test_chained_inference.py +++ /dev/null @@ -1,64 +0,0 @@ -###### -# Warning:torch this test is a work in progress. It will be modified soon. -# - if you want more stable tests, see test_block_exact_match -# - if you want to figure out chained inference, ask yozh - -import os - -import hivemind -import torch -from hivemind.moe.expert_uid import ExpertInfo - -from src.bloom.from_pretrained import load_pretrained_block -from src.client.remote_block import RemoteTransformerBlock -from src.dht_utils import get_remote_module - -INITIAL_PEERS = os.environ.get("INITIAL_PEERS") -if not INITIAL_PEERS: - raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids") -INITIAL_PEERS = INITIAL_PEERS.split() - - -BLOCK_UID = os.environ.get("BLOCK_UID") -if not BLOCK_UID: - raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested") - -REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3") -REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID[-1].split(".")[-1])) - - -def test_remote_block_exact_match(atol_inference=1e-4): - dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) - remote_block = get_remote_module(dht, BLOCK_UID) - assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT" - assert isinstance(remote_block, RemoteTransformerBlock) - - _ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info - remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4", remote_block._info.peer_id) - - inputs = torch.randn(1, 8, 4096) - - outputs_inference = [] - with remote_block.inference_session() as sess: - for i in range(inputs.shape[1]): - outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) - outputs_inference = torch.cat(outputs_inference, dim=1) - - ref_blocks = [ - load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32), - load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32), - ] - outputs_ref = [] - caches = [None, None] - for i in range(inputs.shape[1]): - new_caches = [] - hidden_states = inputs[:, i : i + 1, :] - for ref_block, cache in zip(ref_blocks, caches): - with torch.no_grad(): - hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache) - new_caches.append(new_cache) - - outputs_ref.append(hidden_states) - caches = new_caches - outputs_ref = torch.cat(outputs_ref, dim=1) - assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference) diff --git a/tests/test_full_model.py b/tests/test_full_model.py index 657b965..5a60365 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -24,9 +24,10 @@ if not MODEL_NAME: REF_NAME = os.environ.get("REF_NAME") -def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="bloom6b3"): +def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3): tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + assert isinstance(model, DistributedBloomForCausalLM) assert len(model.transformer.h) == model.config.n_layer test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] @@ -35,26 +36,29 @@ def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix=" logger.info("Forward outputs are finite") if REF_NAME: - ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME) - dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool) - # note: this creates a dummy mask to make the test compatible with older transformer versions - # prior to https://github.com/huggingface/transformers/pull/17837 - ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits - assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward) + with torch.no_grad(): + ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME) + dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool) + # note: this creates a dummy mask to make the test compatible with older transformer versions + # prior to https://github.com/huggingface/transformers/pull/17837 + ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits + assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward) + del ref_model, ref_outputs else: logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set") - embs = model.transformer.word_embeddings(test_inputs) - embs = model.transformer.word_embeddings_layernorm(embs) - recurrent_outputs = [] - with model.transformer.h.inference_session() as sess: - for t in range(embs.shape[1]): - recurrent_outputs.append(sess.step(embs[:, t : t + 1, :])) - recurrent_outputs = torch.cat(recurrent_outputs, dim=1) - recurrent_outputs = model.transformer.ln_f(recurrent_outputs) - - dictionary = model.transformer.word_embeddings.weight.t() - recurrent_outputs = recurrent_outputs.to(dictionary.dtype) - recurrent_outputs = (recurrent_outputs @ dictionary).float() + with torch.inference_mode(): + embs = model.transformer.word_embeddings(test_inputs) + embs = model.transformer.word_embeddings_layernorm(embs) + recurrent_outputs = [] + with model.transformer.h.inference_session() as sess: + for t in range(embs.shape[1]): + recurrent_outputs.append(sess.step(embs[:, t : t + 1, :])) + recurrent_outputs = torch.cat(recurrent_outputs, dim=1) + recurrent_outputs = model.transformer.ln_f(recurrent_outputs) + + dictionary = model.transformer.word_embeddings.weight.t() + recurrent_outputs = recurrent_outputs.to(dictionary.dtype) + recurrent_outputs = (recurrent_outputs @ dictionary).float() assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference) logger.info("Inference is consistent with forward")