Fix deps, enable 8-bit by default for TP (#298)

This PR fixes issues of #290:

- hivemind bfloat16 codec crashed on dummy tensors (with 0 elements), see https://github.com/learning-at-home/hivemind/pull/560 (this PR makes Petals depend on the latest hivemind version from the repo, it's temporary)
- transformers version check mismatched with the version allowed in `setup.cfg`

Also:

- This PR enables 8-bit by default for TP. Even though TP in 8-bit may be slower, we currently prefer to host more blocks to increase the network's stability.
pull/299/head
Alexander Borzunov 1 year ago committed by GitHub
parent 987f4d2b2f
commit 2116df08bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,7 +37,7 @@ install_requires =
huggingface-hub==0.11.1
transformers>=4.25.1,<5.0.0
speedtest-cli==2.1.3
hivemind==1.1.6
hivemind @ git+https://github.com/learning-at-home/hivemind.git
tensor_parallel==1.0.23
humanfriendly
async-timeout>=4.0.2

@ -13,8 +13,8 @@ from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.26.0") < version.parse(transformers.__version__) < version.parse("5.0.0")
), "Please install a proper transformers version: pip install transformers>=4.26.0,<5.0.0"
version.parse("4.25.1") <= version.parse(transformers.__version__) < version.parse("5.0.0")
), "Please install a proper transformers version: pip install transformers>=4.25.1,<5.0.0"
class WrappedBloomBlock(BloomBlock):

@ -163,12 +163,6 @@ class Server:
if load_in_8bit is None:
load_in_8bit = device.type == "cuda"
if load_in_8bit and len(self.tensor_parallel_devices) > 1:
load_in_8bit = False
logger.warning(
"Tensor parallelism doesn't work properly with 8-bit weights yet, loading weights in 16-bit. "
"You can explicitly set `--load_in_8bit True` to override this"
)
self.load_in_8bit = load_in_8bit
logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")

Loading…
Cancel
Save