diff --git a/benchmarks/benchmark_forward.py b/benchmarks/benchmark_forward.py new file mode 100755 index 0000000..0a7d4f8 --- /dev/null +++ b/benchmarks/benchmark_forward.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +import argparse +import multiprocessing as mp +from time import perf_counter + +import torch +from hivemind.utils.logging import get_logger + +from petals import AutoDistributedModel +from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS + +logger = get_logger() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="bigscience/bloom") + parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) + parser.add_argument("--torch_dtype", type=str, default="bfloat16") + parser.add_argument("--n_processes", type=str, default=1) + parser.add_argument("--seq_len", type=int, default=128) + parser.add_argument("--n_steps", type=int, default=100) + parser.add_argument("--batch_size", type=int, required=True) + parser.add_argument("--warmup_steps", type=int, default=1) + args = parser.parse_args() + + if args.n_processes == "n_gpus": + args.n_processes = torch.cuda.device_count() + else: + args.n_processes = int(args.n_processes) + + processes = [mp.Process(target=benchmark_forward, args=(i, args)) for i in range(args.n_processes)] + for proc in processes: + proc.start() + for proc in processes: + proc.join() + + +@torch.inference_mode() +def benchmark_forward(process_idx, args): + model = AutoDistributedModel.from_pretrained( + args.model, + initial_peers=args.initial_peers, + torch_dtype=DTYPE_MAP[args.torch_dtype], + ) + logger.info(f"Created model: {process_idx=} {model.device=}") + + torch.manual_seed(42) + for step in range(args.n_steps): + if step == args.warmup_steps: + start_time = perf_counter() + + input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len)) + + logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}") + h = model(input_ids) + # We don't use model.lm_head + logger.info(f"{process_idx=} Fwd end") + + if step >= args.warmup_steps: + speed = step / (perf_counter() - start_time) * input_ids.numel() + logger.info(f"{process_idx=} {step=} {speed=:.3f}") + + logger.info(f"Final result: {process_idx=} {speed=:.3f}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_inference.py b/benchmarks/benchmark_inference.py new file mode 100755 index 0000000..7b5f0e1 --- /dev/null +++ b/benchmarks/benchmark_inference.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 + +import argparse +import multiprocessing as mp +from time import perf_counter + +import torch +from hivemind.utils.logging import get_logger +from transformers import AutoTokenizer + +from petals import AutoDistributedModelForCausalLM +from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS + +logger = get_logger() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="bigscience/bloom") + parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) + parser.add_argument("--torch_dtype", type=str, default="bfloat16") + parser.add_argument("--n_processes", type=str, default=1) + parser.add_argument("--seq_len", type=int, default=2048) + parser.add_argument("--warmup_steps", type=int, default=1) + args = parser.parse_args() + + if args.n_processes == "n_gpus": + args.n_processes = torch.cuda.device_count() + else: + args.n_processes = int(args.n_processes) + + processes = [mp.Process(target=benchmark_inference, args=(i, args)) for i in range(args.n_processes)] + for proc in processes: + proc.start() + for proc in processes: + proc.join() + + +@torch.inference_mode() +def benchmark_inference(process_idx, args): + tokenizer = AutoTokenizer.from_pretrained(args.model) + model = AutoDistributedModelForCausalLM.from_pretrained( + args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype] + ) + logger.info(f"Created model: {process_idx=} {model.device=} {model.config.torch_dtype=}") + + result = "" + with model.transformer.h.inference_session(max_length=args.seq_len) as sess: + for step in range(args.seq_len): + if step == args.warmup_steps: + start_time = perf_counter() + + outputs = model.generate(max_new_tokens=1, session=sess) + result += tokenizer.decode(outputs[0]) + + if step >= args.warmup_steps: + speed = step / (perf_counter() - start_time) + logger.info(f"{process_idx=} {step=} {speed=:.3f}") + + logger.info(f"Final result: {process_idx=} {speed=:.3f}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_training.py b/benchmarks/benchmark_training.py new file mode 100755 index 0000000..46d0eb2 --- /dev/null +++ b/benchmarks/benchmark_training.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 + +import argparse +import multiprocessing as mp +from time import perf_counter + +import numpy as np +import torch +from hivemind.utils.logging import get_logger + +from petals import AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification +from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS + +logger = get_logger() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="bigscience/bloom") + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--task", type=str, default="cls") + parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) + parser.add_argument("--torch_dtype", type=str, default="bfloat16") + parser.add_argument("--n_processes", type=str, default=1) + parser.add_argument("--seq_len", type=int, default=128) + parser.add_argument("--pre_seq_len", type=int, default=16) + parser.add_argument("--n_steps", type=int, default=10) + parser.add_argument("--batch_size", type=int, required=True) + parser.add_argument("--warmup_steps", type=int, default=1) + args = parser.parse_args() + + assert args.task in ["cls", "causal_lm"] + + if args.n_processes == "n_gpus": + args.n_processes = torch.cuda.device_count() + else: + args.n_processes = int(args.n_processes) + + processes = [mp.Process(target=benchmark_training, args=(i, args)) for i in range(args.n_processes)] + for proc in processes: + proc.start() + for proc in processes: + proc.join() + + +def benchmark_training(process_idx, args): + if args.task == "cls": + model = AutoDistributedModelForSequenceClassification.from_pretrained( + args.model, + initial_peers=args.initial_peers, + torch_dtype=DTYPE_MAP[args.torch_dtype], + tuning_mode="deep_ptune", + pre_seq_len=args.pre_seq_len, + num_labels=2, + ) + elif args.task == "causal_lm": + model = AutoDistributedModelForCausalLM.from_pretrained( + args.model, + initial_peers=args.initial_peers, + torch_dtype=DTYPE_MAP[args.torch_dtype], + tuning_mode="deep_ptune", + pre_seq_len=args.pre_seq_len, + ) + model = model.to(args.device) + opt = torch.optim.Adam(model.parameters()) + logger.info(f"Created model: {process_idx=} {model.device=}") + + torch.manual_seed(42) + fwd_times = [] + bwd_times = [] + for step in range(args.n_steps): + input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device) + if args.task == "cls": + labels = torch.randint(0, 2, size=[args.batch_size], device=args.device) + else: + labels = input_ids + + logger.info(f"{process_idx=} {step=} Forward") + start_time = perf_counter() + outputs = model(input_ids, labels=labels) + fwd_times.append(perf_counter() - start_time) + + logger.info(f"{process_idx=} {step=} Backward") + start_time = perf_counter() + outputs.loss.backward() + bwd_times.append(perf_counter() - start_time) + + logger.info(f"{process_idx=} {step=} Optimizer step") + opt.step() + opt.zero_grad() + + if step >= args.warmup_steps: + fwd_speed = input_ids.numel() / np.mean(fwd_times[1:]) + bwd_speed = input_ids.numel() / np.mean(bwd_times[1:]) + logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}") + + logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}") + + +if __name__ == "__main__": + main() diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 4c6f0e5..83e35e5 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -6,8 +6,8 @@ from hivemind.utils.limits import increase_file_limit from hivemind.utils.logging import get_logger from humanfriendly import parse_size -from petals.constants import PUBLIC_INITIAL_PEERS -from petals.server.server import DTYPE_MAP, Server +from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS +from petals.server.server import Server from petals.utils.version import validate_version logger = get_logger(__name__) diff --git a/src/petals/constants.py b/src/petals/constants.py index da047f1..b04ad03 100644 --- a/src/petals/constants.py +++ b/src/petals/constants.py @@ -1,3 +1,5 @@ +import torch + PUBLIC_INITIAL_PEERS = [ "/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", "/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", @@ -7,3 +9,5 @@ PUBLIC_INITIAL_PEERS = [ # The reachability API is currently used only when connecting to the public swarm REACHABILITY_API_URL = "http://health.petals.ml" + +DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py index e4961d3..7644148 100644 --- a/src/petals/models/bloom/model.py +++ b/src/petals/models/bloom/model.py @@ -128,7 +128,7 @@ class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSeq self.num_labels = config.num_labels self.transformer = DistributedBloomModel(config) - self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype) + self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index aab8a9e..62b9959 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -19,6 +19,7 @@ from huggingface_hub import get_hf_file_metadata, hf_hub_url from transformers import PretrainedConfig from transformers.utils import get_file_from_repo +from petals.constants import DTYPE_MAP from petals.server.block_utils import resolve_block_dtype from petals.utils.auto_config import AutoDistributedConfig from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for @@ -170,6 +171,3 @@ def _load_state_dict_from_file( except Exception as e: logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True) time.sleep(delay) - - -DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 75a999e..39c432c 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -16,13 +16,13 @@ from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.logging import get_logger from transformers import PretrainedConfig -from petals.constants import PUBLIC_INITIAL_PEERS +from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from petals.dht_utils import declare_active_modules, get_remote_module_infos from petals.server import block_selection from petals.server.backend import TransformerBackend, merge_inference_pools_inplace from petals.server.block_utils import get_block_size, resolve_block_dtype -from petals.server.from_pretrained import DTYPE_MAP, load_pretrained_block +from petals.server.from_pretrained import load_pretrained_block from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability