Fix psutil-related AccessDenied crash, disable --load_in_8bit by default in case of TP (#188)

* Don't count open fds since it leads to AccessDenied crashes on some machines
* Use --load_in_8bit=False by default in case of tensor parallelism
* Install petals from PyPI in fine-tuning tutorials
pull/189/head
Alexander Borzunov 1 year ago committed by GitHub
parent 93bed7da5a
commit a617ce3cfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -36,8 +36,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install -q git+https://github.com/bigscience-workshop/petals\n", "%pip install -q petals datasets wandb"
"!pip install -q datasets wandb"
] ]
}, },
{ {

@ -36,8 +36,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install -q git+https://github.com/bigscience-workshop/petals\n", "%pip install -q petals datasets wandb"
"!pip install -q datasets wandb"
] ]
}, },
{ {

@ -9,7 +9,6 @@ import time
from typing import Dict, List, Optional, Sequence, Union from typing import Dict, List, Optional, Sequence, Union
import numpy as np import numpy as np
import psutil
import torch import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.layers import add_custom_models_from_file
@ -28,7 +27,7 @@ from petals.server.block_utils import get_block_size
from petals.server.handler import TransformerConnectionHandler from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache from petals.server.memory_cache import MemoryCache
from petals.server.reachability import check_reachability from petals.server.reachability import check_reachability
from petals.server.throughput import get_host_throughput from petals.server.throughput import get_dtype_name, get_host_throughput
from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.convert_block import check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR from petals.utils.disk_cache import DEFAULT_CACHE_DIR
@ -146,12 +145,6 @@ class Server:
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
self.torch_dtype = torch_dtype self.torch_dtype = torch_dtype
if load_in_8bit is None:
load_in_8bit = device.type == "cuda"
if load_in_8bit:
logger.info("Model weights will be loaded in 8-bit format")
self.load_in_8bit = load_in_8bit
if tensor_parallel_devices is None: if tensor_parallel_devices is None:
tensor_parallel_devices = (device,) tensor_parallel_devices = (device,)
self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices)) self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices))
@ -159,6 +152,17 @@ class Server:
logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}") logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}")
check_device_balance(self.tensor_parallel_devices) check_device_balance(self.tensor_parallel_devices)
if load_in_8bit is None:
load_in_8bit = device.type == "cuda"
if load_in_8bit and len(self.tensor_parallel_devices) > 1:
load_in_8bit = False
logger.warning(
"Tensor parallelism doesn't work properly with 8-bit weights yet, loading weights in 16-bit. "
"You can explicitly set `--load_in_8bit True` to override this"
)
self.load_in_8bit = load_in_8bit
logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
if num_blocks is None and block_indices is None: if num_blocks is None and block_indices is None:
num_blocks = self._choose_num_blocks() num_blocks = self._choose_num_blocks()
@ -167,8 +171,7 @@ class Server:
first_block_index, last_block_index = block_indices.split(":") first_block_index, last_block_index = block_indices.split(":")
first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index))) first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
except Exception as e: except Exception as e:
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)") raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
raise
block_indices = range(first_block_index, last_block_index) block_indices = range(first_block_index, last_block_index)
num_blocks = len(block_indices) num_blocks = len(block_indices)
self.strict_block_indices, self.num_blocks = block_indices, num_blocks self.strict_block_indices, self.num_blocks = block_indices, num_blocks
@ -301,10 +304,6 @@ class Server:
del self.module_container del self.module_container
gc.collect() # In particular, this closes unused file descriptors gc.collect() # In particular, this closes unused file descriptors
cur_proc = psutil.Process()
num_fds = [proc.num_fds() for proc in [cur_proc] + cur_proc.children(recursive=True)]
logger.info(f"Cleaning up, left {sum(num_fds)} open file descriptors")
if self.device.type == "cuda": if self.device.type == "cuda":
torch.cuda.empty_cache() torch.cuda.empty_cache()

Loading…
Cancel
Save