From 2116df08bcbdaff48d185af3554f15582d615d45 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 29 Mar 2023 04:21:37 +0400 Subject: [PATCH] Fix deps, enable 8-bit by default for TP (#298) This PR fixes issues of #290: - hivemind bfloat16 codec crashed on dummy tensors (with 0 elements), see https://github.com/learning-at-home/hivemind/pull/560 (this PR makes Petals depend on the latest hivemind version from the repo, it's temporary) - transformers version check mismatched with the version allowed in `setup.cfg` Also: - This PR enables 8-bit by default for TP. Even though TP in 8-bit may be slower, we currently prefer to host more blocks to increase the network's stability. --- setup.cfg | 2 +- src/petals/bloom/block.py | 4 ++-- src/petals/server/server.py | 6 ------ 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/setup.cfg b/setup.cfg index ba3bedc..c485cd5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = huggingface-hub==0.11.1 transformers>=4.25.1,<5.0.0 speedtest-cli==2.1.3 - hivemind==1.1.6 + hivemind @ git+https://github.com/learning-at-home/hivemind.git tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 diff --git a/src/petals/bloom/block.py b/src/petals/bloom/block.py index 78171cf..9037ee4 100644 --- a/src/petals/bloom/block.py +++ b/src/petals/bloom/block.py @@ -13,8 +13,8 @@ from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _ if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): assert ( - version.parse("4.26.0") < version.parse(transformers.__version__) < version.parse("5.0.0") - ), "Please install a proper transformers version: pip install transformers>=4.26.0,<5.0.0" + version.parse("4.25.1") <= version.parse(transformers.__version__) < version.parse("5.0.0") + ), "Please install a proper transformers version: pip install transformers>=4.25.1,<5.0.0" class WrappedBloomBlock(BloomBlock): diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 29e9d6b..4563e28 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -163,12 +163,6 @@ class Server: if load_in_8bit is None: load_in_8bit = device.type == "cuda" - if load_in_8bit and len(self.tensor_parallel_devices) > 1: - load_in_8bit = False - logger.warning( - "Tensor parallelism doesn't work properly with 8-bit weights yet, loading weights in 16-bit. " - "You can explicitly set `--load_in_8bit True` to override this" - ) self.load_in_8bit = load_in_8bit logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")