Compare commits
128 Commits
Author | SHA1 | Date |
---|---|---|
Priyanshupareek | e268c99a6b | 2 months ago |
Artem Chumachenko | 30f522d1a0 | 2 months ago |
Artem Chumachenko | d6f4f80f3f | 2 months ago |
Artem Chumachenko | d2fcbbc72e | 3 months ago |
justheuristic | 2ad0b2b936 | 3 months ago |
justheuristic | efee5d1fa8 | 3 months ago |
Denis Mazur | 0d91bbdac3 | 4 months ago |
justheuristic | d59c15c578 | 7 months ago |
Max Ryabinin | 03cbe90234 | 7 months ago |
justheuristic | 25a0796b39 | 7 months ago |
justheuristic | dcce43670f | 7 months ago |
Alexander Borzunov | 82a97d6e9e | 8 months ago |
Alexander Borzunov | 47d50e1e29 | 8 months ago |
Max Ryabinin | ae19b65095 | 8 months ago |
Alexander Borzunov | 1d9401ddce | 9 months ago |
FYY | a2484b3053 | 9 months ago |
Alexander Borzunov | 5ce4f1a159 | 9 months ago |
Alexander Borzunov | 158621677b | 9 months ago |
Max Ryabinin | 1ebd88ae7b | 9 months ago |
Alexander Borzunov | d40eb6c701 | 9 months ago |
Alexander Borzunov | dd4a3230bc | 9 months ago |
Alexander Borzunov | b4d822afb2 | 9 months ago |
Alexander Borzunov | abd547735f | 9 months ago |
Alexander Borzunov | 6ef6bf5fa2 | 10 months ago |
Alexander Borzunov | 6bb3f54e39 | 10 months ago |
Alexander Borzunov | 02fc71eb25 | 10 months ago |
Alexander Borzunov | dc0072fde1 | 10 months ago |
Alexander Borzunov | a26559ff65 | 10 months ago |
Alexander Borzunov | 459933f846 | 10 months ago |
Alexander Borzunov | 26ebbfe8f0 | 10 months ago |
Alexander Borzunov | 75e516a8c1 | 10 months ago |
justheuristic | c08d09c4d3 | 10 months ago |
Alexander Borzunov | 90840dfea2 | 10 months ago |
Alexander Borzunov | 915b357740 | 10 months ago |
Alexander Borzunov | 18e93afc73 | 10 months ago |
Alexander Borzunov | 6967904590 | 10 months ago |
Alexander Borzunov | df8ab09ca2 | 10 months ago |
Artem Chumachenko | a14ae7334d | 10 months ago |
Alexander Borzunov | a9b0e9ff1a | 10 months ago |
justheuristic | 4f850996bb | 10 months ago |
justheuristic | 9250025140 | 10 months ago |
justheuristic | adda5f8c20 | 10 months ago |
Alexander Borzunov | de2475f31c | 10 months ago |
Alexander Borzunov | 063e94b4c8 | 10 months ago |
Artem Chumachenko | 568f21dc3b | 10 months ago |
Alexander Borzunov | 329f7d31e8 | 10 months ago |
Alexander Borzunov | 722c4dc496 | 10 months ago |
Alexander Borzunov | 056f22515a | 10 months ago |
justheuristic | 55eb36ef48 | 10 months ago |
Alexander Borzunov | 0e7189b3ed | 10 months ago |
Alexander Borzunov | 8c546d988a | 10 months ago |
Alexander Borzunov | df6fdd2d0b | 10 months ago |
Alexander Borzunov | 2a150770a4 | 10 months ago |
Alexander Borzunov | 00d48dcbe1 | 10 months ago |
justheuristic | ac9b546706 | 10 months ago |
Alexander Borzunov | 593d980ad8 | 10 months ago |
Alexander Borzunov | 32fbab5192 | 10 months ago |
Alexander Borzunov | b58141ef66 | 10 months ago |
Alexander Borzunov | 679397df0c | 10 months ago |
Vadim Peretokin | d0b5af34cd | 10 months ago |
Alexander Borzunov | a1f7791d5e | 11 months ago |
Alexander Borzunov | 351e96bc46 | 11 months ago |
Alexander Borzunov | 6a1b8a6a90 | 11 months ago |
Alexander Borzunov | 44fefa5e54 | 11 months ago |
Alexander Borzunov | cdc0f70653 | 11 months ago |
Guocheng | 8072cd9d1b | 11 months ago |
Alexander Borzunov | f3fafd14a4 | 11 months ago |
Alexander Borzunov | fd19c21859 | 11 months ago |
Alexander Borzunov | ffb20b585c | 11 months ago |
Alexander Borzunov | 48c6b6d963 | 11 months ago |
Alexander Borzunov | c153cba1fa | 11 months ago |
justheuristic | 5af04524dd | 11 months ago |
Alexander Borzunov | 30b94ef18b | 11 months ago |
Alexander Borzunov | 8666653cf5 | 11 months ago |
Alexander Borzunov | eb0664b993 | 11 months ago |
Alexander Borzunov | 6e4ebb94d2 | 11 months ago |
Alexander Borzunov | b6b3ae964f | 11 months ago |
Alexander Borzunov | d49d9ad0cf | 11 months ago |
justheuristic | e51e84631d | 11 months ago |
Aleksandr Borzunov | ddcda02b06 | 11 months ago |
Alexander Borzunov | b1ff8bdd6c | 11 months ago |
Alexander Borzunov | e9a20e7e53 | 11 months ago |
Alexander Borzunov | 057a2fb5de | 11 months ago |
Alexander Borzunov | 3218534745 | 11 months ago |
justheuristic | 398a384075 | 11 months ago |
justheuristic | 5a8de2f1f8 | 11 months ago |
Alexander Borzunov | 895327a0ae | 11 months ago |
Alexander Borzunov | c735dd7ba3 | 11 months ago |
justheuristic | 1ab35c2826 | 11 months ago |
Alexander Borzunov | a6fdfc0556 | 11 months ago |
Alexander Borzunov | f97582fb5f | 11 months ago |
Alexander Borzunov | 3b300c32e4 | 11 months ago |
Alexander Borzunov | 62d9ed5ce7 | 11 months ago |
Ikko Eltociear Ashimine | fd30f7ce10 | 11 months ago |
Alexander Borzunov | 11f0d992d7 | 11 months ago |
Alexander Borzunov | 9517dd1e3d | 11 months ago |
Alexander Borzunov | 3f733a96e3 | 11 months ago |
Alexander Borzunov | 81c4a45ca2 | 11 months ago |
Alexander Borzunov | 2c8959e713 | 11 months ago |
justheuristic | 37fdcb3fe0 | 11 months ago |
Alexander Borzunov | 9703358df0 | 11 months ago |
Alexander Borzunov | 1a78638c02 | 11 months ago |
justheuristic | c511990236 | 11 months ago |
Alexander Borzunov | e12d4c666b | 11 months ago |
justheuristic | 010857a834 | 11 months ago |
Alexander Borzunov | f605f093f7 | 11 months ago |
Alexander Borzunov | 90fbaab61e | 11 months ago |
Alexander Borzunov | 43acfe52a7 | 11 months ago |
Alexander Borzunov | 294970fe18 | 11 months ago |
Alexander Borzunov | 515a5120cb | 11 months ago |
Max Ryabinin | 13f4e3a88a | 11 months ago |
Artem Chumachenko | b9f0a5467f | 11 months ago |
Alexander Borzunov | dfc6578c8e | 11 months ago |
Alexander Borzunov | b28f5016ea | 11 months ago |
Alexander Borzunov | fa095f6461 | 11 months ago |
Alexander Borzunov | 158013a671 | 11 months ago |
Alexander Borzunov | 4d9c26fe5c | 11 months ago |
Alexander Borzunov | de930918a0 | 12 months ago |
Alexander Borzunov | 66a47c763e | 12 months ago |
Alexander Borzunov | 10c72acdf4 | 12 months ago |
Alexander Borzunov | d126ee3053 | 12 months ago |
Alexander Borzunov | fecee8c4dc | 12 months ago |
Alexander Borzunov | 47a2b1ee65 | 12 months ago |
Alexander Borzunov | 7a37513f77 | 12 months ago |
Alexander Borzunov | cb3f018f9f | 12 months ago |
Max Ryabinin | 5c0733711a | 1 year ago |
Max Ryabinin | c839173e57 | 1 year ago |
Max Ryabinin | 3e7ae5116d | 1 year ago |
@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from time import perf_counter
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
|
||||
from petals import AutoDistributedModel
|
||||
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--model", type=str, required=True, help="Model")
|
||||
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
|
||||
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
|
||||
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
|
||||
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
|
||||
parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps")
|
||||
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
|
||||
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.n_processes == "n_gpus":
|
||||
args.n_processes = torch.cuda.device_count()
|
||||
else:
|
||||
args.n_processes = int(args.n_processes)
|
||||
|
||||
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
||||
processes = [mp.Process(target=benchmark_forward, args=(i, args, pipe_send)) for i in range(args.n_processes)]
|
||||
for proc in processes:
|
||||
proc.start()
|
||||
for proc in processes:
|
||||
proc.join()
|
||||
|
||||
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
|
||||
logger.info(f"Final result: {speed=:.2f}")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def benchmark_forward(process_idx, args, result_pipe):
|
||||
model = AutoDistributedModel.from_pretrained(
|
||||
args.model,
|
||||
initial_peers=args.initial_peers,
|
||||
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||
)
|
||||
logger.info(f"Created model: {process_idx=} {model.device=}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
step_times = []
|
||||
for step in range(args.warmup_steps + args.n_steps):
|
||||
start_time = perf_counter()
|
||||
|
||||
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len))
|
||||
|
||||
logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
|
||||
h = model(input_ids)
|
||||
# We don't use model.lm_head
|
||||
logger.info(f"{process_idx=} Fwd end")
|
||||
|
||||
if step >= args.warmup_steps:
|
||||
step_times.append(perf_counter() - start_time)
|
||||
speed = input_ids.numel() / np.mean(step_times)
|
||||
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
|
||||
|
||||
result_pipe.send(speed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from time import perf_counter
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from petals import AutoDistributedModelForCausalLM
|
||||
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--model", type=str, required=True, help="Model")
|
||||
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
|
||||
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
|
||||
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
|
||||
parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length")
|
||||
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.n_processes == "n_gpus":
|
||||
args.n_processes = torch.cuda.device_count()
|
||||
else:
|
||||
args.n_processes = int(args.n_processes)
|
||||
|
||||
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
||||
processes = [mp.Process(target=benchmark_inference, args=(i, args, pipe_send)) for i in range(args.n_processes)]
|
||||
for proc in processes:
|
||||
proc.start()
|
||||
for proc in processes:
|
||||
proc.join()
|
||||
|
||||
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
|
||||
logger.info(f"Final result: {speed=:.2f}")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def benchmark_inference(process_idx, args, result_pipe):
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
|
||||
# Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway
|
||||
|
||||
model = AutoDistributedModelForCausalLM.from_pretrained(
|
||||
args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
||||
)
|
||||
logger.info(f"Created model: {process_idx=} {model.device=}")
|
||||
|
||||
result = ""
|
||||
step_times = []
|
||||
with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
|
||||
for step in range(args.seq_len):
|
||||
start_time = perf_counter()
|
||||
|
||||
outputs = model.generate(max_new_tokens=1, session=sess)
|
||||
result += tokenizer.decode(outputs[0])
|
||||
|
||||
if step >= args.warmup_steps:
|
||||
step_times.append(perf_counter() - start_time)
|
||||
speed = 1 / np.mean(step_times)
|
||||
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
|
||||
|
||||
result_pipe.send(speed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from time import perf_counter
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
|
||||
from petals import AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification
|
||||
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--model", type=str, required=True, help="Model")
|
||||
parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client")
|
||||
parser.add_argument("--task", type=str, default="cls", help="Training task type")
|
||||
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
|
||||
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
|
||||
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
|
||||
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
|
||||
parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens")
|
||||
parser.add_argument("--n_steps", type=int, default=10, help="Number of benchmark steps")
|
||||
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
|
||||
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.task in ["cls", "causal_lm"]
|
||||
|
||||
if args.n_processes == "n_gpus":
|
||||
args.n_processes = torch.cuda.device_count()
|
||||
else:
|
||||
args.n_processes = int(args.n_processes)
|
||||
|
||||
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
||||
processes = [mp.Process(target=benchmark_training, args=(i, args, pipe_send)) for i in range(args.n_processes)]
|
||||
for proc in processes:
|
||||
proc.start()
|
||||
for proc in processes:
|
||||
proc.join()
|
||||
|
||||
fwd_speed, bwd_speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)], axis=0)
|
||||
logger.info(f"Final result: {fwd_speed=:.2f} {bwd_speed=:.2f}")
|
||||
|
||||
|
||||
def benchmark_training(process_idx, args, result_pipe):
|
||||
if args.task == "cls":
|
||||
model = AutoDistributedModelForSequenceClassification.from_pretrained(
|
||||
args.model,
|
||||
initial_peers=args.initial_peers,
|
||||
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||
tuning_mode="deep_ptune",
|
||||
pre_seq_len=args.pre_seq_len,
|
||||
num_labels=2,
|
||||
)
|
||||
elif args.task == "causal_lm":
|
||||
model = AutoDistributedModelForCausalLM.from_pretrained(
|
||||
args.model,
|
||||
initial_peers=args.initial_peers,
|
||||
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||
tuning_mode="deep_ptune",
|
||||
pre_seq_len=args.pre_seq_len,
|
||||
)
|
||||
model = model.to(args.device)
|
||||
opt = torch.optim.Adam(model.parameters())
|
||||
logger.info(f"Created model: {process_idx=} {model.device=}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
fwd_times = []
|
||||
bwd_times = []
|
||||
for step in range(args.warmup_steps + args.n_steps):
|
||||
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device)
|
||||
if args.task == "cls":
|
||||
labels = torch.randint(0, 2, size=[args.batch_size], device=args.device)
|
||||
else:
|
||||
labels = input_ids
|
||||
|
||||
logger.info(f"{process_idx=} {step=} Forward")
|
||||
start_time = perf_counter()
|
||||
outputs = model(input_ids, labels=labels)
|
||||
if step >= args.warmup_steps:
|
||||
fwd_times.append(perf_counter() - start_time)
|
||||
|
||||
logger.info(f"{process_idx=} {step=} Backward")
|
||||
start_time = perf_counter()
|
||||
outputs.loss.backward()
|
||||
if step >= args.warmup_steps:
|
||||
bwd_times.append(perf_counter() - start_time)
|
||||
|
||||
logger.info(f"{process_idx=} {step=} Optimizer step")
|
||||
opt.step()
|
||||
opt.zero_grad()
|
||||
|
||||
if step >= args.warmup_steps:
|
||||
fwd_speed = input_ids.numel() / np.mean(fwd_times)
|
||||
bwd_speed = input_ids.numel() / np.mean(bwd_times)
|
||||
logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
|
||||
|
||||
result_pipe.send((fwd_speed, bwd_speed))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,62 +0,0 @@
|
||||
"""
|
||||
Bloom intermediate layer
|
||||
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
||||
See commit history for authorship.
|
||||
"""
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch.nn.quantized.dynamic.modules.linear
|
||||
import transformers
|
||||
from packaging import version
|
||||
from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor
|
||||
|
||||
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
|
||||
assert (
|
||||
version.parse("4.25.1") <= version.parse(transformers.__version__) < version.parse("5.0.0")
|
||||
), "Please install a proper transformers version: pip install transformers>=4.25.1,<5.0.0"
|
||||
|
||||
|
||||
class WrappedBloomBlock(BloomBlock):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
alibi: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs
|
||||
):
|
||||
assert attention_mask is None
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
|
||||
seq_length_with_past = seq_length + past_length
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
if alibi is None:
|
||||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||
attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
||||
return super().forward(
|
||||
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
||||
)
|
||||
|
||||
def _prepare_attn_mask(
|
||||
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
# create causal mask
|
||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
||||
combined_attention_mask = None
|
||||
device = attention_mask.device
|
||||
_, src_length = input_shape
|
||||
|
||||
if src_length > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
torch.Size(input_shape), device=device, past_key_values_length=past_key_values_length
|
||||
)
|
||||
|
||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
@ -1,131 +0,0 @@
|
||||
"""
|
||||
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
|
||||
If necessary, one can rewrite this to implement a different behavior, such as:
|
||||
- loading files from a local data source (e.g. S3)
|
||||
- load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
|
||||
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import time
|
||||
from typing import Optional, OrderedDict, Union
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_utils import WEIGHTS_NAME
|
||||
from transformers.models.bloom.configuration_bloom import BloomConfig
|
||||
from transformers.utils import get_file_from_repo
|
||||
|
||||
from petals.bloom.block import WrappedBloomBlock
|
||||
from petals.server.block_utils import get_block_size
|
||||
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
CLIENT_BRANCH = "main"
|
||||
BLOCK_BRANCH_PREFIX = "block_"
|
||||
|
||||
|
||||
def load_pretrained_block(
|
||||
converted_model_name_or_path: str,
|
||||
block_index: int,
|
||||
config: Optional[BloomConfig] = None,
|
||||
torch_dtype: Union[torch.dtype, str] = "auto",
|
||||
use_auth_token: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
max_disk_space: Optional[int] = None,
|
||||
) -> WrappedBloomBlock:
|
||||
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
|
||||
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
||||
|
||||
if config is None:
|
||||
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
||||
if cache_dir is None:
|
||||
cache_dir = DEFAULT_CACHE_DIR
|
||||
|
||||
with init_empty_weights():
|
||||
block = WrappedBloomBlock(config)
|
||||
|
||||
state_dict = _load_state_dict(
|
||||
converted_model_name_or_path,
|
||||
block_index,
|
||||
config,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
max_disk_space=max_disk_space,
|
||||
)
|
||||
|
||||
# dummy load, check that keys match
|
||||
report = block.load_state_dict(state_dict, strict=True)
|
||||
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
|
||||
|
||||
for param_name, _ in block.named_parameters():
|
||||
assert param_name in state_dict, f"{param_name} not in state dict"
|
||||
param = state_dict[param_name]
|
||||
if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
param = param.to(torch_dtype)
|
||||
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
||||
|
||||
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
|
||||
return block
|
||||
|
||||
|
||||
def _load_state_dict(
|
||||
pretrained_model_name_or_path: str,
|
||||
block_index: int,
|
||||
config: BloomConfig,
|
||||
*,
|
||||
use_auth_token: Optional[str] = None,
|
||||
cache_dir: str,
|
||||
max_disk_space: Optional[int] = None,
|
||||
min_backoff: float = 5,
|
||||
) -> OrderedDict[str, torch.Tensor]:
|
||||
revision = BLOCK_BRANCH_PREFIX + str(block_index)
|
||||
|
||||
# First, try to find the weights locally
|
||||
try:
|
||||
with allow_cache_reads(cache_dir):
|
||||
archive_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path,
|
||||
filename=WEIGHTS_NAME,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=True,
|
||||
)
|
||||
if archive_file is not None:
|
||||
return torch.load(archive_file, map_location="cpu")
|
||||
except Exception:
|
||||
logger.debug(
|
||||
f"Failed to load block {block_index} from cache. The block will be downloaded again", exc_info=True
|
||||
)
|
||||
|
||||
# If not found, ensure that we have enough disk space to download them (maybe remove something)
|
||||
for attempt_no in itertools.count():
|
||||
try:
|
||||
with allow_cache_writes(cache_dir):
|
||||
block_size = get_block_size(config, "disk")
|
||||
free_disk_space_for(
|
||||
pretrained_model_name_or_path, block_size, cache_dir=cache_dir, max_disk_space=max_disk_space
|
||||
)
|
||||
|
||||
archive_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path,
|
||||
filename=WEIGHTS_NAME,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
return torch.load(archive_file, map_location="cpu")
|
||||
except Exception as e:
|
||||
delay = min_backoff * (2**attempt_no)
|
||||
logger.warning(f"Failed to load block {block_index} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
@ -1,99 +0,0 @@
|
||||
"""
|
||||
PyTorch BLOOM model that implements several memory-efficient modes.
|
||||
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
||||
See commit history for authorship.
|
||||
"""
|
||||
|
||||
import platform
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from hivemind import get_logger
|
||||
from torch import nn
|
||||
from transformers import BloomConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LMHead(nn.Module):
|
||||
"""
|
||||
The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
|
||||
embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
|
||||
In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
|
||||
"""
|
||||
|
||||
def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
|
||||
super().__init__()
|
||||
self.word_embeddings = word_embeddings
|
||||
|
||||
self.use_chunked_forward = config.use_chunked_forward
|
||||
if self.use_chunked_forward == "auto":
|
||||
if platform.machine() == "x86_64":
|
||||
# Import of cpufeature may crash on non-x86_64 machines
|
||||
from cpufeature import CPUFeature
|
||||
|
||||
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
|
||||
# Otherwise, it's ~8x slower.
|
||||
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
|
||||
else:
|
||||
self.use_chunked_forward = True
|
||||
self.chunked_forward_step = config.chunked_forward_step
|
||||
self._bf16_warning_shown = False
|
||||
|
||||
@property
|
||||
def in_features(self) -> int:
|
||||
return self.word_embeddings.num_embeddings
|
||||
|
||||
@property
|
||||
def out_features(self) -> int:
|
||||
return self.word_embeddings.embedding_dim
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.word_embeddings.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
word_embeddings = self.word_embeddings.weight
|
||||
|
||||
if (
|
||||
word_embeddings.dtype in [torch.float16, torch.bfloat16]
|
||||
and word_embeddings.device.type == "cpu"
|
||||
and self.use_chunked_forward
|
||||
):
|
||||
lm_logits = self.chunked_forward(hidden_states)
|
||||
else:
|
||||
# Switch dtype in case word_embeddings are fp16/bf16
|
||||
hidden_states = hidden_states.to(word_embeddings.dtype)
|
||||
lm_logits = F.linear(hidden_states, word_embeddings)
|
||||
return lm_logits
|
||||
|
||||
def chunked_forward(self, hidden_states):
|
||||
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
|
||||
chunked_forward_step: provides trade-off between efficiency and extra memory consumption.
|
||||
"""
|
||||
assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
|
||||
|
||||
if not self._bf16_warning_shown:
|
||||
if self.word_embeddings.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
|
||||
logger.warning(
|
||||
"Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
|
||||
"Consider loading the model with torch_dtype='float32'"
|
||||
)
|
||||
self._bf16_warning_shown = True
|
||||
|
||||
word_embeddings = self.word_embeddings.weight
|
||||
num_embeddings = self.word_embeddings.num_embeddings
|
||||
|
||||
hidden_states = hidden_states.float()
|
||||
output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
|
||||
|
||||
for i in range(0, num_embeddings, self.chunked_forward_step):
|
||||
chunk = word_embeddings[i : i + self.chunked_forward_step].float()
|
||||
output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)
|
||||
return output
|
@ -1,20 +0,0 @@
|
||||
{
|
||||
"apply_residual_connection_post_layernorm": false,
|
||||
"attention_dropout": 0.0,
|
||||
"attention_softmax_in_fp32": true,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_dropout": 0.0,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"masked_softmax_fusion": true,
|
||||
"model_type": "bloom",
|
||||
"n_embed": 14336,
|
||||
"n_layer": 70,
|
||||
"num_attention_heads": 112,
|
||||
"pretraining_tp": 4,
|
||||
"slow_but_exact": false,
|
||||
"transformers_version": "4.20.0.dev0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 250880
|
||||
}
|
@ -1,98 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import psutil
|
||||
import torch.backends.quantized
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from hivemind.utils.logging import get_logger
|
||||
from huggingface_hub import HfApi, Repository
|
||||
from tqdm.auto import tqdm
|
||||
from transformers.models.bloom.modeling_bloom import BloomModel
|
||||
|
||||
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
|
||||
from petals.client import DistributedBloomConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
|
||||
|
||||
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
|
||||
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
|
||||
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
||||
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
|
||||
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
|
||||
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
|
||||
parser.add_argument(
|
||||
"--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
|
||||
)
|
||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
||||
parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size")
|
||||
args = parser.parse_args()
|
||||
|
||||
free_ram_gb = psutil.virtual_memory().available / 2**30
|
||||
if args.model == "bigscience/bloom" and free_ram_gb < 400:
|
||||
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
|
||||
|
||||
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
|
||||
if os.path.exists(args.output_path) and (
|
||||
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
|
||||
):
|
||||
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
|
||||
|
||||
logger.info(f"Loading source model {args.model} (this may take a few minutes)")
|
||||
config = DistributedBloomConfig.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
||||
)
|
||||
config.dht_prefix = args.output_repo
|
||||
|
||||
model = BloomModel.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
||||
)
|
||||
if args.resize_token_embeddings:
|
||||
logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
|
||||
model.resize_token_embeddings(args.resize_token_embeddings)
|
||||
config.vocab_size = args.resize_token_embeddings
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
||||
)
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
|
||||
api = HfApi(token=args.use_auth_token)
|
||||
api.create_repo(args.output_repo, repo_type="model", exist_ok=True)
|
||||
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
||||
repo.git_pull()
|
||||
|
||||
transformer_blocks = model.h
|
||||
logger.info(
|
||||
f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
|
||||
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
|
||||
)
|
||||
for i, block in enumerate(tqdm(transformer_blocks)):
|
||||
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
||||
with repo.commit(
|
||||
commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
|
||||
):
|
||||
torch.save(block.state_dict(), "./pytorch_model.bin")
|
||||
|
||||
logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
|
||||
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
||||
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
||||
model.h = nn.ModuleList()
|
||||
model.save_pretrained(".")
|
||||
tokenizer.save_pretrained(".")
|
||||
config.save_pretrained(".")
|
||||
|
||||
logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,79 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
#################
|
||||
# Parse options #
|
||||
#################
|
||||
|
||||
instructions() {
|
||||
echo "Usage: $0 [-m] [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
|
||||
echo " -m: model name"
|
||||
echo " -i: initial peer"
|
||||
echo " -d: device" >&2
|
||||
echo " -p: server identity path" >&2
|
||||
echo " -b: block_ids" >&2
|
||||
echo " -a: host maddrs" >&2
|
||||
echo " -t: whether to run local tests" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [ ! $# -ge 8 ]; then
|
||||
instructions
|
||||
fi
|
||||
|
||||
while getopts ":m:i:d:p:b:a:t:" option; do
|
||||
case $option in
|
||||
m) MODEL_NAME=${OPTARG}
|
||||
;;
|
||||
i) INITIAL_PEER=${OPTARG}
|
||||
;;
|
||||
d) DEVICE=${OPTARG}
|
||||
;;
|
||||
p) SERVER_ID_PATH=${OPTARG}
|
||||
;;
|
||||
b) BLOCK_IDS=${OPTARG}
|
||||
;;
|
||||
a) HOST_MADDR=${OPTARG} # TODO: allow several maddrs
|
||||
;;
|
||||
t) RUN_LOCAL_TESTS=true
|
||||
;;
|
||||
\?) instructions
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
echo "=========="
|
||||
echo "= Config ="
|
||||
echo "=========="
|
||||
echo "Model name: ${MODEL_NAME}"
|
||||
echo "Initial peer: ${INITIAL_PEER}"
|
||||
echo "Device: ${DEVICE}"
|
||||
echo "Server name: ${SERVER_ID_PATH}"
|
||||
echo "Server address: ${HOST_MADDR}"
|
||||
echo "Bloom blocks: ${BLOCK_IDS}"
|
||||
|
||||
|
||||
###########################
|
||||
# Install or activate env #
|
||||
###########################
|
||||
|
||||
# TODO fix bug with self calling
|
||||
source ~/miniconda3/etc/profile.d/conda.sh
|
||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
||||
conda activate bloom-demo
|
||||
else
|
||||
conda create -y --name bloom-demo python=3.8.12 pip
|
||||
conda activate bloom-demo
|
||||
|
||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install -i https://pypi.org/simple -r .
|
||||
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
|
||||
fi
|
||||
|
||||
##############
|
||||
# Run server #
|
||||
##############
|
||||
|
||||
python -m petals.cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
|
||||
--block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log
|
@ -1,51 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
from tqdm.auto import trange
|
||||
from transformers import BloomConfig
|
||||
from transformers.models.bloom.modeling_bloom import build_alibi_tensor
|
||||
|
||||
from petals.bloom.block import BloomBlock
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.warning("inference_one_block will soon be deprecated in favour of tests!")
|
||||
|
||||
|
||||
def print_device_info(device=None):
|
||||
"""Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
|
||||
device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Additional Info when using cuda
|
||||
if device.type == "cuda":
|
||||
logger.info(torch.cuda.get_device_name(0))
|
||||
logger.info(f"Memory Usage:")
|
||||
logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
|
||||
logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
|
||||
parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
|
||||
parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
|
||||
parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
|
||||
parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.device is None:
|
||||
args.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
config = BloomConfig.from_json_file(args.config)
|
||||
block = BloomBlock(config).to(args.device)
|
||||
|
||||
cache = None
|
||||
|
||||
for i in trange(args.num_steps):
|
||||
dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device)
|
||||
alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device)
|
||||
with torch.no_grad():
|
||||
outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
|
||||
|
||||
print_device_info(args.device)
|
@ -1,5 +0,0 @@
|
||||
device=cpu
|
||||
block_ids=2:3
|
||||
id_path=./server.id
|
||||
maddr=/ip4/127.0.0.1/tcp/30000
|
||||
#
|
@ -1,6 +0,0 @@
|
||||
name=bloom-peer-0.bloom.net
|
||||
device=cpu
|
||||
block_ids=1:3
|
||||
id_path=./server.id
|
||||
maddr=/ip4/0.0.0.0/tcp/30000
|
||||
#
|
@ -1,109 +0,0 @@
|
||||
# !/usr/bin/env bash
|
||||
|
||||
#################
|
||||
# Parse options #
|
||||
#################
|
||||
|
||||
instructions() {
|
||||
echo "Usage: $0 [-n] [-c]" >&2
|
||||
echo " -n: number of servers to run" >&2
|
||||
echo " -c: path to the server configs" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [ $# != 4 ]; then
|
||||
instructions
|
||||
fi
|
||||
|
||||
while getopts ":n:c:t:" option; do
|
||||
case $option in
|
||||
n) NUM_SERVERS=${OPTARG}
|
||||
;;
|
||||
c) CONFIG_PATH=${OPTARG}
|
||||
;;
|
||||
\?) instructions
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
###########################
|
||||
# Install or activate env #
|
||||
###########################
|
||||
|
||||
source ~/miniconda3/etc/profile.d/conda.sh
|
||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
||||
conda activate bloom-demo
|
||||
else
|
||||
conda create -y --name bloom-demo python=3.8.12 pip
|
||||
conda activate bloom-demo
|
||||
|
||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install -i https://pypi.org/simple -r .
|
||||
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
|
||||
fi
|
||||
|
||||
|
||||
#######################
|
||||
# Create Initial peer #
|
||||
#######################
|
||||
|
||||
hivemind-dht &> tmp.out &
|
||||
sleep 5
|
||||
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
|
||||
echo "Initial peer: ${INITIAL_PEER}"
|
||||
|
||||
|
||||
##############################
|
||||
# Initialize the config file #
|
||||
##############################
|
||||
|
||||
typeset -A cfg
|
||||
cfg=( # set default values in config array
|
||||
[device]="cpu"
|
||||
[block_ids]="1:2"
|
||||
[id_path]="server.id"
|
||||
[maddr]="/ip4/127.0.0.1/tcp/30000"
|
||||
)
|
||||
|
||||
###############
|
||||
# Run servers #
|
||||
###############
|
||||
|
||||
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
|
||||
do
|
||||
###############
|
||||
# Read config #
|
||||
###############
|
||||
|
||||
while read line
|
||||
do
|
||||
if echo $line | grep -F = &>/dev/null
|
||||
then
|
||||
varname=$(echo "$line" | cut -d '=' -f 1)
|
||||
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
|
||||
fi
|
||||
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
|
||||
|
||||
echo "=== Server #${SERVER_ID} ==="
|
||||
echo "Server ID: ${cfg[id_path]}"
|
||||
echo "Device: ${cfg[device]}"
|
||||
echo "Bloom block ids: ${cfg[block_ids]}"
|
||||
echo "Host maddr: ${cfg[maddr]}"
|
||||
echo ""
|
||||
|
||||
##############
|
||||
# Run server #
|
||||
##############
|
||||
|
||||
tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -m "bigscience/test-bloomd" -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
|
||||
done
|
||||
|
||||
#####################
|
||||
# Kill initial peer #
|
||||
#####################
|
||||
|
||||
sleep 10
|
||||
pkill -f hivemind-dht # TODO: kill only particular pids of hivemind-dht
|
||||
rm tmp.out
|
@ -1,110 +0,0 @@
|
||||
# !/usr/bin/env bash
|
||||
|
||||
SSH_KEY_PATH="~/.ssh/<YOUR_KEY>"
|
||||
|
||||
#################
|
||||
# Parse options #
|
||||
#################
|
||||
|
||||
instructions() {
|
||||
echo "Usage: $0 [-u] [-n] [-c]" >&2
|
||||
echo " -u: username" >&2
|
||||
echo " -n: number of servers to run" >&2
|
||||
echo " -c: path to the server configs" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [ $# != 6 ]; then
|
||||
instructions
|
||||
fi
|
||||
|
||||
while getopts ":u:n:c:" option; do
|
||||
case $option in
|
||||
u) USERNAME=${OPTARG}
|
||||
;;
|
||||
n) NUM_SERVERS=${OPTARG}
|
||||
;;
|
||||
c) CONFIG_PATH=${OPTARG}
|
||||
;;
|
||||
\?) instructions
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
###########################
|
||||
# Install or activate env #
|
||||
###########################
|
||||
|
||||
source ~/miniconda3/etc/profile.d/conda.sh
|
||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
||||
conda activate bloom-demo
|
||||
else
|
||||
conda create -y --name bloom-demo python=3.8.12 pip
|
||||
conda activate bloom-demo
|
||||
|
||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install -i https://pypi.org/simple -r .
|
||||
fi
|
||||
|
||||
|
||||
#######################
|
||||
# Create Initial peer #
|
||||
#######################
|
||||
|
||||
hivemind-dht &> tmp.out &
|
||||
|
||||
sleep 5
|
||||
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
|
||||
rm tmp.out
|
||||
echo "Initial peer: ${INITIAL_PEER}"
|
||||
|
||||
|
||||
##############################
|
||||
# Initialize the config file #
|
||||
##############################
|
||||
|
||||
typeset -A cfg
|
||||
cfg=( # set default values in config array
|
||||
[name]=""
|
||||
[device]="cpu"
|
||||
[block_ids]="1:2"
|
||||
[id_path]="server.id"
|
||||
[maddr]="/ip4/0.0.0.0/tcp/30000"
|
||||
)
|
||||
|
||||
###############
|
||||
# Run servers #
|
||||
###############
|
||||
|
||||
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
|
||||
do
|
||||
###############
|
||||
# Read config #
|
||||
###############
|
||||
|
||||
while read line
|
||||
do
|
||||
if echo $line | grep -F = &>/dev/null
|
||||
then
|
||||
varname=$(echo "$line" | cut -d '=' -f 1)
|
||||
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
|
||||
fi
|
||||
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
|
||||
|
||||
SERVER_NAME="${USERNAME}@${cfg[name]}"
|
||||
echo "=== Server #${SERVER_ID} ==="
|
||||
echo "Server name ${SERVER_NAME}"
|
||||
echo "Server ID: ${cfg[id_path]}"
|
||||
echo "Device: ${cfg[device]}"
|
||||
echo "Bloom block ids: ${cfg[block_ids]}"
|
||||
echo "Host maddr: ${cfg[maddr]}"
|
||||
echo "================="
|
||||
|
||||
##############
|
||||
# Run server #
|
||||
##############
|
||||
|
||||
ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'"
|
||||
done
|
@ -1,10 +1,4 @@
|
||||
from petals.client.config import ClientConfig
|
||||
from petals.client.inference_session import InferenceSession
|
||||
from petals.client.remote_model import (
|
||||
DistributedBloomConfig,
|
||||
DistributedBloomForCausalLM,
|
||||
DistributedBloomForSequenceClassification,
|
||||
DistributedBloomModel,
|
||||
)
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.client.routing.sequence_manager import RemoteSequenceManager
|
||||
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||
from petals.client.routing import NoSpendingPolicy, RemoteSequenceManager, SpendingPolicyBase
|
||||
|
@ -0,0 +1,35 @@
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from hivemind import PeerID
|
||||
|
||||
from petals.constants import PUBLIC_INITIAL_PEERS
|
||||
|
||||
_max_retries = os.getenv("PETALS_MAX_RETRIES")
|
||||
DEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ClientConfig:
|
||||
initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT
|
||||
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
|
||||
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
|
||||
|
||||
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
|
||||
allowed_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, send requests only to these servers
|
||||
blocked_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, do not use these servers
|
||||
use_server_to_server: bool = True # Use direct server-to-server communication
|
||||
|
||||
connect_timeout: float = 5 # timeout for opening a connection
|
||||
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
|
||||
update_period: float = 60 # refresh DHT information once in this many seconds
|
||||
|
||||
max_retries: Optional[int] = DEFAULT_MAX_RETRIES # max number of retries before an exception (default: inf)
|
||||
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
|
||||
max_backoff: float = 60 # limit maximal sleep time between retries to this value
|
||||
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
|
||||
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
|
||||
|
||||
max_pinged: int = 3 # max servers to ping from each sequence side, per update
|
||||
ping_timeout: float = 2 # max time to wait for pings, per update
|
@ -0,0 +1,84 @@
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from contextvars import ContextVar
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers import BloomPreTrainedModel, modeling_utils
|
||||
|
||||
from petals.utils.version import get_compatible_model_repo
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FromPretrainedMixin:
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_name_or_path: Union[str, os.PathLike, None],
|
||||
*args,
|
||||
low_cpu_mem_usage: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
model_name_or_path = get_compatible_model_repo(model_name_or_path)
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
|
||||
with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
|
||||
return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
|
||||
|
||||
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
||||
"low_cpu_mem_usage(`bool`, *optional*)",
|
||||
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
||||
).replace(
|
||||
"torch_dtype (`str` or `torch.dtype`, *optional*)",
|
||||
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
|
||||
)
|
||||
|
||||
|
||||
_ignored_keys = ContextVar("ignored_keys", default=None)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ignore_keys(patterns: List[str]):
|
||||
token = _ignored_keys.set(patterns)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_ignored_keys.reset(token)
|
||||
|
||||
|
||||
def patched_get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path, index_filename, *args, **kwargs
|
||||
) -> Tuple[List[str], dict]:
|
||||
"""Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
|
||||
|
||||
should_ignore_keys = _ignored_keys.get() is not None
|
||||
tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
|
||||
with tempdir_ctx as tempdir:
|
||||
if should_ignore_keys:
|
||||
with open(index_filename) as f:
|
||||
index = json.load(f)
|
||||
n_original_shards = len(set(index["weight_map"].values()))
|
||||
|
||||
index["weight_map"] = {
|
||||
param_name: filename
|
||||
for param_name, filename in index["weight_map"].items()
|
||||
if all(re.search(pattern, param_name) is None for pattern in _ignored_keys.get())
|
||||
}
|
||||
n_loaded_shards = len(set(index["weight_map"].values()))
|
||||
logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")
|
||||
|
||||
# Replace the original index with a patched JSON, where ignored keys are removed
|
||||
index_filename = os.path.join(tempdir, "pytorch_model.bin.index.json")
|
||||
with open(index_filename, "w") as f:
|
||||
json.dump(index, f)
|
||||
|
||||
return original_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
|
||||
|
||||
original_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files
|
@ -0,0 +1,82 @@
|
||||
import dataclasses
|
||||
import platform
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from hivemind import get_logger
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LMHeadConfig:
|
||||
# This settings matter for running the client with dtype bfloat16 on CPU.
|
||||
# If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
|
||||
use_chunked_forward: Union[str, bool] = "auto"
|
||||
chunked_forward_step: int = 16384
|
||||
|
||||
|
||||
class LMHead(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
|
||||
if not config.tie_word_embeddings:
|
||||
self.weight = nn.Parameter(torch.zeros(config.vocab_size, config.hidden_size))
|
||||
self.weight.requires_grad = False
|
||||
else:
|
||||
self.weight = None # Will be set to get_input_embeddings().weight during loading the model
|
||||
self.bias = None
|
||||
self.in_features = config.hidden_size # Similar to nn.Linear attributes
|
||||
self.out_features = config.vocab_size
|
||||
|
||||
self.use_chunked_forward = config.use_chunked_forward
|
||||
if self.use_chunked_forward == "auto":
|
||||
if platform.machine() == "x86_64":
|
||||
# Import of cpufeature may crash on non-x86_64 machines
|
||||
from cpufeature import CPUFeature
|
||||
|
||||
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
|
||||
# Otherwise, it's ~8x slower.
|
||||
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
|
||||
else:
|
||||
self.use_chunked_forward = True
|
||||
self.chunked_forward_step = config.chunked_forward_step
|
||||
self._bf16_warning_shown = False
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if (
|
||||
self.weight.dtype in [torch.float16, torch.bfloat16]
|
||||
and self.weight.device.type == "cpu"
|
||||
and self.use_chunked_forward
|
||||
):
|
||||
lm_logits = self.chunked_forward(hidden_states)
|
||||
else:
|
||||
# Switch dtype in case word_embeddings are fp16/bf16
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
lm_logits = F.linear(hidden_states, self.weight)
|
||||
return lm_logits
|
||||
|
||||
def chunked_forward(self, hidden_states):
|
||||
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
|
||||
chunked_forward_step: provides trade-off between efficiency and extra memory consumption.
|
||||
"""
|
||||
assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
|
||||
|
||||
if not self._bf16_warning_shown:
|
||||
logger.warning(
|
||||
"Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
|
||||
"To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
|
||||
)
|
||||
self._bf16_warning_shown = True
|
||||
|
||||
hidden_states = hidden_states.float()
|
||||
output = torch.empty(*hidden_states.shape[:-1], self.out_features)
|
||||
|
||||
for i in range(0, self.out_features, self.chunked_forward_step):
|
||||
chunk = self.weight[i : i + self.chunked_forward_step].float()
|
||||
output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)
|
||||
return output
|
@ -0,0 +1,84 @@
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind import get_logger
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from petals.utils.misc import DUMMY
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PTuneConfig:
|
||||
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
||||
tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
|
||||
|
||||
|
||||
class PTuneMixin:
|
||||
_keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"]
|
||||
|
||||
def init_prompts(self, config: PretrainedConfig) -> None:
|
||||
if config.tuning_mode and "ptune" in config.tuning_mode:
|
||||
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
||||
self.pre_seq_len = config.pre_seq_len
|
||||
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
||||
|
||||
with force_non_empty_weights():
|
||||
# Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality
|
||||
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
||||
if config.tuning_mode == "deep_ptune":
|
||||
self.intermediate_prompt_embeddings = nn.Embedding(
|
||||
self.pre_seq_len,
|
||||
config.num_hidden_layers * config.hidden_size,
|
||||
# ^-- TODO: should be num_hidden_layers - 1
|
||||
dtype=torch.float32,
|
||||
)
|
||||
elif config.tuning_mode:
|
||||
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
||||
|
||||
def get_prompt(self, batch_size):
|
||||
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
||||
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
||||
prompts = self.prompt_embeddings(prefix_tokens)
|
||||
|
||||
if self.config.tuning_mode == "deep_ptune":
|
||||
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
||||
intermediate_prompts = intermediate_prompts.view(
|
||||
batch_size,
|
||||
self.pre_seq_len,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.hidden_size
|
||||
# TODO: should be num_hidden_layers - 1
|
||||
)
|
||||
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
||||
else:
|
||||
intermediate_prompts = DUMMY
|
||||
|
||||
dtype = self.word_embeddings.weight.dtype
|
||||
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
||||
|
||||
|
||||
_original_register_parameter = nn.Module.register_parameter
|
||||
|
||||
|
||||
@contextmanager
|
||||
def force_non_empty_weights():
|
||||
"""
|
||||
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
||||
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
||||
The transformers library should replace all meta tensors by empty tensors by itself
|
||||
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
||||
|
||||
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
||||
"""
|
||||
|
||||
possibly_patched_register_parameter = nn.Module.register_parameter
|
||||
nn.Module.register_parameter = _original_register_parameter
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
nn.Module.register_parameter = possibly_patched_register_parameter
|
@ -1,349 +1,164 @@
|
||||
import contextlib
|
||||
from typing import List, Optional
|
||||
import dataclasses
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, ContextManager, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from hivemind.utils.logging import get_logger
|
||||
from torch import Tensor
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.generation.utils import ModelOutput
|
||||
|
||||
from petals.client.inference_session import InferenceSession
|
||||
from petals.utils.generation_algorithms import (
|
||||
BeamSearchAlgorithm,
|
||||
DecodingAlgorithm,
|
||||
GreedyAlgorithm,
|
||||
NucleusAlgorithm,
|
||||
SamplingAlgorithm,
|
||||
TopKAlgorithm,
|
||||
)
|
||||
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.utils.misc import DUMMY, docstring_from
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RemoteGenerationMixin:
|
||||
class RemotePastKeyValues(Cache):
|
||||
"""only keeps the number of seen tokens. pretends to be a legit cache"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.seen_tokens = 0
|
||||
self.hypo_ids: Optional[torch.LongTensor] = None
|
||||
|
||||
def __getitem__(self, _index: int) -> List[torch.Tensor]:
|
||||
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
return self.seen_tokens
|
||||
|
||||
def get_max_length(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
def update_seen(self, new_seen: int) -> None:
|
||||
self.seen_tokens += new_seen
|
||||
|
||||
def reorder_cache(self, beam_idx):
|
||||
raise NotImplementedError("Beam search reordering is not implemented yet")
|
||||
|
||||
|
||||
_skipped_tokens = ContextVar("skipped_tokens", default=0)
|
||||
|
||||
|
||||
class _SkipTokensMixin:
|
||||
# This override is used in RemoteGenerationMixin by has to be defined in a class not named as "GenerationMixin"
|
||||
# due to how transformers.PreTrainedModel.can_generate() works
|
||||
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:
|
||||
input_ids = input_ids[:, _skipped_tokens.get() :]
|
||||
_skipped_tokens.set(0)
|
||||
return super().prepare_inputs_for_generation(input_ids, **kwargs)
|
||||
|
||||
|
||||
class RemoteGenerationMixin(_SkipTokensMixin):
|
||||
"""
|
||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
|
||||
The class exposes can be used for:
|
||||
- *greedy decoding*.
|
||||
- *multinomial, top-k and top-p sampling*.
|
||||
- *beam-search decoding*
|
||||
|
||||
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
|
||||
However, it has some differences for remote usage.
|
||||
This class is an upgrade to `transformers.GenerationMixin` that:
|
||||
|
||||
- Designed to be compatible with most `transformers.GenerationMixin` strategies and options
|
||||
- Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and
|
||||
you don't have to rerun the prefix through all the servers to generate each new token
|
||||
- Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation
|
||||
by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or
|
||||
accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`).
|
||||
- If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`.
|
||||
Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that.
|
||||
"""
|
||||
|
||||
def inference_session(self, **kwargs) -> InferenceSession:
|
||||
"""
|
||||
Returns an inference session for the model's RemoteSequential module.
|
||||
@docstring_from(RemoteSequential.active_session)
|
||||
@property
|
||||
def active_session(self) -> Optional[InferenceSession]:
|
||||
return self.transformer.h.active_session
|
||||
|
||||
:param max_length: Maximal expected length of inference results. Servers use this parameter
|
||||
to calculate the size of attention caches allocated to this client.
|
||||
"""
|
||||
@docstring_from(RemoteSequential.use_session)
|
||||
def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
|
||||
return self.transformer.h.use_session(session)
|
||||
|
||||
@docstring_from(RemoteSequential.inference_session)
|
||||
def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
|
||||
return self.transformer.h.inference_session(**kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
@docstring_from(transformers.GenerationMixin.generate.__doc__)
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
*,
|
||||
do_sample: Optional[bool] = None,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
num_beams: Optional[int] = 1,
|
||||
bos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
decoding_algorithm: Optional[DecodingAlgorithm] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
num_return_sequences: Optional[int] = None,
|
||||
session: Optional[InferenceSession] = None,
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Generates sequences of token ids for models with a language modeling head.
|
||||
|
||||
:param inputs: The input tokens to the model.
|
||||
:param do_sample: Whether to sample from the model predictions or take the argmax.
|
||||
:param temperature: The temperature to use for sampling.
|
||||
:param top_k: The number of results to return.
|
||||
:param top_p: The cumulative probability of results to return.
|
||||
:param num_beams: The number of beams to use for beam search.
|
||||
:param bos_token_id: The id of the beginning of sentence token.
|
||||
:param eos_token_id: The id of the end of sentence token.
|
||||
:param pad_token_id: The id of the padding token.
|
||||
:param max_length: The maximum number of tokens in the output (including input tokens).
|
||||
:param max_new_tokens: The maximum number of tokens to generate.
|
||||
:param decoding_algorithm: The decoding algorithm to use.
|
||||
:param provided_constraints: A list of constraints to use.
|
||||
:param num_return_sequences: How many hypothesis from the beam will be in output.
|
||||
"""
|
||||
|
||||
prefix_length = 0 if inputs is None else inputs.size(1)
|
||||
prefix_length += self.config.pre_seq_len
|
||||
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
|
||||
assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
|
||||
if max_length is not None and max_new_tokens is None:
|
||||
max_new_tokens = max_length - prefix_length
|
||||
assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
|
||||
elif max_length is None and max_new_tokens is not None:
|
||||
max_length = prefix_length + max_new_tokens
|
||||
|
||||
resuming_session = session is not None and session.last_token_id is not None
|
||||
if num_beams > 1 and resuming_session:
|
||||
raise NotImplementedError(
|
||||
"Resuming inference session in .generate() along with beam search is not supported yet"
|
||||
)
|
||||
|
||||
if inputs is not None:
|
||||
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
|
||||
if resuming_session:
|
||||
inputs = torch.cat([session.last_token_id, inputs], dim=1)
|
||||
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
|
||||
):
|
||||
self._fix_generate_kwargs(kwargs)
|
||||
if inputs is None:
|
||||
inputs = kwargs.pop("input_ids", None)
|
||||
|
||||
if session is not None:
|
||||
# If a session specified explicitly, use it
|
||||
context_manager = self.use_session(session)
|
||||
elif self.active_session is not None:
|
||||
# If there's an active session, don't do anything
|
||||
context_manager = contextlib.nullcontext(self.active_session)
|
||||
else:
|
||||
if resuming_session:
|
||||
inputs = session.last_token_id
|
||||
else:
|
||||
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|
||||
inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
|
||||
batch_size = inputs.size(0)
|
||||
|
||||
if decoding_algorithm is None:
|
||||
if do_sample:
|
||||
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
|
||||
elif num_beams is not None and num_beams > 1:
|
||||
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
|
||||
# If there's no active session, create a new one
|
||||
|
||||
max_length = kwargs.get("max_length")
|
||||
max_new_tokens = kwargs.get("max_new_tokens")
|
||||
assert (max_length is None) != (
|
||||
max_new_tokens is None
|
||||
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
|
||||
|
||||
session_max_length = self.transformer.config.pre_seq_len
|
||||
if max_length is not None:
|
||||
session_max_length += max_length
|
||||
else:
|
||||
if top_k is not None or top_p is not None:
|
||||
logger.warning("You passed top_k or top_p but did not pass do_sample=True. Running greedy sampling")
|
||||
decoding_algorithm = GreedyAlgorithm()
|
||||
|
||||
if num_beams > 1:
|
||||
inputs = torch.cat([inputs] * num_beams, dim=0)
|
||||
if batch_size > 1:
|
||||
# TODO: resolve padding problem
|
||||
logger.warning(
|
||||
f"You set batch_size {batch_size} within beam search generation. "
|
||||
f"Be careful, results on sequences with different length may be padded wrong way"
|
||||
)
|
||||
|
||||
if num_return_sequences is None:
|
||||
num_return_sequences = 1
|
||||
|
||||
assert num_return_sequences <= num_beams, (
|
||||
f"You want more sequences than the beam has."
|
||||
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
|
||||
)
|
||||
|
||||
constraints = self._get_constraints(
|
||||
inputs=inputs,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
provided_constraints=provided_constraints,
|
||||
)
|
||||
|
||||
if session is None:
|
||||
context_manager = self.inference_session(max_length=max_length)
|
||||
else:
|
||||
context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
|
||||
session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
|
||||
context_manager = self.inference_session(max_length=session_max_length)
|
||||
|
||||
with context_manager as session:
|
||||
outputs = []
|
||||
# Find samples with padded inputs.
|
||||
# They will be changed before all of the samples have right length.
|
||||
if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
|
||||
outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
|
||||
# Prepend the tokens from the previous .generate() call
|
||||
n_prev_tokens = session.output_ids.shape[1] if session.output_ids is not None else 0
|
||||
if n_prev_tokens > 0:
|
||||
if kwargs.get("num_beams", 1) > 1:
|
||||
logger.warning(
|
||||
"Beam search will not work properly in the resumed petals.InferenceSession "
|
||||
"since intermediate beam entries are lost"
|
||||
)
|
||||
|
||||
if inputs is not None:
|
||||
inputs = torch.cat([session.output_ids, inputs], dim=1)
|
||||
else:
|
||||
inputs = session.output_ids
|
||||
|
||||
# Don't actually run all previous tokens through the transformer,
|
||||
# but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
|
||||
_skipped_tokens.set(max(0, n_prev_tokens - 1))
|
||||
|
||||
if self._supports_cache_class and "past_key_values" not in kwargs:
|
||||
past_key_values = RemotePastKeyValues()
|
||||
past_key_values.update_seen(session.position)
|
||||
kwargs["past_key_values"] = past_key_values
|
||||
|
||||
result = super().generate(inputs, *args, **kwargs)
|
||||
|
||||
sequences = result.sequences if isinstance(result, ModelOutput) else result
|
||||
# Save tokens from this .generate() call
|
||||
session.output_ids = sequences
|
||||
# Crop the last tokens from the previous call
|
||||
sequences = sequences[:, n_prev_tokens:].clone()
|
||||
if isinstance(result, ModelOutput):
|
||||
result.sequences = sequences
|
||||
else:
|
||||
outputs += [inputs]
|
||||
last_token_id = None
|
||||
seq_idx = outputs[0].size(1)
|
||||
hypo_ids = torch.arange(outputs[0].size(0))
|
||||
while True:
|
||||
hidden_state = self.transformer.word_embeddings(outputs[-1])
|
||||
intermediate_prompts = None
|
||||
if self.config.pre_seq_len > 0 and len(outputs) == 1:
|
||||
prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0))
|
||||
hidden_state = torch.cat([prompts, hidden_state], dim=1)
|
||||
hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
|
||||
|
||||
hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
|
||||
|
||||
hidden_state = self.transformer.ln_f(hidden_state)
|
||||
lm_logits = self.lm_head(hidden_state)
|
||||
|
||||
for constraint in constraints:
|
||||
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
|
||||
last_token_id, hypo_ids = decoding_algorithm(lm_logits)
|
||||
|
||||
# If some samples were padded, change only these samples
|
||||
if seq_idx < inputs.size(1):
|
||||
pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
|
||||
last_token_id = (~pad_token_mask) * inputs[
|
||||
:, seq_idx : seq_idx + 1
|
||||
] + pad_token_mask * last_token_id
|
||||
|
||||
# TODO: refactor outputs
|
||||
if num_beams > 1:
|
||||
for i in range(len(outputs), 1, -1):
|
||||
outputs[i - 1] = outputs[i - 1][hypo_ids]
|
||||
|
||||
outputs.append(last_token_id)
|
||||
session.last_token_id = last_token_id
|
||||
seq_idx += 1
|
||||
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
|
||||
break
|
||||
|
||||
outputs = torch.cat(outputs, dim=-1)
|
||||
|
||||
if resuming_session:
|
||||
outputs = outputs[:, 1:]
|
||||
if num_beams > 1:
|
||||
pre_return_idx = [
|
||||
torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
|
||||
]
|
||||
return_idx = torch.cat(pre_return_idx, dim=0)
|
||||
outputs = outputs[return_idx]
|
||||
|
||||
return outputs
|
||||
|
||||
def greedy_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Generates sequences of token ids for models with a language modeling head. Uses greedy search.
|
||||
|
||||
:param input_ids: The input tokens to the model.
|
||||
:param max_length: The maximum length of the sequence to generate.
|
||||
:param pad_token_id: The id of the padding token.
|
||||
:param eos_token_id: The id of the end of sentence token.
|
||||
:param provided_constraints: A list of constraints to use.
|
||||
"""
|
||||
return self.generate(
|
||||
inputs=input_ids,
|
||||
max_new_tokens=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoding_algorithm=GreedyAlgorithm(),
|
||||
provided_constraints=provided_constraints,
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
|
||||
If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
|
||||
|
||||
:param: input_ids: The input tokens to the model.
|
||||
:param: temperature: The temperature to use for sampling.
|
||||
:param: top_k: The number of samples to use for top_k sampling.
|
||||
:param: top_p: The probability of using top_p sampling.
|
||||
:param: max_length: The maximum length of the sequence to generate.
|
||||
:param: pad_token_id: The id of the padding token.
|
||||
:param: eos_token_id: The id of the end of sentence token.
|
||||
:param: provided_constraints: A list of constraints to use.
|
||||
"""
|
||||
|
||||
return self.generate(
|
||||
inputs=input_ids,
|
||||
max_new_tokens=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
|
||||
provided_constraints=provided_constraints,
|
||||
)
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
num_beams: int = 1,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Generates sequences of token ids for models with a language modeling head. Uses beam search.
|
||||
|
||||
:param input_ids: The input tokens to the model.
|
||||
:param num_beams: The number of beams to use.
|
||||
:param max_length: The maximum length of the sequence to generate.
|
||||
:param pad_token_id: The id of the padding token.
|
||||
:param eos_token_id: The id of the end of sentence token.
|
||||
:param provided_constraints: A list of constraints to use.
|
||||
"""
|
||||
decoding_algorithm = BeamSearchAlgorithm(
|
||||
num_beams=num_beams,
|
||||
batch_size=input_ids.size(0),
|
||||
)
|
||||
return self.generate(
|
||||
inputs=input_ids,
|
||||
num_beams=num_beams,
|
||||
max_new_tokens=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoding_algorithm=decoding_algorithm,
|
||||
provided_constraints=provided_constraints,
|
||||
)
|
||||
|
||||
def beam_sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def group_beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _choose_sample_algorithm(
|
||||
self,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> DecodingAlgorithm:
|
||||
if (top_k is not None) and (top_p is not None):
|
||||
raise ValueError("You have to provide only top_k or top_p for sampling")
|
||||
if top_k is not None:
|
||||
return TopKAlgorithm(top_k, temperature)
|
||||
elif top_p is not None:
|
||||
return NucleusAlgorithm(top_p, temperature)
|
||||
else:
|
||||
return SamplingAlgorithm(temperature)
|
||||
|
||||
def _get_constraints(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> List[ABCBloomConstraint]:
|
||||
constraints = []
|
||||
constraints.extend(provided_constraints)
|
||||
constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
|
||||
return constraints
|
||||
result = sequences
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _fix_generate_kwargs(kwargs: dict):
|
||||
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
|
||||
if "max_length" in kwargs and kwargs["max_length"] is None:
|
||||
del kwargs["max_length"]
|
||||
|
||||
# Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0
|
||||
do_sample = kwargs.get("do_sample")
|
||||
if isinstance(do_sample, int):
|
||||
kwargs["do_sample"] = bool(do_sample)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
|
||||
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)
|
||||
|
@ -1,269 +0,0 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Collection, List, Optional, Union
|
||||
|
||||
import hivemind
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.bloom import (
|
||||
BloomConfig,
|
||||
BloomForCausalLM,
|
||||
BloomForSequenceClassification,
|
||||
BloomModel,
|
||||
BloomPreTrainedModel,
|
||||
)
|
||||
|
||||
from petals.bloom.modeling_utils import LMHead
|
||||
from petals.client.remote_generation import RemoteGenerationMixin
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
||||
from petals.constants import PUBLIC_INITIAL_PEERS
|
||||
from petals.utils.misc import DUMMY
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig):
|
||||
"""
|
||||
A bloom config that contains information about DHT peers.
|
||||
To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
|
||||
"""
|
||||
|
||||
initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT
|
||||
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
|
||||
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
|
||||
|
||||
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
||||
tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
|
||||
|
||||
# This settings matter for running the client with dtype bfloat16 on CPU.
|
||||
# If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
|
||||
use_chunked_forward: Union[str, bool] = "auto"
|
||||
chunked_forward_step: int = 16384
|
||||
|
||||
|
||||
original_register_parameter = nn.Module.register_parameter
|
||||
|
||||
|
||||
@contextmanager
|
||||
def force_non_empty_weights():
|
||||
"""
|
||||
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
||||
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
||||
The transformers library should replace all meta tensors by empty tensors by itself
|
||||
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
||||
|
||||
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
||||
"""
|
||||
|
||||
try:
|
||||
possibly_patched_register_parameter = nn.Module.register_parameter
|
||||
nn.Module.register_parameter = original_register_parameter
|
||||
yield
|
||||
finally:
|
||||
nn.Module.register_parameter = possibly_patched_register_parameter
|
||||
|
||||
|
||||
class _FromPretrainedDefaultsMixin:
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
*args,
|
||||
low_cpu_mem_usage: Optional[bool] = None,
|
||||
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
if torch_dtype is None:
|
||||
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
|
||||
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
|
||||
torch_dtype = "auto"
|
||||
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs)
|
||||
|
||||
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
||||
"low_cpu_mem_usage(`bool`, *optional*)",
|
||||
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
||||
).replace(
|
||||
"torch_dtype (`str` or `torch.dtype`, *optional*)",
|
||||
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
|
||||
)
|
||||
|
||||
|
||||
class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel):
|
||||
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
|
||||
r"^(intermediate_)?prompt_embeddings\.weight$",
|
||||
]
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
|
||||
assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
|
||||
|
||||
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
|
||||
super().__init__(config)
|
||||
assert len(self.h) == 0
|
||||
config.n_layer = n_layer
|
||||
|
||||
self.h = RemoteSequential(config, dht=dht)
|
||||
|
||||
# Forbid accumulate grads for embeddings and layernorm
|
||||
self.set_requires_grad(False)
|
||||
|
||||
if config.tuning_mode and "ptune" in config.tuning_mode:
|
||||
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
||||
self.pre_seq_len = config.pre_seq_len
|
||||
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
||||
|
||||
with force_non_empty_weights():
|
||||
if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16):
|
||||
logger.info(
|
||||
"Prompt embeddings and their optimizer statistics will be kept in float32 "
|
||||
"to increase ptune quality"
|
||||
)
|
||||
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
||||
if config.tuning_mode == "deep_ptune":
|
||||
self.intermediate_prompt_embeddings = nn.Embedding(
|
||||
self.pre_seq_len,
|
||||
config.num_hidden_layers * config.hidden_size,
|
||||
# ^-- TODO: should be num_hidden_layers - 1
|
||||
dtype=torch.float32,
|
||||
)
|
||||
elif config.tuning_mode:
|
||||
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
||||
|
||||
def set_requires_grad(self, value):
|
||||
for p in self.parameters():
|
||||
p.requires_grad = value
|
||||
|
||||
def get_prompt(self, batch_size):
|
||||
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
||||
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
||||
prompts = self.prompt_embeddings(prefix_tokens)
|
||||
|
||||
if self.config.tuning_mode == "deep_ptune":
|
||||
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
||||
intermediate_prompts = intermediate_prompts.view(
|
||||
batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1
|
||||
)
|
||||
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
||||
else:
|
||||
intermediate_prompts = DUMMY
|
||||
|
||||
dtype = self.word_embeddings.weight.dtype
|
||||
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if not (v is None or v is False):
|
||||
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
||||
else:
|
||||
hidden_states = self.h(hidden_states)
|
||||
|
||||
# Remove prefix
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
||||
class DistributedBloomForCausalLM(_FromPretrainedDefaultsMixin, RemoteGenerationMixin, BloomForCausalLM):
|
||||
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = (
|
||||
BloomForCausalLM._keys_to_ignore_on_load_missing
|
||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||
+ [r"^lm_head.word_embeddings\.weight$"] # Missing since they are shared with input embeddings
|
||||
)
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig):
|
||||
BloomPreTrainedModel.__init__(self, config)
|
||||
self.transformer = DistributedBloomModel(config)
|
||||
self.lm_head = LMHead(config, self.transformer.word_embeddings)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.transformer.word_embeddings
|
||||
|
||||
def get_output_embeddings(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
return None
|
||||
return self.lm_head
|
||||
|
||||
def set_input_embeddings(self, new_embeddings: nn.Embedding):
|
||||
assert isinstance(new_embeddings, nn.Embedding)
|
||||
self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
|
||||
assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
|
||||
|
||||
def set_output_embeddings(self, new_lm_head: nn.Linear):
|
||||
with torch.no_grad():
|
||||
self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
|
||||
self.lm_head.bias[...] = new_lm_head.bias
|
||||
|
||||
|
||||
class DistributedBloomForSequenceClassification(_FromPretrainedDefaultsMixin, BloomForSequenceClassification):
|
||||
_keys_to_ignore_on_load_missing = (
|
||||
BloomForSequenceClassification._keys_to_ignore_on_load_missing
|
||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||
)
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig):
|
||||
BloomPreTrainedModel.__init__(self, config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = DistributedBloomModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
@ -1 +1,2 @@
|
||||
"""Client-side functions responsible for choosing the best server, """
|
||||
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
|
||||
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||
|
@ -1,9 +1,18 @@
|
||||
import torch
|
||||
|
||||
PUBLIC_INITIAL_PEERS = [
|
||||
"/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
"/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
# IPv4 DNS addresses
|
||||
"/dns/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/dns/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
# IPv6 DNS addresses
|
||||
"/dns6/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/dns6/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
# Reserved IPs
|
||||
"/ip4/159.89.214.152/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/ip4/159.203.156.48/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
]
|
||||
|
||||
# The reachability API is currently used only when connecting to the public swarm
|
||||
REACHABILITY_API_URL = "http://health.petals.ml"
|
||||
REACHABILITY_API_URL = "https://health.petals.dev"
|
||||
|
||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
||||
|
@ -1,123 +1,9 @@
|
||||
"""
|
||||
Utilities for declaring and retrieving active model layers using a shared DHT.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import warnings
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
warnings.warn(
|
||||
"petals.dht_utils has been moved to petals.utils.dht. This alias will be removed in Petals 2.2.0+",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from hivemind.dht import DHT, DHTNode, DHTValue
|
||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||
from hivemind.p2p import PeerID
|
||||
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
|
||||
|
||||
import petals.client
|
||||
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def declare_active_modules(
|
||||
dht: DHT,
|
||||
uids: Sequence[ModuleUID],
|
||||
expiration_time: DHTExpiration,
|
||||
state: ServerState,
|
||||
throughput: float,
|
||||
wait: bool = True,
|
||||
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
|
||||
"""
|
||||
Declare that your node serves the specified modules; update timestamps if declared previously
|
||||
|
||||
:param uids: a list of module ids to declare
|
||||
:param wait: if True, awaits for declaration to finish, otherwise runs in background
|
||||
:param throughput: specify your performance in terms of compute throughput
|
||||
:param expiration_time: declared modules will be visible for this many seconds
|
||||
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
|
||||
"""
|
||||
if isinstance(uids, str):
|
||||
uids = [uids]
|
||||
if not isinstance(uids, list):
|
||||
uids = list(uids)
|
||||
for uid in uids:
|
||||
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
|
||||
return dht.run_coroutine(
|
||||
partial(
|
||||
_declare_active_modules,
|
||||
uids=uids,
|
||||
expiration_time=expiration_time,
|
||||
state=state,
|
||||
throughput=throughput,
|
||||
),
|
||||
return_future=not wait,
|
||||
)
|
||||
|
||||
|
||||
async def _declare_active_modules(
|
||||
dht: DHT,
|
||||
node: DHTNode,
|
||||
uids: List[ModuleUID],
|
||||
expiration_time: DHTExpiration,
|
||||
state: ServerState,
|
||||
throughput: float,
|
||||
) -> Dict[ModuleUID, bool]:
|
||||
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
||||
return await node.store_many(
|
||||
keys=uids,
|
||||
subkeys=[dht.peer_id.to_base58()] * len(uids),
|
||||
values=[(state.value, throughput)] * len(uids),
|
||||
expiration_time=expiration_time,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
|
||||
def get_remote_module_infos(
|
||||
dht: DHT,
|
||||
uids: Sequence[ModuleUID],
|
||||
expiration_time: Optional[DHTExpiration] = None,
|
||||
*,
|
||||
latest: bool = False,
|
||||
return_future: bool = False,
|
||||
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
|
||||
return dht.run_coroutine(
|
||||
partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time, latest=latest),
|
||||
return_future=return_future,
|
||||
)
|
||||
|
||||
|
||||
async def _get_remote_module_infos(
|
||||
dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration], latest: bool
|
||||
) -> List[Optional[RemoteModuleInfo]]:
|
||||
if latest:
|
||||
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
|
||||
expiration_time = math.inf
|
||||
elif expiration_time is None:
|
||||
expiration_time = get_dht_time()
|
||||
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
||||
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
|
||||
|
||||
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
|
||||
for i, uid in enumerate(uids):
|
||||
metadata = found[uid]
|
||||
if metadata is None or not isinstance(metadata.value, dict):
|
||||
if metadata is not None:
|
||||
logger.error(f"Incorrect metadata for {uid}: {metadata}")
|
||||
continue
|
||||
servers = {}
|
||||
for peer_id, server_info in metadata.value.items():
|
||||
try:
|
||||
peer_id = PeerID.from_base58(peer_id)
|
||||
state, throughput = server_info.value
|
||||
if not (
|
||||
isinstance(state, int)
|
||||
and isinstance(throughput, float)
|
||||
and math.isfinite(throughput)
|
||||
and throughput >= 0.0
|
||||
):
|
||||
raise ValueError(f"Invalid server info: {server_info}")
|
||||
servers[peer_id] = ServerInfo(ServerState(state), throughput)
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
|
||||
if servers:
|
||||
modules[i] = RemoteModuleInfo(uid, servers)
|
||||
return modules
|
||||
from petals.utils.dht import *
|
||||
|
@ -0,0 +1,4 @@
|
||||
from petals.models.bloom import *
|
||||
from petals.models.falcon import *
|
||||
from petals.models.llama import *
|
||||
from petals.models.mixtral import *
|
@ -0,0 +1,15 @@
|
||||
from petals.models.bloom.block import WrappedBloomBlock
|
||||
from petals.models.bloom.config import DistributedBloomConfig
|
||||
from petals.models.bloom.model import (
|
||||
DistributedBloomForCausalLM,
|
||||
DistributedBloomForSequenceClassification,
|
||||
DistributedBloomModel,
|
||||
)
|
||||
from petals.utils.auto_config import register_model_classes
|
||||
|
||||
register_model_classes(
|
||||
config=DistributedBloomConfig,
|
||||
model=DistributedBloomModel,
|
||||
model_for_causal_lm=DistributedBloomForCausalLM,
|
||||
model_for_sequence_classification=DistributedBloomForSequenceClassification,
|
||||
)
|
@ -0,0 +1,45 @@
|
||||
"""
|
||||
Bloom intermediate layer
|
||||
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
||||
See commit history for authorship.
|
||||
"""
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
|
||||
|
||||
from petals.utils.misc import is_dummy
|
||||
|
||||
|
||||
class WrappedBloomBlock(BloomBlock):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
alibi: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs
|
||||
):
|
||||
assert attention_mask is None, "Non-causal attention masks are not supported yet"
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
if layer_past is not None and is_dummy(layer_past[0]):
|
||||
# Bloom cannot use cache if it was misconsctructed(e.g. Dummy tensors)
|
||||
# In this case, fallback to the old code:
|
||||
layer_past = None
|
||||
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
|
||||
seq_length_with_past = seq_length + past_length
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
if alibi is None:
|
||||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask=attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
inputs_embeds=hidden_states,
|
||||
past_key_values_length=past_length,
|
||||
)
|
||||
attention_mask = attention_mask.bool()
|
||||
return super().forward(
|
||||
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
||||
)
|
@ -0,0 +1,35 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from hivemind import get_logger
|
||||
from transformers.models.bloom import BloomConfig
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
from petals.client.config import ClientConfig
|
||||
from petals.client.lm_head import LMHeadConfig
|
||||
from petals.client.ptune import PTuneConfig
|
||||
from petals.models.bloom.block import WrappedBloomBlock
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfig):
|
||||
block_class = WrappedBloomBlock
|
||||
attn_class = BloomAttention
|
||||
block_prefix = "h"
|
||||
|
||||
num_key_value_groups = 1
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||
):
|
||||
logger.info("Make sure you follow the BLOOM's terms of use: https://bit.ly/bloom-license")
|
||||
|
||||
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||
if loading_from_repo and dht_prefix is None:
|
||||
# We need "-petals" for backward compatibility with Petals < 1.2.0
|
||||
dht_prefix = str(model_name_or_path) + "-petals"
|
||||
dht_prefix = dht_prefix.replace(".", "-")
|
||||
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
@ -0,0 +1,197 @@
|
||||
from typing import Optional
|
||||
|
||||
import hivemind
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
|
||||
|
||||
from petals.client.from_pretrained import FromPretrainedMixin
|
||||
from petals.client.lm_head import LMHead
|
||||
from petals.client.ptune import PTuneMixin
|
||||
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.models.bloom.config import DistributedBloomConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
|
||||
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = [r"^h\."]
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||
super().__init__(config)
|
||||
assert len(self.h) == 0
|
||||
config.num_hidden_layers = n_layer
|
||||
|
||||
self.h = RemoteSequential(config, dht=dht)
|
||||
|
||||
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||
self.init_prompts(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[RemotePastKeyValues] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
# The causal mask will be added on the server-side
|
||||
assert (
|
||||
attention_mask is None or (attention_mask == 1).all()
|
||||
), f"Custom attention masks are not supported, {attention_mask=}"
|
||||
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
|
||||
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
|
||||
assert not output_attentions, f"{output_attentions=} is not supported"
|
||||
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
|
||||
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
|
||||
if use_prompts:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
else:
|
||||
prompts = intermediate_prompts = None
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
hidden_states = self.h(
|
||||
hidden_states,
|
||||
prompts=intermediate_prompts,
|
||||
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
|
||||
)
|
||||
|
||||
# Remove prefix
|
||||
if use_prompts:
|
||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = RemotePastKeyValues()
|
||||
past_key_values.update_seen(hidden_states.size(1))
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
||||
class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
|
||||
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
|
||||
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
||||
_supports_cache_class = True
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig):
|
||||
BloomPreTrainedModel.__init__(self, config)
|
||||
self.transformer = DistributedBloomModel(config)
|
||||
self.lm_head = LMHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
) -> dict:
|
||||
# Omit tokens covered by past_key_values
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, Cache):
|
||||
cache_length = past_key_values.get_seq_length()
|
||||
past_length = past_key_values.seen_tokens
|
||||
max_cache_length = past_key_values.get_max_length()
|
||||
else:
|
||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||
max_cache_length = None
|
||||
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def _temporary_reorder_cache(self, past_key_values, beam_idx):
|
||||
return past_key_values
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
|
||||
class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
|
||||
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig):
|
||||
BloomPreTrainedModel.__init__(self, config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = DistributedBloomModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
@ -0,0 +1,15 @@
|
||||
from petals.models.falcon.block import WrappedFalconBlock
|
||||
from petals.models.falcon.config import DistributedFalconConfig
|
||||
from petals.models.falcon.model import (
|
||||
DistributedFalconForCausalLM,
|
||||
DistributedFalconForSequenceClassification,
|
||||
DistributedFalconModel,
|
||||
)
|
||||
from petals.utils.auto_config import register_model_classes
|
||||
|
||||
register_model_classes(
|
||||
config=DistributedFalconConfig,
|
||||
model=DistributedFalconModel,
|
||||
model_for_causal_lm=DistributedFalconForCausalLM,
|
||||
model_for_sequence_classification=DistributedFalconForSequenceClassification,
|
||||
)
|
@ -0,0 +1,480 @@
|
||||
"""
|
||||
Falcon intermediate layer
|
||||
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
|
||||
See commit history for authorship.
|
||||
"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.falcon.modeling_falcon import (
|
||||
FalconAttention,
|
||||
FalconConfig,
|
||||
FalconDecoderLayer,
|
||||
FalconLinear,
|
||||
FalconMLP,
|
||||
FalconModel,
|
||||
LayerNorm,
|
||||
build_alibi_tensor,
|
||||
dropout_add,
|
||||
rotate_half,
|
||||
)
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
INFERENCE_MAX_LENGTH = 8192
|
||||
|
||||
|
||||
def apply_rotary(query, key, cos, sin):
|
||||
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
|
||||
|
||||
|
||||
class OptimizedFalconRotaryEmbedding(nn.Module):
|
||||
def __init__(self, head_dim: int, base=10000):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.head_dim = head_dim
|
||||
self.seq_len_cached = -1
|
||||
|
||||
self.cuda_graph = None
|
||||
self.input_surface = None
|
||||
self.static_outputs = None
|
||||
|
||||
def _optimized_apply_rotary(self, query, key, cos, sin):
|
||||
if self.cuda_graph is None:
|
||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
self.input_surface = (query, key, cos, sin)
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(3):
|
||||
apply_rotary(*self.input_surface)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph):
|
||||
self.static_outputs = apply_rotary(*self.input_surface)
|
||||
|
||||
inputs = (query, key, cos, sin)
|
||||
for static_input, data in zip(self.input_surface, inputs):
|
||||
static_input.copy_(data)
|
||||
self.cuda_graph.replay()
|
||||
return tuple(o.detach() for o in self.static_outputs)
|
||||
|
||||
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
|
||||
total_length = seq_len + past_key_values_length
|
||||
if self.seq_len_cached == -1:
|
||||
# warm up the cache
|
||||
total_length = max(INFERENCE_MAX_LENGTH, total_length)
|
||||
|
||||
if total_length > self.seq_len_cached:
|
||||
with torch.inference_mode(False):
|
||||
self.seq_len_cached = total_length
|
||||
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
emb = emb.float()
|
||||
|
||||
self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
|
||||
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
|
||||
)
|
||||
|
||||
def forward(self, query, key, past_key_values_length=0):
|
||||
batch, seq_len, head_dim = query.shape
|
||||
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
|
||||
if seq_len == 1 and torch.is_inference_mode_enabled() and query.device.type == "cuda":
|
||||
return self._optimized_apply_rotary(query, key, cos, sin)
|
||||
else:
|
||||
return apply_rotary(query, key, cos, sin)
|
||||
|
||||
|
||||
def split_heads(
|
||||
fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch, seq_len, _ = fused_qkv.shape
|
||||
qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)
|
||||
query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3)
|
||||
key = torch.broadcast_to(key, query.shape)
|
||||
value = torch.broadcast_to(value, query.shape)
|
||||
|
||||
query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
|
||||
return query, key, value
|
||||
|
||||
|
||||
class OptimizedFalconAttention(FalconAttention):
|
||||
def __init__(self, config: FalconConfig):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.split_size = self.hidden_size
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
if self.head_dim * self.num_heads != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
|
||||
self.maybe_rotary = OptimizedFalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||
self.beta = self.inv_norm_factor
|
||||
if config.new_decoder_architecture:
|
||||
qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
|
||||
elif config.multi_query:
|
||||
qkv_out_dim = self.hidden_size + 2 * self.head_dim
|
||||
else:
|
||||
qkv_out_dim = 3 * self.hidden_size
|
||||
self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
|
||||
self.new_decoder_architecture = config.new_decoder_architecture
|
||||
self.multi_query = config.multi_query
|
||||
self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
|
||||
|
||||
if self.new_decoder_architecture:
|
||||
self._split_heads = partial(
|
||||
split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
|
||||
)
|
||||
self.split_graph = None
|
||||
self.input_surface = None
|
||||
self.static_outputs = None
|
||||
|
||||
def _optimized_split_heads(self, fused_qkv):
|
||||
if self.split_graph is None:
|
||||
self.split_graph = torch.cuda.CUDAGraph()
|
||||
self.input_surface = fused_qkv
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(3):
|
||||
self._split_heads(fused_qkv)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
with torch.cuda.graph(self.split_graph):
|
||||
self.static_outputs = self._split_heads(self.input_surface)
|
||||
|
||||
self.input_surface.copy_(fused_qkv)
|
||||
self.split_graph.replay()
|
||||
return tuple(o.detach() for o in self.static_outputs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
assert not output_attentions
|
||||
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
|
||||
if (
|
||||
self.new_decoder_architecture
|
||||
and hidden_states.size(1) == 1
|
||||
and torch.is_inference_mode_enabled()
|
||||
and hidden_states.device.type == "cuda"
|
||||
):
|
||||
query_layer, key_layer, value_layer = self._optimized_split_heads(fused_qkv)
|
||||
else:
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
|
||||
num_kv_heads = self.num_heads
|
||||
batch_size, query_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
|
||||
key_layer = key_layer.transpose(1, 2).reshape(
|
||||
batch_size * num_kv_heads,
|
||||
query_length,
|
||||
self.head_dim,
|
||||
)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
|
||||
|
||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
_, kv_length, _ = key_layer.shape
|
||||
if use_cache:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
||||
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
|
||||
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
||||
|
||||
if alibi is None:
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
|
||||
|
||||
output_tensor = self.dense(attn_output)
|
||||
|
||||
return output_tensor, present
|
||||
else:
|
||||
matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
|
||||
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
||||
attention_scores = attention_scores.to(torch.float32)
|
||||
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
|
||||
# adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
|
||||
# equivalent and more performant, but there might be a numerical difference. If you're reading this
|
||||
# and you'd like to experiment and maybe file a PR, feel free!
|
||||
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
|
||||
attention_logits *= self.inv_norm_factor
|
||||
attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
|
||||
# [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
# change view [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
|
||||
|
||||
# change view [batch_size, q_length, num_heads * head_dim]
|
||||
context_layer = self._merge_heads(context_layer)
|
||||
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
if output_attentions:
|
||||
return output_tensor, present, attention_probs
|
||||
else:
|
||||
return output_tensor, present
|
||||
|
||||
|
||||
class OptimizedFalconDecoderLayer(FalconDecoderLayer):
|
||||
def __init__(self, config: FalconConfig):
|
||||
nn.Module.__init__(self)
|
||||
hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
self.mlp = FalconMLP(config)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
self.config = config
|
||||
|
||||
self.self_attention = OptimizedFalconAttention(config)
|
||||
|
||||
if self.config.alibi or not config.new_decoder_architecture:
|
||||
if config.new_decoder_architecture:
|
||||
# The layer norm before self-attention
|
||||
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
# The layer norm before the MLP
|
||||
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
else:
|
||||
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
if not config.parallel_attn:
|
||||
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
else:
|
||||
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.ln_graph = None
|
||||
self.static_input = None
|
||||
self.static_outputs = None
|
||||
|
||||
def _optimized_apply_ln(self, hidden_states):
|
||||
if self.ln_graph is None:
|
||||
self.ln_graph = torch.cuda.CUDAGraph()
|
||||
self.static_input = hidden_states
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(3):
|
||||
self.ln_attn(hidden_states)
|
||||
self.ln_mlp(hidden_states)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
with torch.cuda.graph(self.ln_graph):
|
||||
ln_attn_output = self.ln_attn(hidden_states)
|
||||
ln_mlp_output = self.ln_mlp(hidden_states)
|
||||
self.static_outputs = (ln_attn_output, ln_mlp_output)
|
||||
|
||||
self.static_input.copy_(hidden_states)
|
||||
self.ln_graph.replay()
|
||||
return tuple(o.detach() for o in self.static_outputs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
|
||||
attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states)
|
||||
else:
|
||||
attention_layernorm_out = self.ln_attn(hidden_states)
|
||||
mlp_layernorm_out = self.ln_mlp(hidden_states)
|
||||
else:
|
||||
attention_layernorm_out = self.input_layernorm(hidden_states)
|
||||
|
||||
attn_outputs = self.self_attention(
|
||||
attention_layernorm_out,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
alibi=alibi,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
attention_output = attn_outputs[0]
|
||||
|
||||
if not self.config.new_decoder_architecture:
|
||||
if self.config.parallel_attn:
|
||||
mlp_layernorm_out = attention_layernorm_out
|
||||
else:
|
||||
residual = dropout_add(
|
||||
attention_output, residual, self.config.attention_dropout, training=self.training
|
||||
)
|
||||
mlp_layernorm_out = self.post_attention_layernorm(residual)
|
||||
|
||||
outputs = attn_outputs[1:]
|
||||
|
||||
mlp_output = self.mlp(mlp_layernorm_out)
|
||||
|
||||
if self.config.new_decoder_architecture or self.config.parallel_attn:
|
||||
mlp_output += attention_output
|
||||
|
||||
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
|
||||
|
||||
if use_cache:
|
||||
outputs = (output,) + outputs
|
||||
else:
|
||||
outputs = (output,) + outputs[1:]
|
||||
|
||||
return outputs # hidden_states, present, attentions
|
||||
|
||||
|
||||
class WrappedFalconBlock(OptimizedFalconDecoderLayer):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
alibi: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[KVCache] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
assert attention_mask is None
|
||||
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
if layer_past is not None:
|
||||
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
|
||||
past_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
seq_length_with_past = seq_length + past_length
|
||||
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
if alibi is None and self.config.alibi:
|
||||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
||||
|
||||
outputs = super().forward(
|
||||
hidden_states,
|
||||
*args,
|
||||
attention_mask=attention_mask,
|
||||
alibi=alibi,
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
present_key_value = outputs[-1]
|
||||
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
|
||||
outputs = outputs[:-1] + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
|
||||
key_states, value_states = key_value
|
||||
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
key_states = self._expand_states(key_states)
|
||||
value_states = self._expand_states(value_states)
|
||||
|
||||
return (key_states, value_states)
|
||||
|
||||
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
|
||||
key_states, value_states = key_value
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
key_states = self._collapse_states(key_states)
|
||||
value_states = self._collapse_states(value_states)
|
||||
|
||||
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
|
||||
return (key_states, value_states)
|
||||
|
||||
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
|
||||
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
|
||||
|
||||
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
|
||||
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
|
||||
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
|
||||
return state
|
||||
|
||||
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
|
||||
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
|
||||
|
||||
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
|
||||
state = state[:, :, 0]
|
||||
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
|
||||
return state
|
@ -0,0 +1,48 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from hivemind import get_logger
|
||||
from transformers.models.falcon import FalconConfig
|
||||
from transformers.models.falcon.modeling_falcon import FalconAttention
|
||||
|
||||
from petals.client.config import ClientConfig
|
||||
from petals.client.lm_head import LMHeadConfig
|
||||
from petals.client.ptune import PTuneConfig
|
||||
from petals.models.falcon.block import WrappedFalconBlock
|
||||
from petals.utils.auto_config import DefaultRevisionMixin
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
|
||||
block_class = WrappedFalconBlock
|
||||
attn_class = FalconAttention
|
||||
block_prefix = "transformer.h"
|
||||
|
||||
@property
|
||||
def num_key_value_groups(self) -> int:
|
||||
if self.new_decoder_architecture:
|
||||
return self.num_attention_heads // self.num_kv_heads
|
||||
if self.multi_query:
|
||||
return self.num_attention_heads
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||
):
|
||||
if "180B" in model_name_or_path.upper():
|
||||
logger.info("Make sure you follow the Falcon-180B license: https://bit.ly/falcon-180b-license")
|
||||
|
||||
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||
if loading_from_repo and dht_prefix is None:
|
||||
dht_prefix = str(model_name_or_path)
|
||||
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
|
||||
dht_prefix = dht_prefix.replace(".", "-")
|
||||
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||
|
||||
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||
config = result[0] if isinstance(result, tuple) else result
|
||||
if config.pad_token_id is None:
|
||||
config.pad_token_id = 0
|
||||
return result
|
@ -0,0 +1,154 @@
|
||||
from typing import Optional
|
||||
|
||||
import hivemind
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.falcon import (
|
||||
FalconForCausalLM,
|
||||
FalconForSequenceClassification,
|
||||
FalconModel,
|
||||
FalconPreTrainedModel,
|
||||
)
|
||||
|
||||
from petals.client.from_pretrained import FromPretrainedMixin
|
||||
from petals.client.lm_head import LMHead
|
||||
from petals.client.ptune import PTuneMixin
|
||||
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.models.falcon.config import DistributedFalconConfig
|
||||
from petals.utils.auto_config import DefaultRevisionMixin
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel):
|
||||
"""FalconModel, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."]
|
||||
|
||||
config_class = DistributedFalconConfig
|
||||
|
||||
def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||
super().__init__(config)
|
||||
assert len(self.h) == 0
|
||||
config.num_hidden_layers = n_layer
|
||||
|
||||
self.h = RemoteSequential(config, dht=dht)
|
||||
|
||||
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||
self.init_prompts(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[RemotePastKeyValues] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
# The causal mask will be added on the server-side
|
||||
assert (
|
||||
attention_mask is None or (attention_mask == 1).all()
|
||||
), f"Custom attention masks are not supported, {attention_mask=}"
|
||||
assert (
|
||||
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
|
||||
), f"Non-consecutive position_ids are not supported, {position_ids=}"
|
||||
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
|
||||
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
|
||||
assert not output_attentions, f"{output_attentions=} is not supported"
|
||||
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
|
||||
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
|
||||
if use_prompts:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
else:
|
||||
prompts = intermediate_prompts = None
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
hidden_states = self.h(
|
||||
hidden_states,
|
||||
prompts=intermediate_prompts,
|
||||
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
|
||||
)
|
||||
|
||||
# Remove prefix
|
||||
if use_prompts:
|
||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=RemotePastKeyValues(),
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||
return nn.Identity()
|
||||
|
||||
|
||||
class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):
|
||||
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedFalconConfig
|
||||
|
||||
def __init__(self, config: DistributedFalconConfig):
|
||||
FalconPreTrainedModel.__init__(self, config)
|
||||
self.transformer = DistributedFalconModel(config)
|
||||
self.lm_head = LMHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
|
||||
class DistributedFalconForSequenceClassification(
|
||||
DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification
|
||||
):
|
||||
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedFalconConfig
|
||||
|
||||
def __init__(self, config: DistributedFalconConfig):
|
||||
FalconPreTrainedModel.__init__(self, config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = DistributedFalconModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
@ -0,0 +1,15 @@
|
||||
from petals.models.llama.block import WrappedLlamaBlock
|
||||
from petals.models.llama.config import DistributedLlamaConfig
|
||||
from petals.models.llama.model import (
|
||||
DistributedLlamaForCausalLM,
|
||||
DistributedLlamaForSequenceClassification,
|
||||
DistributedLlamaModel,
|
||||
)
|
||||
from petals.utils.auto_config import register_model_classes
|
||||
|
||||
register_model_classes(
|
||||
config=DistributedLlamaConfig,
|
||||
model=DistributedLlamaModel,
|
||||
model_for_causal_lm=DistributedLlamaForCausalLM,
|
||||
model_for_sequence_classification=DistributedLlamaForSequenceClassification,
|
||||
)
|
@ -0,0 +1,303 @@
|
||||
"""
|
||||
LLaMA intermediate layer
|
||||
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
See commit history for authorship.
|
||||
"""
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaConfig,
|
||||
LlamaDecoderLayer,
|
||||
LlamaMLP,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
repeat_kv,
|
||||
rotate_half,
|
||||
)
|
||||
|
||||
from petals.utils.cuda_graphs import make_inference_graphed_callable
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin):
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class OptimizedLlamaAttention(LlamaAttention):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._rotary_graph = None
|
||||
|
||||
def _optimized_apply_rotary(self, query_states, key_states, cos, sin):
|
||||
if self._rotary_graph is None:
|
||||
self._rotary_graph = make_inference_graphed_callable(
|
||||
apply_rotary_pos_emb, sample_args=(query_states, key_states, cos, sin)
|
||||
)
|
||||
return self._rotary_graph(query_states, key_states, cos, sin)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
assert not output_attentions
|
||||
if position_ids is None:
|
||||
past_seen_tokens = past_key_value[0].shape[2] if past_key_value is not None else 0
|
||||
position_ids = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
|
||||
).unsqueeze(0)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
||||
cos, sin = cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
|
||||
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
|
||||
else:
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = OptimizedLlamaAttention(config=config)
|
||||
self.mlp = LlamaMLP(config)
|
||||
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.pre_attn_graph = None
|
||||
self.post_attn_graph = None
|
||||
|
||||
def _optimized_input_layernorm(self, hidden_states):
|
||||
if self.pre_attn_graph is None:
|
||||
self.pre_attn_graph = make_inference_graphed_callable(
|
||||
self.input_layernorm.forward, sample_args=(hidden_states,)
|
||||
)
|
||||
return self.pre_attn_graph(hidden_states)
|
||||
|
||||
def _optimized_output_layernorm(self, hidden_states):
|
||||
if self.post_attn_graph is None:
|
||||
self.post_attn_graph = make_inference_graphed_callable(
|
||||
self.post_attention_layernorm.forward, sample_args=(hidden_states,)
|
||||
)
|
||||
return self.post_attn_graph(hidden_states)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
|
||||
hidden_states = self._optimized_input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
|
||||
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
|
||||
hidden_states = self._optimized_output_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
past_key_value = layer_past
|
||||
if past_key_value is not None:
|
||||
past_key_values_length = past_key_value[0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
|
||||
|
||||
assert position_ids is None
|
||||
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask=attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
inputs_embeds=hidden_states,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
outputs = super().forward(
|
||||
hidden_states,
|
||||
*args,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
present_key_value = outputs[-1]
|
||||
present_key_value = self._reorder_cache_from_llama_to_bloom(
|
||||
present_key_value, batch_size, seq_length_with_past
|
||||
)
|
||||
outputs = outputs[:-1] + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def _reorder_cache_from_bloom_to_llama(
|
||||
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||
) -> Tuple[torch.Tensor]:
|
||||
key_states, value_states = key_value
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
key_states = key_states.view(
|
||||
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||
)
|
||||
value_states = value_states.view(*key_states.shape)
|
||||
return (key_states, value_states)
|
||||
|
||||
def _reorder_cache_from_llama_to_bloom(
|
||||
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||
) -> Tuple[torch.Tensor]:
|
||||
key_states, value_states = key_value
|
||||
value_states = value_states.view(
|
||||
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||
)
|
||||
key_states = key_states.view(*value_states.shape)
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
return (key_states, value_states)
|
@ -0,0 +1,47 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from hivemind import get_logger
|
||||
from transformers.models.llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
|
||||
from petals.client.config import ClientConfig
|
||||
from petals.client.lm_head import LMHeadConfig
|
||||
from petals.client.ptune import PTuneConfig
|
||||
from petals.models.llama.block import WrappedLlamaBlock
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfig):
|
||||
block_class = WrappedLlamaBlock
|
||||
attn_class = LlamaAttention
|
||||
block_prefix = "model.layers"
|
||||
|
||||
@property
|
||||
def num_key_value_groups(self):
|
||||
return self.num_attention_heads // self.num_key_value_heads
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||
):
|
||||
logger.info(
|
||||
"Make sure you follow the LLaMA's terms of use: "
|
||||
"https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1"
|
||||
)
|
||||
|
||||
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||
if loading_from_repo and dht_prefix is None:
|
||||
dht_prefix = str(model_name_or_path)
|
||||
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
|
||||
dht_prefix = dht_prefix.replace(".", "-")
|
||||
if not dht_prefix.endswith("-hf"):
|
||||
dht_prefix += "-hf"
|
||||
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||
|
||||
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||
config = result[0] if isinstance(result, tuple) else result
|
||||
config.pretraining_tp = 1 # This may give less accurate results but it doesn't matter if we use quantization
|
||||
config.use_cache = True # use_cache=False leads to identical results but is slower and not supported by Petals
|
||||
return result
|
@ -0,0 +1,174 @@
|
||||
from typing import Optional
|
||||
|
||||
import hivemind
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
|
||||
|
||||
from petals.client.from_pretrained import FromPretrainedMixin
|
||||
from petals.client.lm_head import LMHead
|
||||
from petals.client.ptune import PTuneMixin
|
||||
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.models.llama.config import DistributedLlamaConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
|
||||
"""LlamaModel, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = [r"^model\.layers\."]
|
||||
|
||||
config_class = DistributedLlamaConfig
|
||||
|
||||
def __init__(self, config: DistributedLlamaConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||
super().__init__(config)
|
||||
assert len(self.layers) == 0
|
||||
config.num_hidden_layers = n_layer
|
||||
|
||||
self.layers = RemoteSequential(config, dht=dht)
|
||||
|
||||
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||
self.init_prompts(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[RemotePastKeyValues] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> BaseModelOutputWithPast:
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
# The causal mask will be added on the server-side
|
||||
assert (
|
||||
attention_mask is None or (attention_mask == 1).all()
|
||||
), f"Custom attention masks are not supported, {attention_mask=}"
|
||||
if cache_position is not None:
|
||||
assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item()
|
||||
assert (
|
||||
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
|
||||
), f"Non-consecutive position_ids are not supported, {position_ids=}"
|
||||
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
|
||||
assert not output_attentions, f"{output_attentions=} is not supported"
|
||||
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
|
||||
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0
|
||||
if use_prompts:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
else:
|
||||
prompts = intermediate_prompts = None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
hidden_states = self.layers(
|
||||
hidden_states,
|
||||
prompts=intermediate_prompts,
|
||||
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
|
||||
)
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = RemotePastKeyValues()
|
||||
past_key_values.update_seen(hidden_states.size(1))
|
||||
|
||||
# Remove prefix
|
||||
if use_prompts:
|
||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
|
||||
return self.embed_tokens
|
||||
|
||||
@property
|
||||
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||
return nn.Identity()
|
||||
|
||||
@property
|
||||
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
|
||||
return self.layers
|
||||
|
||||
@property
|
||||
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||
return self.norm
|
||||
|
||||
|
||||
class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, LlamaForCausalLM):
|
||||
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedLlamaConfig
|
||||
|
||||
def __init__(self, config: DistributedLlamaConfig):
|
||||
LlamaPreTrainedModel.__init__(self, config)
|
||||
self.model = DistributedLlamaModel(config)
|
||||
self.pretraining_tp = config.pretraining_tp
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = LMHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
@property
|
||||
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
|
||||
return self.model
|
||||
|
||||
|
||||
class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
|
||||
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedLlamaConfig
|
||||
|
||||
def __init__(self, config):
|
||||
LlamaPreTrainedModel.__init__(self, config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.model = DistributedLlamaModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@property
|
||||
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
|
||||
return self.model
|
@ -0,0 +1,15 @@
|
||||
from petals.models.mixtral.block import WrappedMixtralBlock
|
||||
from petals.models.mixtral.config import DistributedMixtralConfig
|
||||
from petals.models.mixtral.model import (
|
||||
DistributedMixtralForCausalLM,
|
||||
DistributedMixtralForSequenceClassification,
|
||||
DistributedMixtralModel,
|
||||
)
|
||||
from petals.utils.auto_config import register_model_classes
|
||||
|
||||
register_model_classes(
|
||||
config=DistributedMixtralConfig,
|
||||
model=DistributedMixtralModel,
|
||||
model_for_causal_lm=DistributedMixtralForCausalLM,
|
||||
model_for_sequence_classification=DistributedMixtralForSequenceClassification,
|
||||
)
|
@ -0,0 +1,114 @@
|
||||
import json
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import MixtralConfig
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel
|
||||
|
||||
|
||||
class WrappedMixtralBlock(MixtralDecoderLayer):
|
||||
def __init__(self, config: MixtralConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.sliding_window = config.sliding_window
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
past_key_value = layer_past
|
||||
|
||||
if past_key_value is not None:
|
||||
past_key_values_length = past_key_value[0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
|
||||
past_key_value = DynamicCache()
|
||||
past_key_value.key_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[0]]
|
||||
past_key_value.value_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[1]]
|
||||
past_key_value._seen_tokens = past_key_values_length
|
||||
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._attn_implementation == "sdpa":
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
sliding_window=self.sliding_window,
|
||||
)
|
||||
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=hidden_states.device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
|
||||
outputs = super().forward(
|
||||
hidden_states,
|
||||
*args,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
present_key_value = outputs[-1]
|
||||
present_key_value = present_key_value[self.layer_idx]
|
||||
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
|
||||
outputs = outputs[:-1] + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def _reorder_cache_from_bloom(
|
||||
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||
) -> Tuple[torch.Tensor]:
|
||||
# TODO: Move to mixin
|
||||
key_states, value_states = key_value
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
key_states = key_states.view(
|
||||
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||
)
|
||||
value_states = value_states.view(*key_states.shape)
|
||||
return (key_states, value_states)
|
||||
|
||||
def _reorder_cache_to_bloom(
|
||||
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||
) -> Tuple[torch.Tensor]:
|
||||
# TODO: Move to mixin
|
||||
key_states, value_states = key_value
|
||||
value_states = value_states.view(
|
||||
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||
)
|
||||
key_states = key_states.view(*value_states.shape)
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
return (key_states, value_states)
|
@ -0,0 +1,36 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from hivemind import get_logger
|
||||
from transformers.models.mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
|
||||
|
||||
from petals.client.config import ClientConfig
|
||||
from petals.client.lm_head import LMHeadConfig
|
||||
from petals.client.ptune import PTuneConfig
|
||||
from petals.models.mixtral.block import WrappedMixtralBlock
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedMixtralConfig(MixtralConfig, ClientConfig, PTuneConfig, LMHeadConfig):
|
||||
block_class = WrappedMixtralBlock
|
||||
attn_class = MixtralAttention
|
||||
block_prefix = "model.layers"
|
||||
|
||||
num_key_value_groups = 1
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||
):
|
||||
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||
if loading_from_repo and dht_prefix is None:
|
||||
dht_prefix = str(model_name_or_path)
|
||||
dht_prefix = dht_prefix.replace(".", "-")
|
||||
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||
config = result[0] if isinstance(result, tuple) else result
|
||||
if config.pad_token_id is None:
|
||||
config.pad_token_id = 0
|
||||
return result
|
@ -0,0 +1,178 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind import DHT
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_outputs import MoeModelOutputWithPast
|
||||
from transformers.models.mixtral import (
|
||||
MixtralForCausalLM,
|
||||
MixtralForSequenceClassification,
|
||||
MixtralModel,
|
||||
MixtralPreTrainedModel,
|
||||
)
|
||||
|
||||
from petals.client.from_pretrained import FromPretrainedMixin
|
||||
from petals.client.lm_head import LMHead
|
||||
from petals.client.ptune import PTuneMixin
|
||||
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.models.mixtral.config import DistributedMixtralConfig
|
||||
from petals.utils.auto_config import DefaultRevisionMixin
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, MixtralModel):
|
||||
"""MixtralModel, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = [r"^model\.layers\."]
|
||||
|
||||
config_class = DistributedMixtralConfig
|
||||
|
||||
def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None):
|
||||
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||
super().__init__(config)
|
||||
assert len(self.layers) == 0
|
||||
config.num_hidden_layers = n_layer
|
||||
|
||||
self.layers = RemoteSequential(config, dht=dht)
|
||||
|
||||
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||
self.init_prompts(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[RemotePastKeyValues] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
# The causal mask will be added on the server-side
|
||||
assert (
|
||||
attention_mask is None or (attention_mask == 1).all()
|
||||
), f"Custom attention masks are not supported, {attention_mask=}"
|
||||
assert (
|
||||
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
|
||||
), f"Non-consecutive position_ids are not supported, {position_ids=}"
|
||||
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
|
||||
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
|
||||
assert not output_attentions, f"{output_attentions=} is not supported"
|
||||
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
|
||||
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
|
||||
assert not output_router_logits, f"{output_router_logits=} is not supported"
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
|
||||
if use_prompts:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
else:
|
||||
prompts = intermediate_prompts = None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = RemotePastKeyValues()
|
||||
past_key_values.update_seen(hidden_states.size(1))
|
||||
|
||||
hidden_states = self.layers(
|
||||
hidden_states,
|
||||
prompts=intermediate_prompts,
|
||||
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
|
||||
)
|
||||
|
||||
# Remove prefix
|
||||
if use_prompts:
|
||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
|
||||
return self.embed_tokens
|
||||
|
||||
@property
|
||||
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
|
||||
return nn.Identity()
|
||||
|
||||
@property
|
||||
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
|
||||
return self.layers
|
||||
|
||||
@property
|
||||
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
|
||||
return self.norm
|
||||
|
||||
|
||||
class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM):
|
||||
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedMixtralConfig
|
||||
|
||||
def __init__(self, config: DistributedMixtralConfig):
|
||||
MixtralPreTrainedModel.__init__(self, config)
|
||||
self.model = DistributedMixtralModel(config)
|
||||
self.lm_head = LMHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
@property
|
||||
def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin
|
||||
return self.model
|
||||
|
||||
|
||||
class DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification):
|
||||
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedMixtralConfig
|
||||
|
||||
def __init__(self, config: DistributedMixtralConfig):
|
||||
MixtralPreTrainedModel.__init__(self, config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.model = DistributedMixtralModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@property
|
||||
def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin
|
||||
return self.model
|
@ -0,0 +1,230 @@
|
||||
"""
|
||||
This module implements server-side computations on served blocks: forward, backward and inference; used by handler
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
|
||||
from hivemind.moe.expert_uid import ExpertUID
|
||||
from hivemind.proto import runtime_pb2
|
||||
from hivemind.utils.logging import get_logger
|
||||
from hivemind.utils.nested import nested_flatten
|
||||
|
||||
from petals.data_structures import Handle, InferenceMetadata
|
||||
from petals.server.backend import TransformerBackend
|
||||
from petals.server.task_pool import PrioritizedTaskPool
|
||||
from petals.server.task_prioritizer import TaskPrioritizerBase
|
||||
from petals.utils.convert_block import QuantType
|
||||
from petals.utils.misc import DUMMY, is_dummy
|
||||
from petals.utils.packaging import unpack_args_kwargs
|
||||
|
||||
# We prioritize short inference requests and make them use a *merged* inference pool,
|
||||
# so they are processed without interruptions and extra overheads
|
||||
# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward
|
||||
MAX_SHORT_INFERENCE_TOKENS = 128
|
||||
MAX_NF4_SHORT_INFERENCE_TOKENS = 1
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def run_rpc_forward(
|
||||
*flat_tensors: torch.Tensor,
|
||||
requested_backends: Sequence[TransformerBackend],
|
||||
active_adapter: str = "",
|
||||
prioritizer: TaskPrioritizerBase,
|
||||
points: int = 0,
|
||||
args_structure: Any = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
||||
|
||||
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
|
||||
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
|
||||
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
||||
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
||||
"""
|
||||
if args_structure is not None:
|
||||
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
||||
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
||||
hidden_states, prompts, *_ = flat_tensors
|
||||
|
||||
dtype = requested_backends[0].dtype
|
||||
# check parse input tensors and cast dtypes
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
assert hidden_states.ndim == 3
|
||||
if prompts is None or is_dummy(prompts):
|
||||
prompts = [DUMMY] * len(requested_backends)
|
||||
else:
|
||||
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||
|
||||
# Run a chain of requested backends
|
||||
for backend, prompt in zip(requested_backends, prompts):
|
||||
if not is_dummy(prompt):
|
||||
hidden_states[:, : prompt.shape[1]] += prompt
|
||||
|
||||
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||
priority = prioritizer.prioritize(
|
||||
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
|
||||
)
|
||||
(hidden_states,) = await backend.forward_pool.submit_task(
|
||||
hidden_states,
|
||||
active_adapter,
|
||||
priority=priority,
|
||||
)
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
assert (
|
||||
hidden_states.ndim == 3
|
||||
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
async def run_rpc_backward(
|
||||
*flat_tensors: torch.Tensor,
|
||||
requested_backends: Sequence[TransformerBackend],
|
||||
active_adapter: str = "",
|
||||
prioritizer: TaskPrioritizerBase,
|
||||
points: int = 0,
|
||||
args_structure: Any = None,
|
||||
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
||||
if args_structure is not None:
|
||||
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
||||
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
||||
inputs, grad_outputs, prompts, *_ = flat_tensors
|
||||
|
||||
# Cast inputs & grad outputs to backend dtype
|
||||
inputs = inputs.to(requested_backends[0].dtype)
|
||||
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
||||
|
||||
if prompts is None or is_dummy(prompts):
|
||||
prompts = [DUMMY] * len(requested_backends)
|
||||
else:
|
||||
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||
|
||||
# Run a forward chain to collect intermediate inputs
|
||||
# Note that we do not forward for the last module since we do not need its output
|
||||
inter_inputs = []
|
||||
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
||||
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
||||
if not is_dummy(prompt):
|
||||
inputs[:, : prompt.shape[1]] += prompt
|
||||
inter_inputs.append(inputs)
|
||||
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||
priority = prioritizer.prioritize(
|
||||
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
||||
)
|
||||
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
|
||||
|
||||
assert isinstance(inputs, torch.Tensor)
|
||||
|
||||
if not is_dummy(prompts[-1]):
|
||||
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
|
||||
inter_inputs.append(inputs)
|
||||
|
||||
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
|
||||
grad_prompts_reversed = []
|
||||
# Run a chain of requested backends
|
||||
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
|
||||
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||
priority = prioritizer.prioritize(
|
||||
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
||||
)
|
||||
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
|
||||
|
||||
assert isinstance(grad_outputs, torch.Tensor)
|
||||
if not is_dummy(prompt):
|
||||
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
||||
|
||||
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
||||
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
|
||||
|
||||
|
||||
async def iterate_rpc_inference(
|
||||
requested_uids: Sequence[ExpertUID],
|
||||
requested_backends: Sequence[TransformerBackend],
|
||||
active_adapter: Optional[str],
|
||||
input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
|
||||
cache_handles: Sequence[Sequence[Handle]],
|
||||
*,
|
||||
max_length: int,
|
||||
prioritizer: TaskPrioritizerBase,
|
||||
points: int,
|
||||
quant_type: QuantType,
|
||||
args_structure: Any = None,
|
||||
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool, Dict]]:
|
||||
assert len(cache_handles) == len(requested_backends)
|
||||
|
||||
prefix_length = 0
|
||||
point_per_piece = points / max_length if max_length > 0 else 0.0
|
||||
|
||||
async for request, step_metadata in input_iterator:
|
||||
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
|
||||
if args_structure is not None:
|
||||
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
||||
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
||||
|
||||
hidden_states, prompts, hypo_ids, *_ = flat_tensors
|
||||
batch_size, length_increment, _ = hidden_states.shape
|
||||
|
||||
# Cast inputs to backend dtype
|
||||
hidden_states = hidden_states.to(requested_backends[0].dtype)
|
||||
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
|
||||
|
||||
# parse deep prompts (optional argument)
|
||||
has_prompts = prompts is not None and not is_dummy(prompts)
|
||||
if not has_prompts:
|
||||
prompts = [None] * len(requested_backends)
|
||||
else:
|
||||
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
|
||||
|
||||
if not (len(requested_backends) == len(prompts)):
|
||||
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
|
||||
|
||||
if prefix_length + length_increment > max_length:
|
||||
raise ValueError(
|
||||
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
|
||||
f" exceeds pre-allocated maximum {max_length}"
|
||||
)
|
||||
|
||||
merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS
|
||||
can_merge_pools = batch_size * length_increment <= merge_max_tokens
|
||||
priority = prioritizer.prioritize(
|
||||
hidden_states,
|
||||
hypo_ids,
|
||||
points=point_per_piece,
|
||||
requested_uids=requested_uids,
|
||||
type="inference",
|
||||
)
|
||||
|
||||
# A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.
|
||||
# when user wants to pre-allocate cache or check that server *can* allocate that cache.
|
||||
if hidden_states.numel() > 0:
|
||||
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
|
||||
if can_merge_pools:
|
||||
inference_infos = tuple(
|
||||
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
|
||||
for uid, handles in zip(requested_uids, cache_handles)
|
||||
)
|
||||
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
|
||||
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
|
||||
)
|
||||
else:
|
||||
for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
|
||||
inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
|
||||
(hidden_states,) = await backend.inference_pool.submit_task(
|
||||
hidden_states, hypo_ids, inference_infos, prompt, priority=priority
|
||||
)
|
||||
|
||||
# serialize and send last layer outputs
|
||||
output_tensors = [
|
||||
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
||||
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
||||
]
|
||||
can_push = not has_prompts
|
||||
yield output_tensors, can_push, step_metadata
|
||||
|
||||
# prepare for next step
|
||||
prefix_length += length_increment
|
@ -0,0 +1,229 @@
|
||||
"""
|
||||
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
|
||||
If necessary, one can rewrite this to implement a different behavior, such as:
|
||||
- loading files from a local data source (e.g. S3)
|
||||
- load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
|
||||
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
|
||||
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from hivemind.utils.logging import get_logger
|
||||
from huggingface_hub import get_hf_file_metadata, hf_hub_url
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
from transformers.utils import get_file_from_repo
|
||||
|
||||
from petals.constants import DTYPE_MAP
|
||||
from petals.models.mixtral import WrappedMixtralBlock
|
||||
from petals.server.block_utils import get_model_block, resolve_block_dtype
|
||||
from petals.utils.auto_config import AutoDistributedConfig
|
||||
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||
from petals.utils.hf_auth import always_needs_auth
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def load_pretrained_block(
|
||||
model_name: str,
|
||||
block_index: int,
|
||||
*,
|
||||
config: Optional[PretrainedConfig] = None,
|
||||
torch_dtype: Union[torch.dtype, str] = "auto",
|
||||
revision: Optional[str] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
max_disk_space: Optional[int] = None,
|
||||
) -> nn.Module:
|
||||
if config is None:
|
||||
config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=token)
|
||||
if cache_dir is None:
|
||||
cache_dir = DEFAULT_CACHE_DIR
|
||||
|
||||
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
||||
torch_dtype = resolve_block_dtype(config, torch_dtype)
|
||||
|
||||
with init_empty_weights():
|
||||
block = get_model_block(config, layer_idx=block_index)
|
||||
|
||||
block_prefix = f"{config.block_prefix}.{block_index}."
|
||||
state_dict = _load_state_dict_from_repo(
|
||||
model_name,
|
||||
block_prefix,
|
||||
revision=revision,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
max_disk_space=max_disk_space,
|
||||
)
|
||||
|
||||
# dummy load, check that keys match
|
||||
report = block.load_state_dict(state_dict, strict=False)
|
||||
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
|
||||
|
||||
for param_name, _ in block.named_parameters():
|
||||
assert param_name in state_dict, f"{param_name} not in state dict"
|
||||
param = state_dict[param_name]
|
||||
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
param = param.to(torch_dtype)
|
||||
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
||||
|
||||
logger.info(f"Loaded {model_name} block {block_index}")
|
||||
logger.debug(f"Details: {report}")
|
||||
return block
|
||||
|
||||
|
||||
StateDict = Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
def _load_state_dict_from_repo(
|
||||
model_name: str,
|
||||
block_prefix: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
cache_dir: str,
|
||||
max_disk_space: Optional[int] = None,
|
||||
) -> StateDict:
|
||||
if always_needs_auth(model_name) and token is None:
|
||||
token = True
|
||||
|
||||
index_file = _find_index_file(model_name, revision=revision, token=token, cache_dir=cache_dir)
|
||||
if index_file.endswith(".index.json"): # Sharded model
|
||||
path = get_file_from_repo(model_name, filename=index_file, use_auth_token=token, cache_dir=cache_dir)
|
||||
if path is None:
|
||||
# _find_index_file() told that a file exists but we can't get it (e.g., it just disappeared)
|
||||
raise ValueError(f"Failed to get file {index_file}")
|
||||
|
||||
with open(path) as f:
|
||||
index = json.load(f)
|
||||
filenames = {
|
||||
filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
|
||||
}
|
||||
if not filenames:
|
||||
raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
|
||||
else: # Non-sharded model
|
||||
filenames = {index_file}
|
||||
logger.debug(f"Loading {block_prefix}* from {filenames}")
|
||||
|
||||
state_dict = {}
|
||||
for filename in filenames:
|
||||
shard_state_dict = _load_state_dict_from_repo_file(
|
||||
model_name,
|
||||
filename,
|
||||
block_prefix=block_prefix,
|
||||
revision=revision,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
max_disk_space=max_disk_space,
|
||||
)
|
||||
shard_state_dict = {
|
||||
param_name[len(block_prefix) :]: param
|
||||
for param_name, param in shard_state_dict.items()
|
||||
if param_name.startswith(block_prefix)
|
||||
} # Remove unused parameters from memory
|
||||
state_dict.update(shard_state_dict)
|
||||
return state_dict
|
||||
|
||||
|
||||
INDEX_FILES = ["model.safetensors.index.json", "model.safetensors", "pytorch_model.bin.index.json", "pytorch_model.bin"]
|
||||
|
||||
|
||||
def _find_index_file(
|
||||
model_name: str, *, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str
|
||||
) -> str:
|
||||
# If we have cached weights (e.g., Pickle from older Petals versions), reuse them
|
||||
for filename in INDEX_FILES:
|
||||
path = get_file_from_repo(
|
||||
model_name,
|
||||
filename,
|
||||
revision=revision,
|
||||
use_auth_token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=True,
|
||||
)
|
||||
if path is not None:
|
||||
return filename
|
||||
|
||||
# If we don't, prefer Safetensors when possible
|
||||
# (we don't download files here since we can't account for max_disk_space in case of large files)
|
||||
for filename in INDEX_FILES:
|
||||
with suppress(EntryNotFoundError):
|
||||
get_hf_file_metadata(hf_hub_url(model_name, filename, revision=revision), token=token)
|
||||
return filename
|
||||
|
||||
raise ValueError(
|
||||
f"Repo {model_name} does not contain weights in a supported format: files {INDEX_FILES} do not exist"
|
||||
)
|
||||
|
||||
|
||||
def _load_state_dict_from_repo_file(
|
||||
model_name: str,
|
||||
filename: str,
|
||||
*,
|
||||
block_prefix: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
cache_dir: str,
|
||||
max_disk_space: Optional[int] = None,
|
||||
delay: float = 30,
|
||||
) -> StateDict:
|
||||
# First, try to find the weights locally
|
||||
try:
|
||||
with allow_cache_reads(cache_dir):
|
||||
path = get_file_from_repo(
|
||||
model_name,
|
||||
filename,
|
||||
revision=revision,
|
||||
use_auth_token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=True,
|
||||
)
|
||||
if path is not None:
|
||||
return _load_state_dict_from_local_file(path, block_prefix=block_prefix)
|
||||
except Exception:
|
||||
logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
|
||||
|
||||
# If not found, ensure that we have enough disk space to download them (maybe remove something)
|
||||
while True:
|
||||
try:
|
||||
with allow_cache_writes(cache_dir):
|
||||
url = hf_hub_url(model_name, filename, revision=revision)
|
||||
file_size = get_hf_file_metadata(url, token=token).size
|
||||
if file_size is not None:
|
||||
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
||||
else:
|
||||
logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}")
|
||||
|
||||
path = get_file_from_repo(
|
||||
model_name,
|
||||
filename,
|
||||
revision=revision,
|
||||
use_auth_token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
if path is None:
|
||||
raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
|
||||
return _load_state_dict_from_local_file(path, block_prefix=block_prefix)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
def _load_state_dict_from_local_file(path: str, *, block_prefix: Optional[str] = None) -> StateDict:
|
||||
if path.endswith(".bin"):
|
||||
return torch.load(path, map_location="cpu")
|
||||
|
||||
if path.endswith(".safetensors"):
|
||||
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
|
||||
return {key: f.get_tensor(key) for key in f.keys() if block_prefix is None or key.startswith(block_prefix)}
|
||||
|
||||
raise ValueError(f"Unknown weight format: {path}")
|
@ -0,0 +1,7 @@
|
||||
from petals.utils.auto_config import (
|
||||
AutoDistributedConfig,
|
||||
AutoDistributedModel,
|
||||
AutoDistributedModelForCausalLM,
|
||||
AutoDistributedModelForSequenceClassification,
|
||||
)
|
||||
from petals.utils.dht import declare_active_modules, get_remote_module_infos
|
@ -0,0 +1,94 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
from hivemind import get_logger
|
||||
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
||||
|
||||
from petals.utils.hf_auth import always_needs_auth
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ModelClasses:
|
||||
config: Type[PretrainedConfig]
|
||||
model: Optional[Type[PreTrainedModel]] = None
|
||||
model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
|
||||
model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None
|
||||
|
||||
|
||||
_CLASS_MAPPING = {} # Populated by petals.models.* subpackages with register_model_classes()
|
||||
|
||||
|
||||
def register_model_classes(*, config: Type[PretrainedConfig], **kwargs):
|
||||
assert issubclass(config, PretrainedConfig)
|
||||
assert config.model_type not in _CLASS_MAPPING, f"Model type {config.model_type} is already registered"
|
||||
|
||||
_CLASS_MAPPING[config.model_type] = _ModelClasses(config=config, **kwargs)
|
||||
|
||||
|
||||
class _AutoDistributedBase:
|
||||
_mapping_field = None # Should be defined in child classes
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig:
|
||||
if (
|
||||
always_needs_auth(model_name_or_path)
|
||||
and kwargs.get("token") is None
|
||||
and kwargs.get("use_auth_token") is None
|
||||
):
|
||||
kwargs["use_auth_token"] = True
|
||||
|
||||
config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs)
|
||||
if config.model_type not in _CLASS_MAPPING:
|
||||
raise ValueError(f"Petals does not support model type {config.model_type}")
|
||||
|
||||
proper_cls = getattr(_CLASS_MAPPING[config.model_type], cls._mapping_field)
|
||||
if proper_cls is None:
|
||||
raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}")
|
||||
|
||||
return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs)
|
||||
|
||||
|
||||
class DefaultRevisionMixin:
|
||||
"""
|
||||
Petals only supports Falcon loaded in the new in-library format (transformers.FalconModel).
|
||||
TII models were recently converted to this format but then reverted back due to compatibility issues.
|
||||
We chose to support only the new format since HF staff promised to eventually convert these models
|
||||
to the new format again, see https://huggingface.co/tiiuae/falcon-40b/discussions/90#64b4d23bf44fd957492f7602
|
||||
Until it happens, we override the default `main` revision for the TII repos with the commit
|
||||
pointing out to the model in the in-library format.
|
||||
"""
|
||||
|
||||
DEFAULT_REVISIONS = {
|
||||
"tiiuae/falcon-40b": "f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232",
|
||||
"tiiuae/falcon-40b-instruct": "7475ff8cfc36ed9a962b658ae3c33391566a85a5",
|
||||
"tiiuae/falcon-7b": "4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76",
|
||||
"tiiuae/falcon-7b-instruct": "f8dac3fff96d5debd43edf56fb4e1abcfffbef28",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name_or_path: Union[str, os.PathLike, None], *args, revision: Optional[str] = None, **kwargs
|
||||
):
|
||||
if revision is None and model_name_or_path in cls.DEFAULT_REVISIONS:
|
||||
revision = cls.DEFAULT_REVISIONS[model_name_or_path]
|
||||
logger.info(f"Loading {model_name_or_path}, revision {revision}")
|
||||
return super().from_pretrained(model_name_or_path, *args, revision=revision, **kwargs)
|
||||
|
||||
|
||||
class AutoDistributedConfig(DefaultRevisionMixin, _AutoDistributedBase):
|
||||
_mapping_field = "config"
|
||||
|
||||
|
||||
class AutoDistributedModel(DefaultRevisionMixin, _AutoDistributedBase):
|
||||
_mapping_field = "model"
|
||||
|
||||
|
||||
class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase):
|
||||
_mapping_field = "model_for_causal_lm"
|
||||
|
||||
|
||||
class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
|
||||
_mapping_field = "model_for_sequence_classification"
|
@ -0,0 +1,76 @@
|
||||
import torch
|
||||
from torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten
|
||||
|
||||
|
||||
def make_inference_graphed_callable(callable: callable, sample_args, num_warmup_iters=3):
|
||||
"""Similar to torch.cuda.make_graphed_callables, but takes only one function and does not build a graph for the backward pass"""
|
||||
assert not isinstance(callable, torch.nn.Module)
|
||||
if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
|
||||
raise RuntimeError(
|
||||
"make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
|
||||
)
|
||||
|
||||
flatten_arg, _ = _tree_flatten(sample_args)
|
||||
flatten_sample_args = tuple(flatten_arg)
|
||||
assert all(
|
||||
isinstance(arg, torch.Tensor) for arg in flatten_arg
|
||||
), "In the beta API, sample_args for each callable must contain only Tensors. Other types are not allowed."
|
||||
|
||||
len_user_args = len(sample_args)
|
||||
static_input_surface = flatten_sample_args
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
# Warmup
|
||||
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
|
||||
# from ending up in any captures.
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(num_warmup_iters):
|
||||
outputs, _ = _tree_flatten(callable(*sample_args))
|
||||
del outputs
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
# Capture forward graph
|
||||
with torch.cuda.graph(graph):
|
||||
outputs = callable(*sample_args)
|
||||
|
||||
flatten_outputs, output_unflatten_spec = _tree_flatten(outputs)
|
||||
static_outputs = tuple(flatten_outputs)
|
||||
|
||||
def make_graphed_function(
|
||||
graph,
|
||||
len_user_args,
|
||||
output_unflatten_spec,
|
||||
static_input_surface,
|
||||
static_outputs,
|
||||
):
|
||||
def replay_graph(*inputs):
|
||||
# At this stage, only the user args may (potentially) be new tensors.
|
||||
for i in range(len_user_args):
|
||||
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
|
||||
static_input_surface[i].copy_(inputs[i])
|
||||
graph.replay()
|
||||
assert isinstance(static_outputs, tuple)
|
||||
return tuple(o.detach() for o in static_outputs)
|
||||
|
||||
def functionalized(*user_args):
|
||||
# Runs the autograd function with inputs == all inputs to the graph that might require grad
|
||||
# (explicit user args + module parameters)
|
||||
# Assumes module params didn't change since capture.
|
||||
flatten_user_args, _ = _tree_flatten(user_args)
|
||||
out = replay_graph(*flatten_user_args)
|
||||
return _tree_unflatten(out, output_unflatten_spec)
|
||||
|
||||
return functionalized
|
||||
|
||||
# Put together the final graphed callable
|
||||
graphed = make_graphed_function(
|
||||
graph,
|
||||
len_user_args,
|
||||
output_unflatten_spec,
|
||||
static_input_surface,
|
||||
static_outputs,
|
||||
)
|
||||
return graphed
|
@ -0,0 +1,153 @@
|
||||
"""
|
||||
Utilities for declaring and retrieving active model layers using a shared DHT.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
from hivemind.dht import DHT, DHTNode, DHTValue
|
||||
from hivemind.p2p import PeerID
|
||||
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
|
||||
|
||||
from petals.data_structures import (
|
||||
CHAIN_DELIMITER,
|
||||
UID_DELIMITER,
|
||||
ModuleUID,
|
||||
RemoteModuleInfo,
|
||||
RemoteSpanInfo,
|
||||
ServerInfo,
|
||||
ServerState,
|
||||
parse_uid,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def declare_active_modules(
|
||||
dht: DHT,
|
||||
uids: Sequence[ModuleUID],
|
||||
server_info: ServerInfo,
|
||||
expiration_time: DHTExpiration,
|
||||
wait: bool = True,
|
||||
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
|
||||
"""
|
||||
Declare that your node serves the specified modules; update timestamps if declared previously
|
||||
|
||||
:param uids: a list of module ids to declare
|
||||
:param wait: if True, awaits for declaration to finish, otherwise runs in background
|
||||
:param throughput: specify your performance in terms of compute throughput
|
||||
:param expiration_time: declared modules will be visible for this many seconds
|
||||
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
|
||||
"""
|
||||
if isinstance(uids, str):
|
||||
uids = [uids]
|
||||
if not isinstance(uids, list):
|
||||
uids = list(uids)
|
||||
for uid in uids:
|
||||
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
|
||||
|
||||
return dht.run_coroutine(
|
||||
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
|
||||
return_future=not wait,
|
||||
)
|
||||
|
||||
|
||||
async def _declare_active_modules(
|
||||
dht: DHT,
|
||||
node: DHTNode,
|
||||
uids: List[ModuleUID],
|
||||
server_info: ServerInfo,
|
||||
expiration_time: DHTExpiration,
|
||||
) -> Dict[ModuleUID, bool]:
|
||||
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
||||
return await node.store_many(
|
||||
keys=uids,
|
||||
subkeys=[dht.peer_id.to_base58()] * len(uids),
|
||||
values=[server_info.to_tuple()] * len(uids),
|
||||
expiration_time=expiration_time,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
|
||||
def get_remote_module_infos(
|
||||
dht: DHT,
|
||||
uids: Sequence[ModuleUID],
|
||||
expiration_time: Optional[DHTExpiration] = None,
|
||||
active_adapter: Optional[str] = None,
|
||||
*,
|
||||
latest: bool = False,
|
||||
return_future: bool = False,
|
||||
) -> Union[List[RemoteModuleInfo], MPFuture]:
|
||||
return dht.run_coroutine(
|
||||
partial(
|
||||
_get_remote_module_infos,
|
||||
uids=uids,
|
||||
active_adapter=active_adapter,
|
||||
expiration_time=expiration_time,
|
||||
latest=latest,
|
||||
),
|
||||
return_future=return_future,
|
||||
)
|
||||
|
||||
|
||||
async def _get_remote_module_infos(
|
||||
dht: DHT,
|
||||
node: DHTNode,
|
||||
uids: List[ModuleUID],
|
||||
active_adapter: Optional[str],
|
||||
expiration_time: Optional[DHTExpiration],
|
||||
latest: bool,
|
||||
) -> List[RemoteModuleInfo]:
|
||||
if latest:
|
||||
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
|
||||
expiration_time = math.inf
|
||||
elif expiration_time is None:
|
||||
expiration_time = get_dht_time()
|
||||
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
||||
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
|
||||
|
||||
modules = [RemoteModuleInfo(uid=uid, servers={}) for uid in uids]
|
||||
for module_info in modules:
|
||||
metadata = found[module_info.uid]
|
||||
if metadata is None or not isinstance(metadata.value, dict):
|
||||
if metadata is not None:
|
||||
logger.warning(f"Incorrect metadata for {module_info.uid}: {metadata}")
|
||||
continue
|
||||
|
||||
for peer_id, server_info in metadata.value.items():
|
||||
try:
|
||||
peer_id = PeerID.from_base58(peer_id)
|
||||
server_info = ServerInfo.from_tuple(server_info.value)
|
||||
|
||||
if active_adapter and active_adapter not in server_info.adapters:
|
||||
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
|
||||
continue
|
||||
|
||||
module_info.servers[peer_id] = server_info
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.warning(f"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}")
|
||||
return modules
|
||||
|
||||
|
||||
def compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: ServerState) -> Dict[PeerID, RemoteSpanInfo]:
|
||||
block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0
|
||||
num_blocks = len(module_infos)
|
||||
|
||||
spans = {}
|
||||
for block_idx, module_info in enumerate(module_infos):
|
||||
for peer_id, server_info in sorted(module_info.servers.items()):
|
||||
if server_info.state.value < min_state.value:
|
||||
continue
|
||||
|
||||
if peer_id not in spans or spans[peer_id].state.value < server_info.state.value:
|
||||
spans[peer_id] = RemoteSpanInfo(
|
||||
peer_id=peer_id, start=block_idx, end=block_idx + 1, server_info=server_info
|
||||
)
|
||||
if server_info.start_block is not None and server_info.end_block is not None:
|
||||
spans[peer_id].start = max(server_info.start_block - block_offset, 0)
|
||||
spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks)
|
||||
elif spans[peer_id].state == server_info.state:
|
||||
spans[peer_id].end = max(spans[peer_id].end, block_idx + 1)
|
||||
return spans
|
@ -1,129 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
TokenIds = torch.Tensor
|
||||
HypoIds = torch.Tensor
|
||||
|
||||
|
||||
class DecodingAlgorithm(ABC):
|
||||
"""
|
||||
An abstract class for decoding algorithms. Describes the base function of those algorithms:
|
||||
they have to select new tokens and provide the corresponding hypotheses.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
"""
|
||||
:param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
|
||||
:return: A tuple of selected token ids and corresponding hypotheses.
|
||||
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class GreedyAlgorithm(DecodingAlgorithm):
|
||||
"""
|
||||
The simplest algorithm for decoding. It selects the most probable token.
|
||||
"""
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
"""
|
||||
Returns the most probable token. The second returned object is always a range of integers
|
||||
from 0 to batch_size - 1.
|
||||
"""
|
||||
return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
|
||||
|
||||
|
||||
class SamplingAlgorithm(DecodingAlgorithm):
|
||||
def __init__(self, temperature: float = 1.0):
|
||||
self.temperature = temperature
|
||||
|
||||
def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
"""
|
||||
:param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
|
||||
:param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
|
||||
:return: A tuple of selected token ids and corresponding hypotheses.
|
||||
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size).
|
||||
"""
|
||||
logits[indices_to_remove] = -float("Inf")
|
||||
probs = torch.softmax(logits / self.temperature, -1)
|
||||
return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
|
||||
return self.sample(logits, indices_to_remove)
|
||||
|
||||
|
||||
class TopKAlgorithm(SamplingAlgorithm):
|
||||
def __init__(self, top_k: int, temperature: float = 1.0) -> None:
|
||||
super().__init__(temperature=temperature)
|
||||
self.top_k = top_k
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
|
||||
return self.sample(logits, indices_to_remove)
|
||||
|
||||
|
||||
class NucleusAlgorithm(SamplingAlgorithm):
|
||||
def __init__(self, top_p: float, temperature: float = 1.0) -> None:
|
||||
super().__init__(temperature=temperature)
|
||||
self.top_p = top_p
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
|
||||
probs = torch.softmax(sorted_logits / self.temperature, -1)
|
||||
cumulative_probs = torch.cumsum(probs, dim=-1)
|
||||
|
||||
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
return self.sample(logits, indices_to_remove)
|
||||
|
||||
|
||||
class BeamSearchAlgorithm(DecodingAlgorithm):
|
||||
def __init__(self, num_beams: int, batch_size: int) -> None:
|
||||
self.num_beams = num_beams
|
||||
self._cur_num_beams = 1
|
||||
self.batch_size = batch_size
|
||||
|
||||
self._batch_beams = [list() for _ in range(batch_size)]
|
||||
|
||||
def __call__(self, logits: torch.Tensor):
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
probs = torch.log_softmax(sorted_logits, -1)
|
||||
|
||||
if len(self._batch_beams[0]) > 0:
|
||||
for batch_idx in range(self.batch_size):
|
||||
new_beams = []
|
||||
cur_beams = self._batch_beams[batch_idx]
|
||||
for beam_idx in range(len(cur_beams)):
|
||||
probs_idx = batch_idx + beam_idx * self.batch_size
|
||||
new_beam = cur_beams[beam_idx]
|
||||
for hypo_idx in range(self.num_beams):
|
||||
new_beams.append(
|
||||
(new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
|
||||
)
|
||||
self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
|
||||
else:
|
||||
for batch_idx in range(self.batch_size):
|
||||
for beam_idx in range(self.num_beams):
|
||||
self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
|
||||
|
||||
return_hypos = []
|
||||
return_tokens = []
|
||||
for batch_idx in range(self.batch_size):
|
||||
cur_beam = self._batch_beams[batch_idx]
|
||||
return_hypos.append(list())
|
||||
return_tokens.append(list())
|
||||
for beam in cur_beam:
|
||||
beam_idx = beam[1] // self.num_beams
|
||||
hypo_idx = batch_idx + beam_idx * self.batch_size
|
||||
token_idx = beam[1] % self.num_beams
|
||||
return_hypos[-1].append(hypo_idx)
|
||||
return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
|
||||
return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
|
||||
return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
|
||||
|
||||
return torch.tensor(return_tokens), torch.tensor(return_hypos)
|
@ -1,51 +0,0 @@
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ABCBloomConstraint(ABC):
|
||||
"""
|
||||
Base class of all kind of decoding constraints. It can be used to implement a new constraint.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
|
||||
:param tokens_id: The token id of the last chosen token.
|
||||
:param logits: The logits from the Bloom model.
|
||||
:param hypo_ids: The hypothesis ids of the last tokens.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class EosConstraint(ABCBloomConstraint):
|
||||
"""
|
||||
This constrained repeats EOS token if it was generated on the previous step.
|
||||
Args:
|
||||
prefix: The prefix of the sequence.
|
||||
eos_token_id: The id of the end of sentence token.
|
||||
pad_token_id: The id of the padding token.
|
||||
min_logits: The minimum logits that can be generated. Default: -1e6.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
|
||||
self.eos_token_id = eos_token_id
|
||||
self.min_logits = min_logits
|
||||
self.past_tokens = None
|
||||
|
||||
self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
|
||||
|
||||
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
|
||||
if self.past_tokens is not None:
|
||||
mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
|
||||
logits += self.min_logits * mask
|
||||
logits[mask[:, 0], self.eos_token_id] = 0
|
||||
|
||||
if tokens_id is not None:
|
||||
self.past_tokens = tokens_id
|
||||
self.wait_until_starting -= 1
|
||||
|
||||
return logits
|
@ -0,0 +1,7 @@
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
|
||||
def always_needs_auth(model_name: Union[str, os.PathLike, None]) -> bool:
|
||||
loading_from_repo = model_name is not None and not os.path.isdir(model_name)
|
||||
return loading_from_repo and model_name.startswith("meta-llama/Llama-2-")
|
@ -0,0 +1,49 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from hivemind import nested_flatten, nested_pack
|
||||
|
||||
# TODO: Move functions to hivemind
|
||||
|
||||
|
||||
def _mark_masked_tensor(index: int) -> bytes:
|
||||
return b"__T" + str(index).encode()
|
||||
|
||||
|
||||
def _is_masked_tensor(item: Any) -> bool:
|
||||
return isinstance(item, bytes) and item.startswith(b"__T")
|
||||
|
||||
|
||||
def _get_tensor_index(item: bytes) -> int:
|
||||
return int(item[3:])
|
||||
|
||||
|
||||
def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
|
||||
"""
|
||||
Check the function's arguments and pack all tensors into different flattened lists.
|
||||
:returns: a flattened list of tensors and args and kwargs, where tensors were masked
|
||||
"""
|
||||
masked_flat_values, flat_tensors, tensor_to_index = [], [], {}
|
||||
for value in nested_flatten((args, kwargs)):
|
||||
if isinstance(value, torch.Tensor):
|
||||
tensor_index = tensor_to_index.setdefault(value, len(flat_tensors))
|
||||
if tensor_index == len(flat_tensors):
|
||||
flat_tensors.append(value)
|
||||
masked_flat_values.append(_mark_masked_tensor(tensor_index))
|
||||
else:
|
||||
masked_flat_values.append(value)
|
||||
return flat_tensors, nested_pack(masked_flat_values, (args, kwargs))
|
||||
|
||||
|
||||
def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any):
|
||||
"""
|
||||
Restore arguments after `pack_args_kwargs` function.
|
||||
:returns: list of args and dict of kwargs
|
||||
"""
|
||||
return nested_pack(
|
||||
(
|
||||
value if not _is_masked_tensor(value) else flat_tensors[_get_tensor_index(value)]
|
||||
for value in nested_flatten(args_structure)
|
||||
),
|
||||
args_structure,
|
||||
)
|
@ -0,0 +1,288 @@
|
||||
import contextlib
|
||||
import re
|
||||
import time
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from hivemind.utils.logging import get_logger
|
||||
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
|
||||
from peft.config import PeftConfig
|
||||
from peft.tuners import lora
|
||||
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file
|
||||
from transformers.utils import get_file_from_repo
|
||||
|
||||
from petals.server.block_utils import get_model_block, resolve_block_dtype
|
||||
from petals.utils.convert_block import QuantType
|
||||
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||
from petals.utils.misc import get_size_in_bytes
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def check_peft_repository(repo_id: str) -> bool:
|
||||
return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}")
|
||||
|
||||
|
||||
def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
|
||||
tensors = dict()
|
||||
is_tensors_found = dict()
|
||||
common_layer_patter_re = (
|
||||
".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+"
|
||||
)
|
||||
with safe_open(filepath, framework=framework, device=device) as f:
|
||||
for k in f.keys():
|
||||
if re.match(common_layer_patter_re, k):
|
||||
is_tensors_found[block_idx] = True
|
||||
tensors[k] = f.get_tensor(k)
|
||||
if not is_tensors_found.get(block_idx, False):
|
||||
logger.warning(f"There is no peft weights for block {block_idx}")
|
||||
return tensors
|
||||
|
||||
|
||||
def get_adapter_from_repo(
|
||||
repo_id: str,
|
||||
block_idx: Optional[int] = None,
|
||||
device: Optional[int] = None,
|
||||
*,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs)
|
||||
if config_path is None:
|
||||
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
||||
config = PeftConfig.from_json_file(config_path)
|
||||
|
||||
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, use_auth_token=token, **kwargs)
|
||||
if weight_path is None:
|
||||
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
||||
if block_idx is None:
|
||||
return config, load_file(weight_path)
|
||||
return config, load_specific_module(block_idx, weight_path, device=device)
|
||||
|
||||
|
||||
def load_peft(
|
||||
repo_id: str,
|
||||
block_idx: Optional[int] = None,
|
||||
device: Optional[int] = None,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
cache_dir: str,
|
||||
max_disk_space: Optional[int] = None,
|
||||
delay: float = 30,
|
||||
):
|
||||
# TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
|
||||
|
||||
if not check_peft_repository(repo_id):
|
||||
raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
|
||||
|
||||
try:
|
||||
with allow_cache_reads(cache_dir):
|
||||
return get_adapter_from_repo(
|
||||
repo_id,
|
||||
block_idx,
|
||||
device,
|
||||
revision=revision,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
|
||||
|
||||
while True:
|
||||
try:
|
||||
with allow_cache_writes(cache_dir):
|
||||
config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
|
||||
config_file_size = get_hf_file_metadata(config_url, token=token).size
|
||||
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
||||
weight_file_size = get_hf_file_metadata(weight_url, token=token).size
|
||||
|
||||
file_size = config_file_size + weight_file_size
|
||||
if file_size is not None:
|
||||
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
||||
else:
|
||||
logger.warning(f"Failed to fetch size from peft repo {repo_id}")
|
||||
|
||||
return get_adapter_from_repo(
|
||||
repo_id,
|
||||
block_idx,
|
||||
device,
|
||||
revision=revision,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
class AdapterContextMixin:
|
||||
"""A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""
|
||||
|
||||
ADAPTER_NOT_SET = "__ADAPTER_NOT_SET"
|
||||
_context_active_adapter = ADAPTER_NOT_SET
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def using_adapter(active_adapter: Optional[str]):
|
||||
prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
AdapterContextMixin._context_active_adapter = prev
|
||||
|
||||
@property
|
||||
def active_adapter(self):
|
||||
if self._context_active_adapter == self.ADAPTER_NOT_SET:
|
||||
logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug")
|
||||
return self._context_active_adapter
|
||||
|
||||
@active_adapter.setter
|
||||
def active_adapter(self, value: Optional[str]):
|
||||
assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""
|
||||
|
||||
|
||||
using_adapter = AdapterContextMixin.using_adapter
|
||||
|
||||
|
||||
class LoraLinear(AdapterContextMixin, lora.Linear):
|
||||
"""LoRA linear layer that uses adapter selected via using_adapter"""
|
||||
|
||||
|
||||
class LoraLinear8bitLt(AdapterContextMixin, lora.Linear8bitLt):
|
||||
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
|
||||
|
||||
|
||||
class LoraLinear4bit(AdapterContextMixin, lora.Linear4bit):
|
||||
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""
|
||||
|
||||
|
||||
def create_lora_adapter(block, quant_type: QuantType):
|
||||
for _, module in block.named_modules():
|
||||
for child_name, child in module.named_children():
|
||||
lora_wrapped_child = None
|
||||
if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
|
||||
continue
|
||||
if quant_type == QuantType.INT8:
|
||||
kwargs = {
|
||||
"has_fp16_weights": False,
|
||||
"threshold": 6.0,
|
||||
"bias": hasattr(child, "bias") and child.bias is not None,
|
||||
}
|
||||
lora_wrapped_child = LoraLinear8bitLt(
|
||||
AdapterContextMixin.ADAPTER_NOT_SET,
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
**kwargs,
|
||||
)
|
||||
elif quant_type == QuantType.NF4:
|
||||
kwargs = {
|
||||
"compress_statistics": True,
|
||||
"quant_type": "nf4",
|
||||
"blocksize": 64,
|
||||
"bias": hasattr(child, "bias") and child.bias is not None,
|
||||
}
|
||||
lora_wrapped_child = LoraLinear4bit(
|
||||
AdapterContextMixin.ADAPTER_NOT_SET,
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
**kwargs,
|
||||
)
|
||||
lora_wrapped_child.compute_dtype = child.compute_dtype
|
||||
else:
|
||||
bias = hasattr(child, "bias") and child.bias is not None
|
||||
lora_wrapped_child = LoraLinear(
|
||||
AdapterContextMixin.ADAPTER_NOT_SET,
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=bias,
|
||||
)
|
||||
if lora_wrapped_child:
|
||||
lora_wrapped_child.weight = child.weight
|
||||
lora_wrapped_child.bias = child.bias
|
||||
for p in lora_wrapped_child.parameters():
|
||||
p.requires_grad = False
|
||||
setattr(module, child_name, lora_wrapped_child)
|
||||
|
||||
|
||||
def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
|
||||
assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
|
||||
if peft_config["lora_dropout"] > 0:
|
||||
logger.info(f"Adapter {adapter_name} has dropout enabled, this server will disable dropout")
|
||||
|
||||
for _, module in block.named_modules():
|
||||
for child_name, child in module.named_children():
|
||||
if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
|
||||
continue
|
||||
|
||||
if child_name in peft_config["target_modules"] or (
|
||||
isinstance(peft_config["target_modules"], str)
|
||||
and re.fullmatch(peft_config["target_modules"], child_name)
|
||||
):
|
||||
is_lora_a_loaded = False
|
||||
is_lora_b_loaded = False
|
||||
for peft_key in peft_state_dict:
|
||||
if child_name not in peft_key:
|
||||
continue
|
||||
|
||||
if adapter_name not in child.lora_A:
|
||||
child.update_layer(
|
||||
adapter_name,
|
||||
peft_config["r"],
|
||||
peft_config["lora_alpha"],
|
||||
lora_dropout=peft_config["lora_dropout"],
|
||||
init_lora_weights=peft_config["init_lora_weights"],
|
||||
)
|
||||
child.train(False)
|
||||
for p in child.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
if peft_key.endswith(".lora_A.weight"):
|
||||
child.lora_A[adapter_name].weight[...] = peft_state_dict[peft_key]
|
||||
is_lora_a_loaded = True
|
||||
elif peft_key.endswith(".lora_A.bias"):
|
||||
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
||||
elif peft_key.endswith(".lora_B.weight"):
|
||||
child.lora_B[adapter_name].weight[...] = peft_state_dict[peft_key]
|
||||
is_lora_b_loaded = True
|
||||
elif peft_key.endswith(".lora_B.bias"):
|
||||
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
||||
|
||||
if is_lora_a_loaded and is_lora_b_loaded:
|
||||
logger.debug(f"Loaded adapter {adapter_name} for block {block_index}.{child_name}")
|
||||
elif is_lora_a_loaded or is_lora_b_loaded:
|
||||
raise ValueError(f"Invalid adapter {adapter_name} for block {block_index}.{child_name}")
|
||||
logger.info(f"Loaded adapter {adapter_name} for block {block_index}")
|
||||
|
||||
|
||||
def estimate_adapter_memory_per_block(
|
||||
block_config: transformers.PretrainedConfig,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
adapters: Sequence[str],
|
||||
**load_peft_kwargs,
|
||||
) -> int:
|
||||
"""Get the number of extra bytes used to store a set of adapters per given block"""
|
||||
with init_empty_weights(include_buffers=True):
|
||||
block = get_model_block(block_config)
|
||||
base_block_parameters = sum(p.numel() for p in block.parameters())
|
||||
create_lora_adapter(block, quant_type=QuantType.NONE)
|
||||
|
||||
for adapter in adapters:
|
||||
peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)
|
||||
assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now"
|
||||
add_adapter_to_block(
|
||||
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
|
||||
)
|
||||
adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
|
||||
bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype))
|
||||
return adapter_parameters * bytes_per_parameter
|
@ -0,0 +1,64 @@
|
||||
import asyncio
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import hivemind
|
||||
from hivemind.proto import dht_pb2
|
||||
from hivemind.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def ping(
|
||||
peer_id: hivemind.PeerID,
|
||||
_dht: hivemind.DHT,
|
||||
node: hivemind.dht.DHTNode,
|
||||
*,
|
||||
wait_timeout: float = 5,
|
||||
) -> float:
|
||||
try:
|
||||
ping_request = dht_pb2.PingRequest(peer=node.protocol.node_info)
|
||||
start_time = time.perf_counter()
|
||||
await node.protocol.get_stub(peer_id).rpc_ping(ping_request, timeout=wait_timeout)
|
||||
return time.perf_counter() - start_time
|
||||
except Exception as e:
|
||||
if str(e) == "protocol not supported": # Happens on servers with client-mode DHT (e.g., reachable via relays)
|
||||
return time.perf_counter() - start_time
|
||||
|
||||
logger.debug(f"Failed to ping {peer_id}:", exc_info=True)
|
||||
return math.inf
|
||||
|
||||
|
||||
async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) -> Dict[hivemind.PeerID, float]:
|
||||
rpc_infos = await asyncio.gather(*[ping(peer_id, *args, **kwargs) for peer_id in peer_ids])
|
||||
return dict(zip(peer_ids, rpc_infos))
|
||||
|
||||
|
||||
class PingAggregator:
|
||||
def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 300):
|
||||
self.dht = dht
|
||||
self.ema_alpha = ema_alpha
|
||||
self.expiration = expiration
|
||||
self.ping_emas = hivemind.TimedStorage()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs) -> None:
|
||||
current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs))
|
||||
logger.debug(f"Current RTTs: {current_rtts}")
|
||||
|
||||
with self.lock:
|
||||
expiration = hivemind.get_dht_time() + self.expiration
|
||||
for peer_id, rtt in current_rtts.items():
|
||||
prev_rtt = self.ping_emas.get(peer_id)
|
||||
if prev_rtt is not None and prev_rtt.value != math.inf:
|
||||
rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value # Exponential smoothing
|
||||
self.ping_emas.store(peer_id, rtt, expiration)
|
||||
|
||||
def to_dict(self) -> Dict[hivemind.PeerID, float]:
|
||||
with self.lock, self.ping_emas.freeze():
|
||||
smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()}
|
||||
logger.debug(f"Smothed RTTs: {smoothed_rtts}")
|
||||
return smoothed_rtts
|
@ -0,0 +1,12 @@
|
||||
import random
|
||||
from typing import Collection, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def sample_up_to(population: Collection[T], k: int) -> T:
|
||||
if not isinstance(population, list):
|
||||
population = list(population)
|
||||
if len(population) > k:
|
||||
population = random.sample(population, k)
|
||||
return population
|
@ -1,25 +0,0 @@
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
from huggingface_hub import delete_repo, list_models
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Remove old testing models from HF hub")
|
||||
parser.add_argument("--author", type=str, default="bloom-testing", help="auth token for from_pretrained")
|
||||
parser.add_argument("--seconds_since_last_updated", type=int, default=7 * 24 * 60 * 60)
|
||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
||||
parser.add_argument("--dry_run", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
for model in list_models(author=args.author, full=True):
|
||||
last_modified = datetime.strptime(model.lastModified, "%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
|
||||
if model.modelId.endswith("-main") or "/test-" not in model.modelId:
|
||||
continue # remove only test models
|
||||
|
||||
if (datetime.now() - last_modified).total_seconds() > args.seconds_since_last_updated:
|
||||
if args.dry_run:
|
||||
print(f"{model.modelId} can be deleted")
|
||||
else:
|
||||
delete_repo(repo_id=model.modelId, token=args.use_auth_token)
|
Binary file not shown.
@ -1,22 +1,75 @@
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from hivemind import nested_compare, nested_flatten
|
||||
|
||||
from petals.client import DistributedBloomConfig
|
||||
from petals import AutoDistributedConfig
|
||||
from petals.server.throughput import measure_compute_rps
|
||||
from petals.utils.convert_block import QuantType
|
||||
from petals.utils.misc import DUMMY, is_dummy
|
||||
from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
|
||||
from test_utils import MODEL_NAME
|
||||
|
||||
|
||||
def test_bnb_not_imported_when_unnecessary():
|
||||
"""
|
||||
We avoid importing bitsandbytes when it's not used,
|
||||
since bitsandbytes doesn't always find correct CUDA libs and may raise exceptions because of that.
|
||||
|
||||
If this test fails, please change your code to import bitsandbytes and/or petals.utils.peft
|
||||
in the function's/method's code when it's actually needed instead of importing them in the beginning of the file.
|
||||
This won't slow down the code - importing a module for the 2nd time doesn't rerun module code.
|
||||
"""
|
||||
|
||||
subprocess.check_call([sys.executable, "-c", "import petals, sys; assert 'bitsandbytes' not in sys.modules"])
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.parametrize("inference", [False, True])
|
||||
@pytest.mark.parametrize("n_tokens", [1, 16])
|
||||
@pytest.mark.parametrize("tensor_parallel", [False, True])
|
||||
def test_compute_throughput(tensor_parallel: bool):
|
||||
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
||||
def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool):
|
||||
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
||||
if tensor_parallel and config.model_type != "bloom":
|
||||
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
|
||||
|
||||
tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
|
||||
compute_rps = measure_compute_rps(
|
||||
config,
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.bfloat16,
|
||||
load_in_8bit=False,
|
||||
quant_type=QuantType.NONE,
|
||||
tensor_parallel_devices=tensor_parallel_devices,
|
||||
n_steps=10,
|
||||
n_tokens=n_tokens,
|
||||
n_steps=5,
|
||||
inference=inference,
|
||||
)
|
||||
assert isinstance(compute_rps, float) and compute_rps > 0
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_pack_inputs():
|
||||
x = torch.ones(3)
|
||||
y = torch.arange(5)
|
||||
z = DUMMY
|
||||
|
||||
args = (x, z, None, (y, y), z)
|
||||
kwargs = dict(foo=torch.zeros(1, 1), bar={"l": "i", "g": "h", "t": ("y", "e", "a", "r", torch.rand(1), x, y)})
|
||||
|
||||
flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)
|
||||
|
||||
assert len(flat_tensors) == 5
|
||||
assert all(isinstance(t, torch.Tensor) for t in flat_tensors)
|
||||
|
||||
restored_args, restored_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
||||
|
||||
assert len(restored_args) == len(args)
|
||||
assert torch.all(restored_args[0] == x).item() and restored_args[2] is None
|
||||
assert nested_compare((args, kwargs), (restored_args, restored_kwargs))
|
||||
for original, restored in zip(nested_flatten((args, kwargs)), nested_flatten((restored_args, restored_kwargs))):
|
||||
if isinstance(original, torch.Tensor):
|
||||
assert torch.all(original == restored)
|
||||
else:
|
||||
assert original == restored
|
||||
|
@ -1,85 +1,43 @@
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.models.bloom.configuration_bloom import BloomConfig
|
||||
|
||||
from petals.bloom.block import WrappedBloomBlock
|
||||
from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block
|
||||
from petals.client import DistributedBloomConfig, RemoteSequential
|
||||
from petals.data_structures import UID_DELIMITER
|
||||
from petals import AutoDistributedConfig, RemoteSequential
|
||||
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
|
||||
from petals.server.from_pretrained import load_pretrained_block
|
||||
from test_utils import *
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
|
||||
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
||||
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
||||
remote_sequential = RemoteSequential(config)
|
||||
|
||||
for block_index in random.sample(range(config.n_layer), 3):
|
||||
remote_block = remote_sequential[block_index]
|
||||
block_index = random.randint(0, config.num_hidden_layers - 1)
|
||||
remote_block = remote_sequential[block_index]
|
||||
|
||||
inputs = torch.randn(1, 8, config.hidden_size)
|
||||
outputs_forward = remote_block(inputs)
|
||||
inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS + 8, config.hidden_size)
|
||||
outputs_forward = remote_block(inputs)
|
||||
|
||||
outputs_inference = []
|
||||
outputs_inference = []
|
||||
with torch.inference_mode():
|
||||
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
|
||||
for i in range(inputs.shape[1]):
|
||||
# Test long inference (unmerged inference pools)
|
||||
outputs_inference.append(sess.step(inputs[:, : MAX_SHORT_INFERENCE_TOKENS + 1, :]))
|
||||
|
||||
# Test short inference (merged inference pools)
|
||||
for i in range(MAX_SHORT_INFERENCE_TOKENS + 1, inputs.shape[1]):
|
||||
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
|
||||
|
||||
# test that max length is respected
|
||||
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
|
||||
sess.step(inputs[:, -1:, :])
|
||||
assert "Maximum length exceeded" in repr(exc_info.value)
|
||||
outputs_inference = torch.cat(outputs_inference, dim=1)
|
||||
|
||||
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
|
||||
(outputs_local,) = ref_block(inputs)
|
||||
|
||||
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
|
||||
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
|
||||
|
||||
|
||||
def _old_load_pretrained_block(
|
||||
converted_model_name_or_path: str,
|
||||
block_index: int,
|
||||
torch_dtype: Union[torch.dtype, str] = "auto",
|
||||
) -> WrappedBloomBlock:
|
||||
"""Load the BLOOM block by directly initializing the weights.
|
||||
This test is used to check consistency with the previous implementation and can be removed in the future."""
|
||||
config = BloomConfig.from_pretrained(converted_model_name_or_path)
|
||||
|
||||
block = WrappedBloomBlock(config)
|
||||
state_dict = _load_state_dict(
|
||||
converted_model_name_or_path,
|
||||
block_index,
|
||||
config,
|
||||
cache_dir=None,
|
||||
)
|
||||
|
||||
if torch_dtype == "auto":
|
||||
with torch.no_grad():
|
||||
for name, param in block.named_parameters():
|
||||
assert name in state_dict, f"{name} not in state dict"
|
||||
param.data = param.data.to(state_dict[name].dtype)
|
||||
else:
|
||||
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
||||
block = block.to(dtype=torch_dtype)
|
||||
|
||||
block.load_state_dict(state_dict, strict=True)
|
||||
return block
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_init_pretrained_block(torch_dtype=torch.float32, atol_forward=1e-8):
|
||||
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
||||
torch.random.manual_seed(0)
|
||||
inputs = torch.randn(1, 16, config.hidden_size, dtype=torch_dtype)
|
||||
outputs_inference = torch.cat(outputs_inference, dim=1)
|
||||
|
||||
block = load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
|
||||
ref_block = _old_load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
|
||||
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
|
||||
(outputs_local,) = ref_block(inputs)
|
||||
|
||||
outputs = block.forward(inputs)[0]
|
||||
outputs_ref = ref_block.forward(inputs)[0]
|
||||
assert torch.allclose(outputs, outputs_ref, rtol=0, atol=atol_forward)
|
||||
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
|
||||
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
|
||||
|
@ -0,0 +1,184 @@
|
||||
import asyncio
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio # make sure the module exists; otherwise the test will be skipped
|
||||
import torch
|
||||
from hivemind import TensorDescriptor
|
||||
|
||||
from petals.server.memory_cache import AllocationFailed, MemoryCache
|
||||
from petals.utils.misc import get_size_in_bytes
|
||||
|
||||
|
||||
def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None):
|
||||
if dtype is None:
|
||||
dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool))
|
||||
elem_size_bytes = get_size_in_bytes(dtype)
|
||||
descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype))
|
||||
return descr
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_timeout():
|
||||
cache = MemoryCache(max_size_bytes=1024, max_alloc_timeout=0.5)
|
||||
cache.runtime_pid += 1 # pretend we're another process
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0):
|
||||
pass
|
||||
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(100), timeout=999):
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(128), _make_tensor_descriptor(32), timeout=1):
|
||||
t_start = time.perf_counter()
|
||||
with pytest.raises(AllocationFailed):
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0.1):
|
||||
pass
|
||||
assert 0.1 < time.perf_counter() - t_start < 0.2, "wait time exceeds alloc timeout"
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")):
|
||||
pass
|
||||
|
||||
t_start = time.perf_counter()
|
||||
with pytest.raises(AllocationFailed):
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(384), timeout=1.0): # exceeds max timeout
|
||||
pass
|
||||
assert 0.5 < time.perf_counter() - t_start < 0.6, "wait time exceeds max alloc timeout"
|
||||
|
||||
# test memory allocation when another task frees the memory
|
||||
async def _klog_the_cache():
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
|
||||
pass
|
||||
|
||||
large_alloc_task = asyncio.create_task(_klog_the_cache())
|
||||
|
||||
t_start = time.perf_counter()
|
||||
await asyncio.sleep(0.05) # wait for large alloc to enqueue
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")): # exceeds max timeout
|
||||
pass # this memory should allocate once the background task clears the queue
|
||||
assert 0.2 < time.perf_counter() - t_start < 0.3, "memory should be allocated after background task clears"
|
||||
with pytest.raises(AllocationFailed):
|
||||
await large_alloc_task
|
||||
|
||||
# test that zero-timeout allocation fails instantaneously even if someone else is awaiting alloc
|
||||
large_alloc_task = asyncio.create_task(_klog_the_cache())
|
||||
t_start = time.perf_counter()
|
||||
await asyncio.sleep(0.05) # wait for large alloc to enqueue
|
||||
with pytest.raises(AllocationFailed):
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
|
||||
pass # this memory should allocate once the background task clears the queue
|
||||
assert time.perf_counter() - t_start < 0.1, "zero-timeout task should fail (or succeed) instantaneously"
|
||||
with pytest.raises(AllocationFailed):
|
||||
await large_alloc_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlimited_timeout():
|
||||
cache = MemoryCache(max_size_bytes=1024)
|
||||
cache.runtime_pid += 1 # pretend we're another process
|
||||
t_start = time.perf_counter()
|
||||
|
||||
async def _klog_the_cache():
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
alloc_task = asyncio.create_task(_klog_the_cache())
|
||||
await asyncio.sleep(0.1)
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=float("inf")):
|
||||
await alloc_task
|
||||
assert 0.5 < time.perf_counter() - t_start < 0.6, "memory should be allocated after background task clears"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_usage():
|
||||
cache = MemoryCache(max_size_bytes=2048)
|
||||
alloc_event, dealloc_a_event, dealloc_bcd_event, dealloc_e_event, dealloc_f_event = (mp.Event() for _ in range(5))
|
||||
pipe_receiver, pipe_sender = mp.Pipe(duplex=False)
|
||||
with pytest.raises(AssertionError):
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(123), timeout=1):
|
||||
pass # fails because cache must be allocated from another process
|
||||
|
||||
descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8)) # 768 bytes
|
||||
descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64)) # 8 bytes
|
||||
descr_c = TensorDescriptor.from_tensor(torch.empty((33,), dtype=torch.bool)) # 33 bytes
|
||||
descr_d = TensorDescriptor.from_tensor(torch.empty((0,), dtype=torch.int64)) # 0 bytes
|
||||
descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16)) # 1536 bytes
|
||||
descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8)) # 1792 bytes
|
||||
|
||||
async def _allocate_and_wait(dealloc_event, *descrs, timeout=None):
|
||||
loop = asyncio.get_event_loop()
|
||||
async with cache.allocate_cache(*descrs, timeout=timeout) as handles:
|
||||
pipe_sender.send(handles)
|
||||
await loop.run_in_executor(None, dealloc_event.wait)
|
||||
|
||||
async def _allocate_af():
|
||||
alloc_event.wait()
|
||||
allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a))
|
||||
await allocate_a_task
|
||||
allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f)) # klogs the cache
|
||||
await allocate_f_task
|
||||
|
||||
alloc_process1 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_af()), daemon=True)
|
||||
alloc_process1.start()
|
||||
|
||||
async def _allocate_bcde():
|
||||
alloc_event.wait()
|
||||
await asyncio.sleep(0.1) # ensure that the other tensor is always allocated (and sent through pipe) first
|
||||
allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d))
|
||||
allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e)) # doesn't fit
|
||||
await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
|
||||
|
||||
alloc_process2 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
|
||||
alloc_process2.start()
|
||||
assert cache.current_size_bytes == 0
|
||||
alloc_event.set()
|
||||
(handle_a,) = pipe_receiver.recv()
|
||||
|
||||
handle_b, handle_c, handle_d = pipe_receiver.recv()
|
||||
|
||||
with cache.use_cache(handle_a) as (tensor_a,):
|
||||
assert tensor_a.dtype == torch.uint8
|
||||
tensor_a[2:5] = torch.tensor((42, 43, 44))
|
||||
|
||||
with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d):
|
||||
assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0
|
||||
assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0
|
||||
tensor_a += 1
|
||||
tensor_b[...] = -1.337
|
||||
assert cache.current_size_bytes == 809 # this checks a,b,c,d are allocated but b still awaits memory
|
||||
|
||||
dealloc_bcd_event.set()
|
||||
await asyncio.sleep(0.1)
|
||||
assert cache.current_size_bytes == 768 # only tensor a should be allocated
|
||||
with pytest.raises(KeyError):
|
||||
with cache.use_cache(handle_a, handle_b):
|
||||
pass # one of handles (c) is deallocated
|
||||
with pytest.raises(KeyError):
|
||||
with cache.use_cache(handle_d):
|
||||
pass # handle_d is deallocated correctly, even though it is never used
|
||||
with cache.use_cache(handle_a) as (tensor_a,):
|
||||
assert tuple(tensor_a[2:5]) == (43, 44, 45)
|
||||
|
||||
dealloc_a_event.set()
|
||||
(handle_e,) = pipe_receiver.recv() # e can finally be allocated
|
||||
await asyncio.sleep(0.1)
|
||||
assert cache.current_size_bytes == 1536 # tensor e should finally be able to allocate
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
with cache.use_cache(handle_a):
|
||||
pass # tensor a is no longer allocated
|
||||
with cache.use_cache(handle_e) as (tensor_e,):
|
||||
assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8)
|
||||
|
||||
dealloc_e_event.set()
|
||||
await asyncio.sleep(0.1)
|
||||
assert cache.current_size_bytes == 1792 # only tensor f is still allocated
|
||||
dealloc_f_event.set()
|
||||
|
||||
alloc_process1.join()
|
||||
alloc_process2.join()
|
||||
await asyncio.sleep(0.1)
|
||||
assert cache.current_size_bytes == 0
|
||||
assert cache.current_size_bytes == 0
|
||||
assert alloc_process1.exitcode == 0, "allocation process 1 failed or did not finish, see stderr for details"
|
||||
assert alloc_process2.exitcode == 0, "allocation process 2 failed or did not finish, see stderr for details"
|
@ -0,0 +1,16 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from petals.server.block_utils import resolve_block_dtype
|
||||
from petals.server.from_pretrained import load_pretrained_block
|
||||
from petals.utils.auto_config import AutoDistributedConfig
|
||||
from test_utils import MODEL_NAME
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"])
|
||||
def test_block_dtype(torch_dtype):
|
||||
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
||||
block = load_pretrained_block(MODEL_NAME, 0, config=config, torch_dtype=torch_dtype)
|
||||
expected_dtype = resolve_block_dtype(config, torch_dtype)
|
||||
assert all(param.dtype == expected_dtype for param in block.parameters())
|
@ -0,0 +1,224 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
from petals.server.block_utils import get_model_block
|
||||
from petals.utils.auto_config import AutoDistributedConfig
|
||||
from petals.utils.convert_block import QuantType, convert_block
|
||||
from test_utils import MODEL_NAME
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
alibi: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[KVCache] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
if layer_past is not None:
|
||||
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
|
||||
past_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
seq_length_with_past = seq_length + past_length
|
||||
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
if alibi is None and self.config.alibi:
|
||||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
||||
|
||||
outputs = super().forward(
|
||||
hidden_states,
|
||||
*args,
|
||||
attention_mask=attention_mask,
|
||||
alibi=alibi,
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
present_key_value = outputs[-1]
|
||||
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
|
||||
outputs = outputs[:-1] + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
|
||||
key_states, value_states = key_value
|
||||
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
key_states = self._expand_states(key_states)
|
||||
value_states = self._expand_states(value_states)
|
||||
|
||||
return (key_states, value_states)
|
||||
|
||||
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
|
||||
key_states, value_states = key_value
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
key_states = self._collapse_states(key_states)
|
||||
value_states = self._collapse_states(value_states)
|
||||
|
||||
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
|
||||
return (key_states, value_states)
|
||||
|
||||
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
|
||||
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
|
||||
|
||||
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
|
||||
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
|
||||
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
|
||||
return state
|
||||
|
||||
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
|
||||
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
|
||||
|
||||
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
|
||||
state = state[:, :, 0]
|
||||
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
|
||||
return state
|
||||
|
||||
|
||||
class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
past_key_value = layer_past
|
||||
if past_key_value is not None:
|
||||
past_key_values_length = past_key_value[0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
|
||||
elif use_cache:
|
||||
past_key_value = DynamicCache()
|
||||
|
||||
if position_ids is None:
|
||||
device = hidden_states.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
)
|
||||
|
||||
outputs = super().forward(
|
||||
hidden_states,
|
||||
*args,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
present_key_value = outputs[-1]
|
||||
present_key_value = self._reorder_cache_from_llama_to_bloom(
|
||||
present_key_value, batch_size, seq_length_with_past
|
||||
)
|
||||
outputs = outputs[:-1] + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def _reorder_cache_from_bloom_to_llama(
|
||||
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||
) -> DynamicCache:
|
||||
key_states, value_states = key_value
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
key_states = key_states.view(
|
||||
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||
)
|
||||
value_states = value_states.view(*key_states.shape)
|
||||
past_key_values = ((key_states, value_states),)
|
||||
return DynamicCache.from_legacy_cache(past_key_values)
|
||||
|
||||
def _reorder_cache_from_llama_to_bloom(
|
||||
self, key_value: DynamicCache, batch_size: int, seq_length: int
|
||||
) -> Tuple[torch.Tensor]:
|
||||
key_states, value_states = key_value.to_legacy_cache()[0]
|
||||
value_states = value_states.view(
|
||||
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||
)
|
||||
key_states = key_states.view(*value_states.shape)
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
return (key_states, value_states)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
|
||||
@pytest.mark.forked
|
||||
def test_optimized_block(device):
|
||||
if device == "cuda:0" and not torch.cuda.is_available():
|
||||
pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
|
||||
|
||||
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
||||
|
||||
tensor_parallel_devices = (device,)
|
||||
dtype = torch.bfloat16
|
||||
quant_type = QuantType.NONE
|
||||
|
||||
block_idx = 1
|
||||
block = get_model_block(config, layer_idx=block_idx).to(dtype)
|
||||
block = convert_block(block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
|
||||
|
||||
if config.model_type == "falcon":
|
||||
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
|
||||
elif config.model_type == "llama":
|
||||
unopt_block = UnoptimizedWrappedLlamaBlock(config, layer_idx=0).to(dtype)
|
||||
else:
|
||||
pytest.skip(f"This test is not applicable to {config.model_type} models")
|
||||
|
||||
unopt_block = convert_block(
|
||||
unopt_block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
|
||||
)
|
||||
|
||||
unopt_block.load_state_dict(block.state_dict())
|
||||
cache = unopt_cache = None
|
||||
|
||||
with torch.inference_mode():
|
||||
for length in [10, 1, 1, 1]:
|
||||
dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype)
|
||||
block_output, cache = block(dummy_input, layer_past=cache, use_cache=True)
|
||||
unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True)
|
||||
assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length
|
||||
assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length
|
||||
assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length
|
@ -0,0 +1,66 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from petals.utils.peft import check_peft_repository, load_peft
|
||||
|
||||
UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
|
||||
SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft"
|
||||
TMP_CACHE_DIR = "tmp_cache/"
|
||||
|
||||
|
||||
def clear_dir(path_to_dir):
|
||||
shutil.rmtree(path_to_dir)
|
||||
os.mkdir(path_to_dir)
|
||||
|
||||
|
||||
def dir_empty(path_to_dir):
|
||||
files = os.listdir(path_to_dir)
|
||||
return len(files) == 0
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_check_peft():
|
||||
assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load."
|
||||
assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load."
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_noncached(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
with pytest.raises(Exception):
|
||||
load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded"
|
||||
|
||||
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded"
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_cached(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_layer_exists(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
|
||||
load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir)
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_layer_nonexists(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
|
||||
load_peft(
|
||||
SAFE_PEFT_REPO,
|
||||
block_idx=1337,
|
||||
cache_dir=tmpdir,
|
||||
)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue