diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index c52ad4b..bb89016 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -76,6 +76,10 @@ def load_pretrained_block( if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): param = param.to(torch_dtype) set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) + else: + cur_block = getattr(block, param_name) + dummy_value = torch.empty_like(cur_block, device="cpu") + set_module_tensor_to_device(block, param_name, "cpu", dummy_value) logger.info(f"Loaded {model_name} block {block_index}") logger.debug(f"Details: {report}")