|
|
|
@ -4,6 +4,7 @@ import argparse
|
|
|
|
|
import multiprocessing as mp
|
|
|
|
|
from time import perf_counter
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
@ -38,26 +39,29 @@ def main():
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
|
def benchmark_inference(process_idx, args):
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
|
|
|
|
|
# Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway
|
|
|
|
|
|
|
|
|
|
model = AutoDistributedModelForCausalLM.from_pretrained(
|
|
|
|
|
args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"Created model: {process_idx=} {model.device=} {model.config.torch_dtype=}")
|
|
|
|
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
|
|
|
|
|
|
|
|
|
result = ""
|
|
|
|
|
step_times = []
|
|
|
|
|
with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
|
|
|
|
|
for step in range(args.seq_len):
|
|
|
|
|
if step == args.warmup_steps:
|
|
|
|
|
start_time = perf_counter()
|
|
|
|
|
start_time = perf_counter()
|
|
|
|
|
|
|
|
|
|
outputs = model.generate(max_new_tokens=1, session=sess)
|
|
|
|
|
result += tokenizer.decode(outputs[0])
|
|
|
|
|
|
|
|
|
|
if step >= args.warmup_steps:
|
|
|
|
|
speed = step / (perf_counter() - start_time)
|
|
|
|
|
logger.info(f"{process_idx=} {step=} {speed=:.3f}")
|
|
|
|
|
step_times.append(perf_counter() - start_time)
|
|
|
|
|
speed = 1 / np.mean(step_times)
|
|
|
|
|
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
|
|
|
|
|
|
|
|
|
|
logger.info(f"Final result: {process_idx=} {speed=:.3f}")
|
|
|
|
|
logger.info(f"Final result: {process_idx=} {speed=:.2f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|