Fix psutil-related AccessDenied crash, disable --load_in_8bit by default in case of TP (#188)

* Don't count open fds since it leads to AccessDenied crashes on some machines
* Use --load_in_8bit=False by default in case of tensor parallelism
* Install petals from PyPI in fine-tuning tutorials
pull/189/head
Alexander Borzunov 1 year ago committed by GitHub
parent 93bed7da5a
commit a617ce3cfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -36,8 +36,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install -q git+https://github.com/bigscience-workshop/petals\n",
"!pip install -q datasets wandb"
"%pip install -q petals datasets wandb"
]
},
{

@ -36,8 +36,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install -q git+https://github.com/bigscience-workshop/petals\n",
"!pip install -q datasets wandb"
"%pip install -q petals datasets wandb"
]
},
{

@ -9,7 +9,6 @@ import time
from typing import Dict, List, Optional, Sequence, Union
import numpy as np
import psutil
import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file
@ -28,7 +27,7 @@ from petals.server.block_utils import get_block_size
from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache
from petals.server.reachability import check_reachability
from petals.server.throughput import get_host_throughput
from petals.server.throughput import get_dtype_name, get_host_throughput
from petals.utils.convert_block import check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
@ -146,12 +145,6 @@ class Server:
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
self.torch_dtype = torch_dtype
if load_in_8bit is None:
load_in_8bit = device.type == "cuda"
if load_in_8bit:
logger.info("Model weights will be loaded in 8-bit format")
self.load_in_8bit = load_in_8bit
if tensor_parallel_devices is None:
tensor_parallel_devices = (device,)
self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices))
@ -159,6 +152,17 @@ class Server:
logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}")
check_device_balance(self.tensor_parallel_devices)
if load_in_8bit is None:
load_in_8bit = device.type == "cuda"
if load_in_8bit and len(self.tensor_parallel_devices) > 1:
load_in_8bit = False
logger.warning(
"Tensor parallelism doesn't work properly with 8-bit weights yet, loading weights in 16-bit. "
"You can explicitly set `--load_in_8bit True` to override this"
)
self.load_in_8bit = load_in_8bit
logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
if num_blocks is None and block_indices is None:
num_blocks = self._choose_num_blocks()
@ -167,8 +171,7 @@ class Server:
first_block_index, last_block_index = block_indices.split(":")
first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
except Exception as e:
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
raise
raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
block_indices = range(first_block_index, last_block_index)
num_blocks = len(block_indices)
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
@ -301,10 +304,6 @@ class Server:
del self.module_container
gc.collect() # In particular, this closes unused file descriptors
cur_proc = psutil.Process()
num_fds = [proc.num_fds() for proc in [cur_proc] + cur_proc.children(recursive=True)]
logger.info(f"Cleaning up, left {sum(num_fds)} open file descriptors")
if self.device.type == "cuda":
torch.cuda.empty_cache()

Loading…
Cancel
Save