Merge branch 'main' into repetition-penalty
commit
dd677d9e76
@ -0,0 +1,75 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import multiprocessing as mp
|
||||||
|
from time import perf_counter
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
|
||||||
|
from petals import AutoDistributedModel
|
||||||
|
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("--model", type=str, required=True, help="Model")
|
||||||
|
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
|
||||||
|
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
|
||||||
|
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
|
||||||
|
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
|
||||||
|
parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps")
|
||||||
|
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
|
||||||
|
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.n_processes == "n_gpus":
|
||||||
|
args.n_processes = torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
args.n_processes = int(args.n_processes)
|
||||||
|
|
||||||
|
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
||||||
|
processes = [mp.Process(target=benchmark_forward, args=(i, args, pipe_send)) for i in range(args.n_processes)]
|
||||||
|
for proc in processes:
|
||||||
|
proc.start()
|
||||||
|
for proc in processes:
|
||||||
|
proc.join()
|
||||||
|
|
||||||
|
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
|
||||||
|
logger.info(f"Final result: {speed=:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def benchmark_forward(process_idx, args, result_pipe):
|
||||||
|
model = AutoDistributedModel.from_pretrained(
|
||||||
|
args.model,
|
||||||
|
initial_peers=args.initial_peers,
|
||||||
|
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||||
|
)
|
||||||
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
step_times = []
|
||||||
|
for step in range(args.warmup_steps + args.n_steps):
|
||||||
|
start_time = perf_counter()
|
||||||
|
|
||||||
|
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len))
|
||||||
|
|
||||||
|
logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
|
||||||
|
h = model(input_ids)
|
||||||
|
# We don't use model.lm_head
|
||||||
|
logger.info(f"{process_idx=} Fwd end")
|
||||||
|
|
||||||
|
if step >= args.warmup_steps:
|
||||||
|
step_times.append(perf_counter() - start_time)
|
||||||
|
speed = input_ids.numel() / np.mean(step_times)
|
||||||
|
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
|
||||||
|
|
||||||
|
result_pipe.send(speed)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,72 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import multiprocessing as mp
|
||||||
|
from time import perf_counter
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from petals import AutoDistributedModelForCausalLM
|
||||||
|
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("--model", type=str, required=True, help="Model")
|
||||||
|
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
|
||||||
|
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
|
||||||
|
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
|
||||||
|
parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length")
|
||||||
|
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.n_processes == "n_gpus":
|
||||||
|
args.n_processes = torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
args.n_processes = int(args.n_processes)
|
||||||
|
|
||||||
|
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
||||||
|
processes = [mp.Process(target=benchmark_inference, args=(i, args, pipe_send)) for i in range(args.n_processes)]
|
||||||
|
for proc in processes:
|
||||||
|
proc.start()
|
||||||
|
for proc in processes:
|
||||||
|
proc.join()
|
||||||
|
|
||||||
|
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
|
||||||
|
logger.info(f"Final result: {speed=:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def benchmark_inference(process_idx, args, result_pipe):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
|
||||||
|
# Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway
|
||||||
|
|
||||||
|
model = AutoDistributedModelForCausalLM.from_pretrained(
|
||||||
|
args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
||||||
|
)
|
||||||
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
||||||
|
|
||||||
|
result = ""
|
||||||
|
step_times = []
|
||||||
|
with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
|
||||||
|
for step in range(args.seq_len):
|
||||||
|
start_time = perf_counter()
|
||||||
|
|
||||||
|
outputs = model.generate(max_new_tokens=1, session=sess)
|
||||||
|
result += tokenizer.decode(outputs[0])
|
||||||
|
|
||||||
|
if step >= args.warmup_steps:
|
||||||
|
step_times.append(perf_counter() - start_time)
|
||||||
|
speed = 1 / np.mean(step_times)
|
||||||
|
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
|
||||||
|
|
||||||
|
result_pipe.send(speed)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,107 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import multiprocessing as mp
|
||||||
|
from time import perf_counter
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
|
||||||
|
from petals import AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification
|
||||||
|
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("--model", type=str, required=True, help="Model")
|
||||||
|
parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client")
|
||||||
|
parser.add_argument("--task", type=str, default="cls", help="Training task type")
|
||||||
|
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
|
||||||
|
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
|
||||||
|
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
|
||||||
|
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
|
||||||
|
parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens")
|
||||||
|
parser.add_argument("--n_steps", type=int, default=10, help="Number of benchmark steps")
|
||||||
|
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
|
||||||
|
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert args.task in ["cls", "causal_lm"]
|
||||||
|
|
||||||
|
if args.n_processes == "n_gpus":
|
||||||
|
args.n_processes = torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
args.n_processes = int(args.n_processes)
|
||||||
|
|
||||||
|
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
||||||
|
processes = [mp.Process(target=benchmark_training, args=(i, args, pipe_send)) for i in range(args.n_processes)]
|
||||||
|
for proc in processes:
|
||||||
|
proc.start()
|
||||||
|
for proc in processes:
|
||||||
|
proc.join()
|
||||||
|
|
||||||
|
fwd_speed, bwd_speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)], axis=0)
|
||||||
|
logger.info(f"Final result: {fwd_speed=:.2f} {bwd_speed=:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_training(process_idx, args, result_pipe):
|
||||||
|
if args.task == "cls":
|
||||||
|
model = AutoDistributedModelForSequenceClassification.from_pretrained(
|
||||||
|
args.model,
|
||||||
|
initial_peers=args.initial_peers,
|
||||||
|
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||||
|
tuning_mode="deep_ptune",
|
||||||
|
pre_seq_len=args.pre_seq_len,
|
||||||
|
num_labels=2,
|
||||||
|
)
|
||||||
|
elif args.task == "causal_lm":
|
||||||
|
model = AutoDistributedModelForCausalLM.from_pretrained(
|
||||||
|
args.model,
|
||||||
|
initial_peers=args.initial_peers,
|
||||||
|
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||||
|
tuning_mode="deep_ptune",
|
||||||
|
pre_seq_len=args.pre_seq_len,
|
||||||
|
)
|
||||||
|
model = model.to(args.device)
|
||||||
|
opt = torch.optim.Adam(model.parameters())
|
||||||
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
fwd_times = []
|
||||||
|
bwd_times = []
|
||||||
|
for step in range(args.warmup_steps + args.n_steps):
|
||||||
|
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device)
|
||||||
|
if args.task == "cls":
|
||||||
|
labels = torch.randint(0, 2, size=[args.batch_size], device=args.device)
|
||||||
|
else:
|
||||||
|
labels = input_ids
|
||||||
|
|
||||||
|
logger.info(f"{process_idx=} {step=} Forward")
|
||||||
|
start_time = perf_counter()
|
||||||
|
outputs = model(input_ids, labels=labels)
|
||||||
|
if step >= args.warmup_steps:
|
||||||
|
fwd_times.append(perf_counter() - start_time)
|
||||||
|
|
||||||
|
logger.info(f"{process_idx=} {step=} Backward")
|
||||||
|
start_time = perf_counter()
|
||||||
|
outputs.loss.backward()
|
||||||
|
if step >= args.warmup_steps:
|
||||||
|
bwd_times.append(perf_counter() - start_time)
|
||||||
|
|
||||||
|
logger.info(f"{process_idx=} {step=} Optimizer step")
|
||||||
|
opt.step()
|
||||||
|
opt.zero_grad()
|
||||||
|
|
||||||
|
if step >= args.warmup_steps:
|
||||||
|
fwd_speed = input_ids.numel() / np.mean(fwd_times)
|
||||||
|
bwd_speed = input_ids.numel() / np.mean(bwd_times)
|
||||||
|
logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
|
||||||
|
|
||||||
|
result_pipe.send((fwd_speed, bwd_speed))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,6 +1,29 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
|
||||||
|
|
||||||
|
import hivemind
|
||||||
|
import transformers
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from petals.client import *
|
from petals.client import *
|
||||||
|
from petals.models import *
|
||||||
|
from petals.utils import *
|
||||||
from petals.utils.logging import initialize_logs as _initialize_logs
|
from petals.utils.logging import initialize_logs as _initialize_logs
|
||||||
|
|
||||||
__version__ = "1.0alpha1"
|
__version__ = "2.0.1.post2"
|
||||||
|
|
||||||
|
|
||||||
|
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
|
||||||
|
assert (
|
||||||
|
version.parse("4.31.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
|
||||||
|
), "Please install a proper transformers version: pip install transformers>=4.31.0,<5.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
def _override_bfloat16_mode_default():
|
||||||
|
if os.getenv("USE_LEGACY_BFLOAT16") is None:
|
||||||
|
hivemind.compression.base.USE_LEGACY_BFLOAT16 = False
|
||||||
|
|
||||||
|
|
||||||
_initialize_logs()
|
_initialize_logs()
|
||||||
|
_override_bfloat16_mode_default()
|
||||||
|
@ -1,59 +0,0 @@
|
|||||||
"""
|
|
||||||
Bloom intermediate layer
|
|
||||||
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
|
||||||
See commit history for authorship.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import torch.nn.quantized.dynamic.modules.linear
|
|
||||||
import transformers
|
|
||||||
from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor
|
|
||||||
|
|
||||||
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
|
|
||||||
assert transformers.__version__.startswith("4.25."), "Please install transformers 4.25.1"
|
|
||||||
|
|
||||||
|
|
||||||
class WrappedBloomBlock(BloomBlock):
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
*args,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
alibi: Optional[torch.Tensor] = None,
|
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
assert attention_mask is None
|
|
||||||
batch_size, seq_length = hidden_states.shape[:2]
|
|
||||||
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
|
|
||||||
seq_length_with_past = seq_length + past_length
|
|
||||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
|
||||||
if alibi is None:
|
|
||||||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
|
||||||
attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
|
||||||
return super().forward(
|
|
||||||
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_attn_mask(
|
|
||||||
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
|
||||||
) -> torch.BoolTensor:
|
|
||||||
# create causal mask
|
|
||||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
|
||||||
combined_attention_mask = None
|
|
||||||
device = attention_mask.device
|
|
||||||
_, src_length = input_shape
|
|
||||||
|
|
||||||
if src_length > 1:
|
|
||||||
combined_attention_mask = _make_causal_mask(
|
|
||||||
torch.Size(input_shape), device=device, past_key_values_length=past_key_values_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
|
||||||
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
|
||||||
combined_attention_mask = (
|
|
||||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
return combined_attention_mask
|
|
@ -1,125 +0,0 @@
|
|||||||
"""
|
|
||||||
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
|
|
||||||
If necessary, one can rewrite this to implement a different behavior, such as:
|
|
||||||
- loading files from a local data source (e.g. S3)
|
|
||||||
- load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
|
|
||||||
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import itertools
|
|
||||||
import time
|
|
||||||
from typing import Optional, OrderedDict, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from hivemind.utils.logging import get_logger
|
|
||||||
from transformers.modeling_utils import WEIGHTS_NAME
|
|
||||||
from transformers.models.bloom.configuration_bloom import BloomConfig
|
|
||||||
from transformers.utils import get_file_from_repo
|
|
||||||
|
|
||||||
from petals.bloom.block import WrappedBloomBlock
|
|
||||||
from petals.server.block_utils import get_block_size
|
|
||||||
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
|
||||||
CLIENT_BRANCH = "main"
|
|
||||||
BLOCK_BRANCH_PREFIX = "block_"
|
|
||||||
|
|
||||||
|
|
||||||
def load_pretrained_block(
|
|
||||||
converted_model_name_or_path: str,
|
|
||||||
block_index: int,
|
|
||||||
config: Optional[BloomConfig] = None,
|
|
||||||
torch_dtype: Union[torch.dtype, str] = "auto",
|
|
||||||
use_auth_token: Optional[str] = None,
|
|
||||||
cache_dir: Optional[str] = None,
|
|
||||||
max_disk_space: Optional[int] = None,
|
|
||||||
) -> WrappedBloomBlock:
|
|
||||||
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
|
|
||||||
|
|
||||||
if config is None:
|
|
||||||
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
|
||||||
if cache_dir is None:
|
|
||||||
cache_dir = DEFAULT_CACHE_DIR
|
|
||||||
|
|
||||||
block = WrappedBloomBlock(config)
|
|
||||||
state_dict = _load_state_dict(
|
|
||||||
converted_model_name_or_path,
|
|
||||||
block_index,
|
|
||||||
config,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
max_disk_space=max_disk_space,
|
|
||||||
)
|
|
||||||
|
|
||||||
if torch_dtype == "auto":
|
|
||||||
with torch.no_grad():
|
|
||||||
for name, param in block.named_parameters():
|
|
||||||
assert name in state_dict, f"{name} not in state dict"
|
|
||||||
param.data = param.data.to(state_dict[name].dtype)
|
|
||||||
else:
|
|
||||||
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
||||||
block = block.to(dtype=torch_dtype)
|
|
||||||
|
|
||||||
report = block.load_state_dict(state_dict, strict=True)
|
|
||||||
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
|
|
||||||
return block
|
|
||||||
|
|
||||||
|
|
||||||
def _load_state_dict(
|
|
||||||
pretrained_model_name_or_path: str,
|
|
||||||
block_index: int,
|
|
||||||
config: BloomConfig,
|
|
||||||
*,
|
|
||||||
use_auth_token: Optional[str] = None,
|
|
||||||
cache_dir: str,
|
|
||||||
max_disk_space: Optional[int] = None,
|
|
||||||
min_backoff: float = 5,
|
|
||||||
) -> OrderedDict[str, torch.Tensor]:
|
|
||||||
revision = BLOCK_BRANCH_PREFIX + str(block_index)
|
|
||||||
|
|
||||||
# First, try to find the weights locally
|
|
||||||
try:
|
|
||||||
with allow_cache_reads(cache_dir):
|
|
||||||
archive_file = get_file_from_repo(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
filename=WEIGHTS_NAME,
|
|
||||||
revision=revision,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
|
||||||
if archive_file is not None:
|
|
||||||
return torch.load(archive_file, map_location="cpu")
|
|
||||||
except Exception:
|
|
||||||
logger.debug(
|
|
||||||
f"Failed to load block {block_index} from cache. The block will be downloaded again", exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# If not found, ensure that we have enough disk space to download them (maybe remove something)
|
|
||||||
for attempt_no in itertools.count():
|
|
||||||
try:
|
|
||||||
with allow_cache_writes(cache_dir):
|
|
||||||
block_size = get_block_size(config, "disk")
|
|
||||||
free_disk_space_for(
|
|
||||||
pretrained_model_name_or_path, block_size, cache_dir=cache_dir, max_disk_space=max_disk_space
|
|
||||||
)
|
|
||||||
|
|
||||||
archive_file = get_file_from_repo(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
filename=WEIGHTS_NAME,
|
|
||||||
revision=revision,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
local_files_only=False,
|
|
||||||
)
|
|
||||||
return torch.load(archive_file, map_location="cpu")
|
|
||||||
except Exception as e:
|
|
||||||
delay = min_backoff * (2**attempt_no)
|
|
||||||
logger.warning(f"Failed to load block {block_index} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
|
||||||
time.sleep(delay)
|
|
||||||
|
|
||||||
|
|
||||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
|
@ -1,72 +0,0 @@
|
|||||||
"""
|
|
||||||
PyTorch BLOOM model that implements several memory-efficient modes.
|
|
||||||
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
|
||||||
See commit history for authorship.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.utils.checkpoint
|
|
||||||
from hivemind import get_logger
|
|
||||||
from torch import nn
|
|
||||||
from transformers import BloomConfig
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
|
||||||
|
|
||||||
class LMHead(nn.Module):
|
|
||||||
"""
|
|
||||||
The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
|
|
||||||
embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
|
|
||||||
In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
|
|
||||||
super().__init__()
|
|
||||||
self.word_embeddings = word_embeddings
|
|
||||||
self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
|
|
||||||
|
|
||||||
@property
|
|
||||||
def in_features(self) -> int:
|
|
||||||
return self.word_embeddings.num_embeddings
|
|
||||||
|
|
||||||
@property
|
|
||||||
def out_features(self) -> int:
|
|
||||||
return self.word_embeddings.embedding_dim
|
|
||||||
|
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return self.word_embeddings.weight
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bias(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
word_embeddings = self.word_embeddings.weight
|
|
||||||
|
|
||||||
# We use 'chunked_forward' only when embeddings are in half-precision on CPU.
|
|
||||||
if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
|
|
||||||
lm_logits = self.chunked_forward(hidden_states)
|
|
||||||
else:
|
|
||||||
# Switch dtype in case word_embeddings are fp16/bf16
|
|
||||||
hidden_states = hidden_states.to(word_embeddings.dtype)
|
|
||||||
lm_logits = F.linear(hidden_states, word_embeddings)
|
|
||||||
return lm_logits
|
|
||||||
|
|
||||||
def chunked_forward(self, hidden_states):
|
|
||||||
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
|
|
||||||
chunk_size: provides trade-off between efficiency and extra memory consumption.
|
|
||||||
"""
|
|
||||||
assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
|
|
||||||
|
|
||||||
word_embeddings = self.word_embeddings.weight
|
|
||||||
num_embeddings = self.word_embeddings.num_embeddings
|
|
||||||
|
|
||||||
hidden_states = hidden_states.float()
|
|
||||||
output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
|
|
||||||
|
|
||||||
for i in range(0, num_embeddings, self.chunk_size):
|
|
||||||
chunk = word_embeddings[i : i + self.chunk_size].float()
|
|
||||||
output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
|
|
||||||
return output
|
|
@ -1,20 +0,0 @@
|
|||||||
{
|
|
||||||
"apply_residual_connection_post_layernorm": false,
|
|
||||||
"attention_dropout": 0.0,
|
|
||||||
"attention_softmax_in_fp32": true,
|
|
||||||
"bos_token_id": 1,
|
|
||||||
"eos_token_id": 2,
|
|
||||||
"hidden_dropout": 0.0,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"layer_norm_epsilon": 1e-05,
|
|
||||||
"masked_softmax_fusion": true,
|
|
||||||
"model_type": "bloom",
|
|
||||||
"n_embed": 14336,
|
|
||||||
"n_layer": 70,
|
|
||||||
"num_attention_heads": 112,
|
|
||||||
"pretraining_tp": 4,
|
|
||||||
"slow_but_exact": false,
|
|
||||||
"transformers_version": "4.20.0.dev0",
|
|
||||||
"use_cache": true,
|
|
||||||
"vocab_size": 250880
|
|
||||||
}
|
|
@ -1,92 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
|
|
||||||
import psutil
|
|
||||||
import torch.backends.quantized
|
|
||||||
import torch.nn as nn
|
|
||||||
import transformers
|
|
||||||
from hivemind.utils.logging import get_logger
|
|
||||||
from huggingface_hub import Repository
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers.models.bloom.modeling_bloom import BloomModel
|
|
||||||
|
|
||||||
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
|
|
||||||
from petals.client import DistributedBloomConfig
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
|
||||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
|
|
||||||
|
|
||||||
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
|
|
||||||
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
|
|
||||||
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
|
||||||
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
|
|
||||||
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
|
|
||||||
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
|
|
||||||
parser.add_argument(
|
|
||||||
"--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
|
|
||||||
)
|
|
||||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
|
||||||
parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
free_ram_gb = psutil.virtual_memory().available / 2**30
|
|
||||||
if args.model == "bigscience/bloom" and free_ram_gb < 400:
|
|
||||||
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
|
|
||||||
|
|
||||||
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
|
|
||||||
if os.path.exists(args.output_path) and (
|
|
||||||
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
|
|
||||||
):
|
|
||||||
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
|
|
||||||
|
|
||||||
logger.info(f"Loading source model {args.model} (this may take a few minutes)")
|
|
||||||
config = DistributedBloomConfig.from_pretrained(
|
|
||||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
|
||||||
)
|
|
||||||
config.dht_prefix = args.output_repo
|
|
||||||
|
|
||||||
model = BloomModel.from_pretrained(
|
|
||||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
|
||||||
)
|
|
||||||
if args.resize_token_embeddings:
|
|
||||||
logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
|
|
||||||
model.resize_token_embeddings(args.resize_token_embeddings)
|
|
||||||
config.vocab_size = args.resize_token_embeddings
|
|
||||||
|
|
||||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
||||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
|
||||||
)
|
|
||||||
os.makedirs(args.output_path, exist_ok=True)
|
|
||||||
|
|
||||||
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
|
||||||
repo.git_pull()
|
|
||||||
|
|
||||||
transformer_blocks = model.h
|
|
||||||
logger.info(
|
|
||||||
f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
|
|
||||||
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
|
|
||||||
)
|
|
||||||
for i, block in enumerate(tqdm(transformer_blocks)):
|
|
||||||
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
|
||||||
with repo.commit(
|
|
||||||
commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
|
|
||||||
):
|
|
||||||
torch.save(block.state_dict(), "./pytorch_model.bin")
|
|
||||||
|
|
||||||
logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
|
|
||||||
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
|
||||||
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
|
||||||
model.h = nn.ModuleList()
|
|
||||||
model.save_pretrained(".")
|
|
||||||
tokenizer.save_pretrained(".")
|
|
||||||
config.save_pretrained(".")
|
|
||||||
|
|
||||||
logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
|
|
@ -1,79 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
#################
|
|
||||||
# Parse options #
|
|
||||||
#################
|
|
||||||
|
|
||||||
instructions() {
|
|
||||||
echo "Usage: $0 [-m] [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
|
|
||||||
echo " -m: model name"
|
|
||||||
echo " -i: initial peer"
|
|
||||||
echo " -d: device" >&2
|
|
||||||
echo " -p: server identity path" >&2
|
|
||||||
echo " -b: block_ids" >&2
|
|
||||||
echo " -a: host maddrs" >&2
|
|
||||||
echo " -t: whether to run local tests" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if [ ! $# -ge 8 ]; then
|
|
||||||
instructions
|
|
||||||
fi
|
|
||||||
|
|
||||||
while getopts ":m:i:d:p:b:a:t:" option; do
|
|
||||||
case $option in
|
|
||||||
m) MODEL_NAME=${OPTARG}
|
|
||||||
;;
|
|
||||||
i) INITIAL_PEER=${OPTARG}
|
|
||||||
;;
|
|
||||||
d) DEVICE=${OPTARG}
|
|
||||||
;;
|
|
||||||
p) SERVER_ID_PATH=${OPTARG}
|
|
||||||
;;
|
|
||||||
b) BLOCK_IDS=${OPTARG}
|
|
||||||
;;
|
|
||||||
a) HOST_MADDR=${OPTARG} # TODO: allow several maddrs
|
|
||||||
;;
|
|
||||||
t) RUN_LOCAL_TESTS=true
|
|
||||||
;;
|
|
||||||
\?) instructions
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
|
|
||||||
echo "=========="
|
|
||||||
echo "= Config ="
|
|
||||||
echo "=========="
|
|
||||||
echo "Model name: ${MODEL_NAME}"
|
|
||||||
echo "Initial peer: ${INITIAL_PEER}"
|
|
||||||
echo "Device: ${DEVICE}"
|
|
||||||
echo "Server name: ${SERVER_ID_PATH}"
|
|
||||||
echo "Server address: ${HOST_MADDR}"
|
|
||||||
echo "Bloom blocks: ${BLOCK_IDS}"
|
|
||||||
|
|
||||||
|
|
||||||
###########################
|
|
||||||
# Install or activate env #
|
|
||||||
###########################
|
|
||||||
|
|
||||||
# TODO fix bug with self calling
|
|
||||||
source ~/miniconda3/etc/profile.d/conda.sh
|
|
||||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
|
||||||
conda activate bloom-demo
|
|
||||||
else
|
|
||||||
conda create -y --name bloom-demo python=3.8.12 pip
|
|
||||||
conda activate bloom-demo
|
|
||||||
|
|
||||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
|
||||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
|
||||||
pip install -i https://pypi.org/simple -r .
|
|
||||||
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
|
|
||||||
fi
|
|
||||||
|
|
||||||
##############
|
|
||||||
# Run server #
|
|
||||||
##############
|
|
||||||
|
|
||||||
python -m petals.cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
|
|
||||||
--block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log
|
|
@ -1,51 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from hivemind.utils.logging import get_logger
|
|
||||||
from tqdm.auto import trange
|
|
||||||
from transformers import BloomConfig
|
|
||||||
from transformers.models.bloom.modeling_bloom import build_alibi_tensor
|
|
||||||
|
|
||||||
from petals.bloom.block import BloomBlock
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
|
||||||
logger.warning("inference_one_block will soon be deprecated in favour of tests!")
|
|
||||||
|
|
||||||
|
|
||||||
def print_device_info(device=None):
|
|
||||||
"""Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
|
|
||||||
device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
|
||||||
logger.info(f"Using device: {device}")
|
|
||||||
|
|
||||||
# Additional Info when using cuda
|
|
||||||
if device.type == "cuda":
|
|
||||||
logger.info(torch.cuda.get_device_name(0))
|
|
||||||
logger.info(f"Memory Usage:")
|
|
||||||
logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
|
|
||||||
logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
|
|
||||||
parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
|
|
||||||
parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
|
|
||||||
parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
|
|
||||||
parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.device is None:
|
|
||||||
args.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
config = BloomConfig.from_json_file(args.config)
|
|
||||||
block = BloomBlock(config).to(args.device)
|
|
||||||
|
|
||||||
cache = None
|
|
||||||
|
|
||||||
for i in trange(args.num_steps):
|
|
||||||
dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device)
|
|
||||||
alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
|
|
||||||
|
|
||||||
print_device_info(args.device)
|
|
@ -1,5 +0,0 @@
|
|||||||
device=cpu
|
|
||||||
block_ids=2:3
|
|
||||||
id_path=./server.id
|
|
||||||
maddr=/ip4/127.0.0.1/tcp/30000
|
|
||||||
#
|
|
@ -1,6 +0,0 @@
|
|||||||
name=bloom-peer-0.bloom.net
|
|
||||||
device=cpu
|
|
||||||
block_ids=1:3
|
|
||||||
id_path=./server.id
|
|
||||||
maddr=/ip4/0.0.0.0/tcp/30000
|
|
||||||
#
|
|
@ -0,0 +1,106 @@
|
|||||||
|
"""
|
||||||
|
A copy of run_dht.py from hivemind with the ReachabilityProtocol added:
|
||||||
|
https://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_dht.py
|
||||||
|
|
||||||
|
This script may be used for launching lightweight CPU machines serving as bootstrap nodes to a Petals swarm.
|
||||||
|
|
||||||
|
This may be eventually merged to the hivemind upstream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from secrets import token_hex
|
||||||
|
|
||||||
|
from hivemind.dht import DHT, DHTNode
|
||||||
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||||
|
from hivemind.utils.networking import log_visible_maddrs
|
||||||
|
|
||||||
|
from petals.server.reachability import ReachabilityProtocol
|
||||||
|
|
||||||
|
use_hivemind_log_handler("in_root_logger")
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def report_status(dht: DHT, node: DHTNode):
|
||||||
|
logger.info(
|
||||||
|
f"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) "
|
||||||
|
f"are in the local routing table "
|
||||||
|
)
|
||||||
|
logger.debug(f"Routing table contents: {node.protocol.routing_table}")
|
||||||
|
logger.info(f"Local storage contains {len(node.protocol.storage)} keys")
|
||||||
|
logger.debug(f"Local storage contents: {node.protocol.storage}")
|
||||||
|
|
||||||
|
# Contact peers and keep the routing table healthy (remove stale PeerIDs)
|
||||||
|
await node.get(f"heartbeat_{token_hex(16)}", latest=True)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial_peers",
|
||||||
|
nargs="*",
|
||||||
|
help="Multiaddrs of the peers that will welcome you into the existing DHT. "
|
||||||
|
"Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host_maddrs",
|
||||||
|
nargs="*",
|
||||||
|
default=["/ip4/0.0.0.0/tcp/0", "/ip6/::/tcp/0"],
|
||||||
|
help="Multiaddrs to listen for external connections from other DHT instances. "
|
||||||
|
"Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--announce_maddrs",
|
||||||
|
nargs="*",
|
||||||
|
help="Visible multiaddrs the host announces for external connections from other DHT instances",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_ipfs",
|
||||||
|
action="store_true",
|
||||||
|
help='Use IPFS to find initial_peers. If enabled, you only need to provide the "/p2p/XXXX" '
|
||||||
|
"part of the multiaddrs for the initial_peers "
|
||||||
|
"(no need to specify a particular IPv4/IPv6 host and port)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--identity_path",
|
||||||
|
help="Path to a private key file. If defined, makes the peer ID deterministic. "
|
||||||
|
"If the file does not exist, writes a new private key to this file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_relay",
|
||||||
|
action="store_false",
|
||||||
|
dest="use_relay",
|
||||||
|
help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_auto_relay",
|
||||||
|
action="store_true",
|
||||||
|
help="Look for libp2p relays to become reachable if we are behind NAT/firewall",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
dht = DHT(
|
||||||
|
start=True,
|
||||||
|
initial_peers=args.initial_peers,
|
||||||
|
host_maddrs=args.host_maddrs,
|
||||||
|
announce_maddrs=args.announce_maddrs,
|
||||||
|
use_ipfs=args.use_ipfs,
|
||||||
|
identity_path=args.identity_path,
|
||||||
|
use_relay=args.use_relay,
|
||||||
|
use_auto_relay=args.use_auto_relay,
|
||||||
|
)
|
||||||
|
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)
|
||||||
|
|
||||||
|
reachability_protocol = ReachabilityProtocol.attach_to_dht(dht, await_ready=True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
dht.run_coroutine(report_status, return_future=False)
|
||||||
|
time.sleep(args.refresh_period)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,109 +0,0 @@
|
|||||||
# !/usr/bin/env bash
|
|
||||||
|
|
||||||
#################
|
|
||||||
# Parse options #
|
|
||||||
#################
|
|
||||||
|
|
||||||
instructions() {
|
|
||||||
echo "Usage: $0 [-n] [-c]" >&2
|
|
||||||
echo " -n: number of servers to run" >&2
|
|
||||||
echo " -c: path to the server configs" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if [ $# != 4 ]; then
|
|
||||||
instructions
|
|
||||||
fi
|
|
||||||
|
|
||||||
while getopts ":n:c:t:" option; do
|
|
||||||
case $option in
|
|
||||||
n) NUM_SERVERS=${OPTARG}
|
|
||||||
;;
|
|
||||||
c) CONFIG_PATH=${OPTARG}
|
|
||||||
;;
|
|
||||||
\?) instructions
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
|
|
||||||
###########################
|
|
||||||
# Install or activate env #
|
|
||||||
###########################
|
|
||||||
|
|
||||||
source ~/miniconda3/etc/profile.d/conda.sh
|
|
||||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
|
||||||
conda activate bloom-demo
|
|
||||||
else
|
|
||||||
conda create -y --name bloom-demo python=3.8.12 pip
|
|
||||||
conda activate bloom-demo
|
|
||||||
|
|
||||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
|
||||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
|
||||||
pip install -i https://pypi.org/simple -r .
|
|
||||||
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
#######################
|
|
||||||
# Create Initial peer #
|
|
||||||
#######################
|
|
||||||
|
|
||||||
hivemind-dht &> tmp.out &
|
|
||||||
sleep 5
|
|
||||||
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
|
|
||||||
echo "Initial peer: ${INITIAL_PEER}"
|
|
||||||
|
|
||||||
|
|
||||||
##############################
|
|
||||||
# Initialize the config file #
|
|
||||||
##############################
|
|
||||||
|
|
||||||
typeset -A cfg
|
|
||||||
cfg=( # set default values in config array
|
|
||||||
[device]="cpu"
|
|
||||||
[block_ids]="1:2"
|
|
||||||
[id_path]="server.id"
|
|
||||||
[maddr]="/ip4/127.0.0.1/tcp/30000"
|
|
||||||
)
|
|
||||||
|
|
||||||
###############
|
|
||||||
# Run servers #
|
|
||||||
###############
|
|
||||||
|
|
||||||
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
|
|
||||||
do
|
|
||||||
###############
|
|
||||||
# Read config #
|
|
||||||
###############
|
|
||||||
|
|
||||||
while read line
|
|
||||||
do
|
|
||||||
if echo $line | grep -F = &>/dev/null
|
|
||||||
then
|
|
||||||
varname=$(echo "$line" | cut -d '=' -f 1)
|
|
||||||
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
|
|
||||||
fi
|
|
||||||
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
|
|
||||||
|
|
||||||
echo "=== Server #${SERVER_ID} ==="
|
|
||||||
echo "Server ID: ${cfg[id_path]}"
|
|
||||||
echo "Device: ${cfg[device]}"
|
|
||||||
echo "Bloom block ids: ${cfg[block_ids]}"
|
|
||||||
echo "Host maddr: ${cfg[maddr]}"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
##############
|
|
||||||
# Run server #
|
|
||||||
##############
|
|
||||||
|
|
||||||
tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -m "bigscience/test-bloomd" -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
|
|
||||||
done
|
|
||||||
|
|
||||||
#####################
|
|
||||||
# Kill initial peer #
|
|
||||||
#####################
|
|
||||||
|
|
||||||
sleep 10
|
|
||||||
pkill -f hivemind-dht # TODO: kill only particular pids of hivemind-dht
|
|
||||||
rm tmp.out
|
|
@ -1,110 +0,0 @@
|
|||||||
# !/usr/bin/env bash
|
|
||||||
|
|
||||||
SSH_KEY_PATH="~/.ssh/<YOUR_KEY>"
|
|
||||||
|
|
||||||
#################
|
|
||||||
# Parse options #
|
|
||||||
#################
|
|
||||||
|
|
||||||
instructions() {
|
|
||||||
echo "Usage: $0 [-u] [-n] [-c]" >&2
|
|
||||||
echo " -u: username" >&2
|
|
||||||
echo " -n: number of servers to run" >&2
|
|
||||||
echo " -c: path to the server configs" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if [ $# != 6 ]; then
|
|
||||||
instructions
|
|
||||||
fi
|
|
||||||
|
|
||||||
while getopts ":u:n:c:" option; do
|
|
||||||
case $option in
|
|
||||||
u) USERNAME=${OPTARG}
|
|
||||||
;;
|
|
||||||
n) NUM_SERVERS=${OPTARG}
|
|
||||||
;;
|
|
||||||
c) CONFIG_PATH=${OPTARG}
|
|
||||||
;;
|
|
||||||
\?) instructions
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
|
|
||||||
###########################
|
|
||||||
# Install or activate env #
|
|
||||||
###########################
|
|
||||||
|
|
||||||
source ~/miniconda3/etc/profile.d/conda.sh
|
|
||||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
|
||||||
conda activate bloom-demo
|
|
||||||
else
|
|
||||||
conda create -y --name bloom-demo python=3.8.12 pip
|
|
||||||
conda activate bloom-demo
|
|
||||||
|
|
||||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
|
||||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
|
||||||
pip install -i https://pypi.org/simple -r .
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
#######################
|
|
||||||
# Create Initial peer #
|
|
||||||
#######################
|
|
||||||
|
|
||||||
hivemind-dht &> tmp.out &
|
|
||||||
|
|
||||||
sleep 5
|
|
||||||
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
|
|
||||||
rm tmp.out
|
|
||||||
echo "Initial peer: ${INITIAL_PEER}"
|
|
||||||
|
|
||||||
|
|
||||||
##############################
|
|
||||||
# Initialize the config file #
|
|
||||||
##############################
|
|
||||||
|
|
||||||
typeset -A cfg
|
|
||||||
cfg=( # set default values in config array
|
|
||||||
[name]=""
|
|
||||||
[device]="cpu"
|
|
||||||
[block_ids]="1:2"
|
|
||||||
[id_path]="server.id"
|
|
||||||
[maddr]="/ip4/0.0.0.0/tcp/30000"
|
|
||||||
)
|
|
||||||
|
|
||||||
###############
|
|
||||||
# Run servers #
|
|
||||||
###############
|
|
||||||
|
|
||||||
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
|
|
||||||
do
|
|
||||||
###############
|
|
||||||
# Read config #
|
|
||||||
###############
|
|
||||||
|
|
||||||
while read line
|
|
||||||
do
|
|
||||||
if echo $line | grep -F = &>/dev/null
|
|
||||||
then
|
|
||||||
varname=$(echo "$line" | cut -d '=' -f 1)
|
|
||||||
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
|
|
||||||
fi
|
|
||||||
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
|
|
||||||
|
|
||||||
SERVER_NAME="${USERNAME}@${cfg[name]}"
|
|
||||||
echo "=== Server #${SERVER_ID} ==="
|
|
||||||
echo "Server name ${SERVER_NAME}"
|
|
||||||
echo "Server ID: ${cfg[id_path]}"
|
|
||||||
echo "Device: ${cfg[device]}"
|
|
||||||
echo "Bloom block ids: ${cfg[block_ids]}"
|
|
||||||
echo "Host maddr: ${cfg[maddr]}"
|
|
||||||
echo "================="
|
|
||||||
|
|
||||||
##############
|
|
||||||
# Run server #
|
|
||||||
##############
|
|
||||||
|
|
||||||
ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'"
|
|
||||||
done
|
|
@ -1,10 +1,4 @@
|
|||||||
from petals.client.inference_session import InferenceSession
|
from petals.client.inference_session import InferenceSession
|
||||||
from petals.client.remote_model import (
|
from petals.client.remote_sequential import RemoteSequential
|
||||||
DistributedBloomConfig,
|
|
||||||
DistributedBloomForCausalLM,
|
|
||||||
DistributedBloomForSequenceClassification,
|
|
||||||
DistributedBloomModel,
|
|
||||||
)
|
|
||||||
from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
|
|
||||||
from petals.client.routing.sequence_manager import RemoteSequenceManager
|
from petals.client.routing.sequence_manager import RemoteSequenceManager
|
||||||
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||||
|
@ -0,0 +1,94 @@
|
|||||||
|
import contextlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from transformers import BloomPreTrainedModel, modeling_utils
|
||||||
|
|
||||||
|
from petals.utils.version import get_compatible_model_repo
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FromPretrainedMixin:
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
model_name_or_path: Union[str, os.PathLike, None],
|
||||||
|
*args,
|
||||||
|
low_cpu_mem_usage: Optional[bool] = None,
|
||||||
|
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
model_name_or_path = get_compatible_model_repo(model_name_or_path)
|
||||||
|
if low_cpu_mem_usage is None:
|
||||||
|
low_cpu_mem_usage = True
|
||||||
|
if torch_dtype is None:
|
||||||
|
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
|
||||||
|
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
|
||||||
|
torch_dtype = "auto"
|
||||||
|
|
||||||
|
with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
|
||||||
|
return super().from_pretrained(
|
||||||
|
model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
||||||
|
"low_cpu_mem_usage(`bool`, *optional*)",
|
||||||
|
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
||||||
|
).replace(
|
||||||
|
"torch_dtype (`str` or `torch.dtype`, *optional*)",
|
||||||
|
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_shard_config = threading.local()
|
||||||
|
_shard_config.ignored_keys = None
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def ignore_keys(patterns: List[str]):
|
||||||
|
try:
|
||||||
|
prev_patterns = _shard_config.ignored_keys
|
||||||
|
_shard_config.ignored_keys = patterns
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_shard_config.ignored_keys = prev_patterns
|
||||||
|
|
||||||
|
|
||||||
|
def patched_get_checkpoint_shard_files(
|
||||||
|
pretrained_model_name_or_path, index_filename, *args, **kwargs
|
||||||
|
) -> Tuple[List[str], dict]:
|
||||||
|
"""Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
|
||||||
|
|
||||||
|
should_ignore_keys = _shard_config.ignored_keys is not None
|
||||||
|
tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
|
||||||
|
with tempdir_ctx as tempdir:
|
||||||
|
if should_ignore_keys:
|
||||||
|
with open(index_filename) as f:
|
||||||
|
index = json.load(f)
|
||||||
|
n_original_shards = len(set(index["weight_map"].values()))
|
||||||
|
|
||||||
|
index["weight_map"] = {
|
||||||
|
param_name: filename
|
||||||
|
for param_name, filename in index["weight_map"].items()
|
||||||
|
if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys)
|
||||||
|
}
|
||||||
|
n_loaded_shards = len(set(index["weight_map"].values()))
|
||||||
|
logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")
|
||||||
|
|
||||||
|
# Replace the original index with a patched JSON, where ignored keys are removed
|
||||||
|
index_filename = os.path.join(tempdir, "pytorch_model.bin.index.json")
|
||||||
|
with open(index_filename, "w") as f:
|
||||||
|
json.dump(index, f)
|
||||||
|
|
||||||
|
return original_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
original_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||||
|
modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files
|
@ -0,0 +1,84 @@
|
|||||||
|
import dataclasses
|
||||||
|
import platform
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from hivemind import get_logger
|
||||||
|
from torch import nn
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class LMHeadConfig:
|
||||||
|
# This settings matter for running the client with dtype bfloat16 on CPU.
|
||||||
|
# If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
|
||||||
|
use_chunked_forward: Union[str, bool] = "auto"
|
||||||
|
chunked_forward_step: int = 16384
|
||||||
|
|
||||||
|
|
||||||
|
class LMHead(nn.Module):
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not config.tie_word_embeddings:
|
||||||
|
self.weight = nn.Parameter(torch.zeros(config.vocab_size, config.hidden_size))
|
||||||
|
self.weight.requires_grad = False
|
||||||
|
else:
|
||||||
|
self.weight = None # Will be set to get_input_embeddings().weight during loading the model
|
||||||
|
self.bias = None
|
||||||
|
self.in_features = config.hidden_size # Similar to nn.Linear attributes
|
||||||
|
self.out_features = config.vocab_size
|
||||||
|
|
||||||
|
self.use_chunked_forward = config.use_chunked_forward
|
||||||
|
if self.use_chunked_forward == "auto":
|
||||||
|
if platform.machine() == "x86_64":
|
||||||
|
# Import of cpufeature may crash on non-x86_64 machines
|
||||||
|
from cpufeature import CPUFeature
|
||||||
|
|
||||||
|
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
|
||||||
|
# Otherwise, it's ~8x slower.
|
||||||
|
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
|
||||||
|
else:
|
||||||
|
self.use_chunked_forward = True
|
||||||
|
self.chunked_forward_step = config.chunked_forward_step
|
||||||
|
self._bf16_warning_shown = False
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
if (
|
||||||
|
self.weight.dtype in [torch.float16, torch.bfloat16]
|
||||||
|
and self.weight.device.type == "cpu"
|
||||||
|
and self.use_chunked_forward
|
||||||
|
):
|
||||||
|
lm_logits = self.chunked_forward(hidden_states)
|
||||||
|
else:
|
||||||
|
# Switch dtype in case word_embeddings are fp16/bf16
|
||||||
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
lm_logits = F.linear(hidden_states, self.weight)
|
||||||
|
return lm_logits
|
||||||
|
|
||||||
|
def chunked_forward(self, hidden_states):
|
||||||
|
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
|
||||||
|
chunked_forward_step: provides trade-off between efficiency and extra memory consumption.
|
||||||
|
"""
|
||||||
|
assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
|
||||||
|
|
||||||
|
if not self._bf16_warning_shown:
|
||||||
|
if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
|
||||||
|
logger.warning(
|
||||||
|
"Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
|
||||||
|
"Consider loading the model with torch_dtype='float32'"
|
||||||
|
)
|
||||||
|
self._bf16_warning_shown = True
|
||||||
|
|
||||||
|
hidden_states = hidden_states.float()
|
||||||
|
output = torch.empty(*hidden_states.shape[:-1], self.out_features)
|
||||||
|
|
||||||
|
for i in range(0, self.out_features, self.chunked_forward_step):
|
||||||
|
chunk = self.weight[i : i + self.chunked_forward_step].float()
|
||||||
|
output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)
|
||||||
|
return output
|
@ -0,0 +1,84 @@
|
|||||||
|
import dataclasses
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from hivemind import get_logger
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from petals.utils.misc import DUMMY
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class PTuneConfig:
|
||||||
|
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
||||||
|
tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
|
||||||
|
|
||||||
|
|
||||||
|
class PTuneMixin:
|
||||||
|
_keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"]
|
||||||
|
|
||||||
|
def init_prompts(self, config: PretrainedConfig) -> None:
|
||||||
|
if config.tuning_mode and "ptune" in config.tuning_mode:
|
||||||
|
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
||||||
|
self.pre_seq_len = config.pre_seq_len
|
||||||
|
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
||||||
|
|
||||||
|
with force_non_empty_weights():
|
||||||
|
# Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality
|
||||||
|
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
||||||
|
if config.tuning_mode == "deep_ptune":
|
||||||
|
self.intermediate_prompt_embeddings = nn.Embedding(
|
||||||
|
self.pre_seq_len,
|
||||||
|
config.num_hidden_layers * config.hidden_size,
|
||||||
|
# ^-- TODO: should be num_hidden_layers - 1
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
elif config.tuning_mode:
|
||||||
|
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
||||||
|
|
||||||
|
def get_prompt(self, batch_size):
|
||||||
|
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
||||||
|
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
||||||
|
prompts = self.prompt_embeddings(prefix_tokens)
|
||||||
|
|
||||||
|
if self.config.tuning_mode == "deep_ptune":
|
||||||
|
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
||||||
|
intermediate_prompts = intermediate_prompts.view(
|
||||||
|
batch_size,
|
||||||
|
self.pre_seq_len,
|
||||||
|
self.config.num_hidden_layers,
|
||||||
|
self.config.hidden_size
|
||||||
|
# TODO: should be num_hidden_layers - 1
|
||||||
|
)
|
||||||
|
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
||||||
|
else:
|
||||||
|
intermediate_prompts = DUMMY
|
||||||
|
|
||||||
|
dtype = self.word_embeddings.weight.dtype
|
||||||
|
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
_original_register_parameter = nn.Module.register_parameter
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def force_non_empty_weights():
|
||||||
|
"""
|
||||||
|
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
||||||
|
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
||||||
|
The transformers library should replace all meta tensors by empty tensors by itself
|
||||||
|
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
||||||
|
|
||||||
|
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
possibly_patched_register_parameter = nn.Module.register_parameter
|
||||||
|
nn.Module.register_parameter = _original_register_parameter
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
nn.Module.register_parameter = possibly_patched_register_parameter
|
@ -1,264 +0,0 @@
|
|||||||
import os
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import hivemind
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from hivemind.utils.logging import get_logger
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
|
||||||
from transformers.models.bloom import (
|
|
||||||
BloomConfig,
|
|
||||||
BloomForCausalLM,
|
|
||||||
BloomForSequenceClassification,
|
|
||||||
BloomModel,
|
|
||||||
BloomPreTrainedModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
from petals.bloom.modeling_utils import LMHead
|
|
||||||
from petals.client.remote_generation import RemoteGenerationMixin
|
|
||||||
from petals.client.remote_sequential import RemoteSequential
|
|
||||||
from petals.constants import PUBLIC_INITIAL_PEERS
|
|
||||||
from petals.utils.misc import DUMMY
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedBloomConfig(BloomConfig):
|
|
||||||
"""
|
|
||||||
A bloom config that contains information about DHT peers.
|
|
||||||
To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
|
|
||||||
"""
|
|
||||||
|
|
||||||
initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT
|
|
||||||
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
|
|
||||||
daemon_startup_timeout: int = 30
|
|
||||||
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
|
|
||||||
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
|
|
||||||
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
|
||||||
tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
|
|
||||||
request_timeout: int = 30 # a number of seconds for waiting result from each node
|
|
||||||
|
|
||||||
|
|
||||||
original_register_parameter = nn.Module.register_parameter
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def force_non_empty_weights():
|
|
||||||
"""
|
|
||||||
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
|
||||||
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
|
||||||
The transformers library should replace all meta tensors by empty tensors by itself
|
|
||||||
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
|
||||||
|
|
||||||
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
possibly_patched_register_parameter = nn.Module.register_parameter
|
|
||||||
nn.Module.register_parameter = original_register_parameter
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
nn.Module.register_parameter = possibly_patched_register_parameter
|
|
||||||
|
|
||||||
|
|
||||||
class _LowCPUMemoryMixin:
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
|
|
||||||
if low_cpu_mem_usage is None:
|
|
||||||
low_cpu_mem_usage = True
|
|
||||||
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
|
|
||||||
|
|
||||||
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
|
||||||
"low_cpu_mem_usage(`bool`, *optional*)",
|
|
||||||
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
|
|
||||||
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
|
||||||
|
|
||||||
_keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
|
|
||||||
r"^(intermediate_)?prompt_embeddings\.weight$",
|
|
||||||
]
|
|
||||||
|
|
||||||
config_class = DistributedBloomConfig
|
|
||||||
|
|
||||||
def __init__(self, config: DistributedBloomConfig):
|
|
||||||
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
|
|
||||||
assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
|
|
||||||
|
|
||||||
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
|
|
||||||
super().__init__(config)
|
|
||||||
assert len(self.h) == 0
|
|
||||||
config.n_layer = n_layer
|
|
||||||
|
|
||||||
dht = (
|
|
||||||
config.dht
|
|
||||||
if config.dht is not None
|
|
||||||
else hivemind.DHT(
|
|
||||||
initial_peers=config.initial_peers,
|
|
||||||
client_mode=True,
|
|
||||||
num_workers=n_layer,
|
|
||||||
startup_timeout=config.daemon_startup_timeout,
|
|
||||||
start=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
|
|
||||||
self.h = RemoteSequential(config, dht, config.dht_prefix, request_timeout=config.request_timeout)
|
|
||||||
|
|
||||||
# Forbid accumulate grads for embeddings and layernorm
|
|
||||||
self.set_requires_grad(False)
|
|
||||||
|
|
||||||
if config.tuning_mode and "ptune" in config.tuning_mode:
|
|
||||||
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
|
||||||
self.pre_seq_len = config.pre_seq_len
|
|
||||||
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
|
||||||
|
|
||||||
with force_non_empty_weights():
|
|
||||||
if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16):
|
|
||||||
logger.info(
|
|
||||||
"Prompt embeddings and their optimizer statistics will be kept in float32 "
|
|
||||||
"to increase ptune quality"
|
|
||||||
)
|
|
||||||
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
|
||||||
if config.tuning_mode == "deep_ptune":
|
|
||||||
self.intermediate_prompt_embeddings = nn.Embedding(
|
|
||||||
self.pre_seq_len,
|
|
||||||
config.num_hidden_layers * config.hidden_size,
|
|
||||||
# ^-- TODO: should be num_hidden_layers - 1
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
elif config.tuning_mode:
|
|
||||||
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
|
||||||
|
|
||||||
def set_requires_grad(self, value):
|
|
||||||
for p in self.parameters():
|
|
||||||
p.requires_grad = value
|
|
||||||
|
|
||||||
def get_prompt(self, batch_size):
|
|
||||||
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
|
||||||
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
|
||||||
prompts = self.prompt_embeddings(prefix_tokens)
|
|
||||||
|
|
||||||
if self.config.tuning_mode == "deep_ptune":
|
|
||||||
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
|
||||||
intermediate_prompts = intermediate_prompts.view(
|
|
||||||
batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1
|
|
||||||
)
|
|
||||||
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
|
||||||
else:
|
|
||||||
intermediate_prompts = DUMMY
|
|
||||||
|
|
||||||
dtype = self.word_embeddings.weight.dtype
|
|
||||||
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
|
|
||||||
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if not (v is None or v is False):
|
|
||||||
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
||||||
elif input_ids is not None:
|
|
||||||
input_shape = input_ids.size()
|
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
|
||||||
|
|
||||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
||||||
batch_size = inputs_embeds.shape[0]
|
|
||||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
|
||||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
||||||
|
|
||||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
|
||||||
|
|
||||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
||||||
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
|
||||||
else:
|
|
||||||
hidden_states = self.h(hidden_states)
|
|
||||||
|
|
||||||
# Remove prefix
|
|
||||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
||||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
|
||||||
|
|
||||||
# Add last hidden state
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
|
||||||
hidden_states = hidden_states.view(output_shape)
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=None,
|
|
||||||
hidden_states=None,
|
|
||||||
attentions=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, BloomForCausalLM):
|
|
||||||
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
||||||
|
|
||||||
_keys_to_ignore_on_load_missing = (
|
|
||||||
BloomForCausalLM._keys_to_ignore_on_load_missing
|
|
||||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
||||||
+ [r"^lm_head.word_embeddings\.weight$"] # Missing since they are shared with input embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
config_class = DistributedBloomConfig
|
|
||||||
|
|
||||||
def __init__(self, config: DistributedBloomConfig):
|
|
||||||
BloomPreTrainedModel.__init__(self, config)
|
|
||||||
self.transformer = DistributedBloomModel(config)
|
|
||||||
self.lm_head = LMHead(config, self.transformer.word_embeddings)
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.transformer.word_embeddings
|
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
|
||||||
if self.config.tie_word_embeddings:
|
|
||||||
return None
|
|
||||||
return self.lm_head
|
|
||||||
|
|
||||||
def set_input_embeddings(self, new_embeddings: nn.Embedding):
|
|
||||||
assert isinstance(new_embeddings, nn.Embedding)
|
|
||||||
self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
|
|
||||||
assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
|
|
||||||
|
|
||||||
def set_output_embeddings(self, new_lm_head: nn.Linear):
|
|
||||||
with torch.no_grad():
|
|
||||||
self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
|
|
||||||
self.lm_head.bias[...] = new_lm_head.bias
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequenceClassification):
|
|
||||||
_keys_to_ignore_on_load_missing = (
|
|
||||||
BloomForSequenceClassification._keys_to_ignore_on_load_missing
|
|
||||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
||||||
)
|
|
||||||
|
|
||||||
config_class = DistributedBloomConfig
|
|
||||||
|
|
||||||
def __init__(self, config: DistributedBloomConfig):
|
|
||||||
BloomPreTrainedModel.__init__(self, config)
|
|
||||||
self.num_labels = config.num_labels
|
|
||||||
|
|
||||||
self.transformer = DistributedBloomModel(config)
|
|
||||||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
@ -1,6 +1,18 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
PUBLIC_INITIAL_PEERS = [
|
PUBLIC_INITIAL_PEERS = [
|
||||||
"/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
# IPv4 DNS addresses
|
||||||
"/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
"/dns/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||||
"/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
"/dns/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||||
"/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
# IPv6 DNS addresses
|
||||||
|
"/dns6/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||||
|
"/dns6/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||||
|
# Reserved IPs
|
||||||
|
"/ip4/159.89.214.152/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||||
|
"/ip4/159.203.156.48/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# The reachability API is currently used only when connecting to the public swarm
|
||||||
|
REACHABILITY_API_URL = "https://health.petals.dev"
|
||||||
|
|
||||||
|
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
||||||
|
@ -0,0 +1,2 @@
|
|||||||
|
from petals.models.bloom import *
|
||||||
|
from petals.models.llama import *
|
@ -0,0 +1,15 @@
|
|||||||
|
from petals.models.bloom.block import WrappedBloomBlock
|
||||||
|
from petals.models.bloom.config import DistributedBloomConfig
|
||||||
|
from petals.models.bloom.model import (
|
||||||
|
DistributedBloomForCausalLM,
|
||||||
|
DistributedBloomForSequenceClassification,
|
||||||
|
DistributedBloomModel,
|
||||||
|
)
|
||||||
|
from petals.utils.auto_config import register_model_classes
|
||||||
|
|
||||||
|
register_model_classes(
|
||||||
|
config=DistributedBloomConfig,
|
||||||
|
model=DistributedBloomModel,
|
||||||
|
model_for_causal_lm=DistributedBloomForCausalLM,
|
||||||
|
model_for_sequence_classification=DistributedBloomForSequenceClassification,
|
||||||
|
)
|
@ -0,0 +1,32 @@
|
|||||||
|
"""
|
||||||
|
Bloom intermediate layer
|
||||||
|
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
||||||
|
See commit history for authorship.
|
||||||
|
"""
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedBloomBlock(BloomBlock):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
alibi: Optional[torch.Tensor] = None,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
assert attention_mask is None, "Non-causal attention masks are not supported yet"
|
||||||
|
batch_size, seq_length = hidden_states.shape[:2]
|
||||||
|
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
|
||||||
|
seq_length_with_past = seq_length + past_length
|
||||||
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||||
|
if alibi is None:
|
||||||
|
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||||
|
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
||||||
|
)
|
@ -0,0 +1,34 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from hivemind import get_logger
|
||||||
|
from transformers.models.bloom import BloomConfig
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||||
|
|
||||||
|
from petals.client.lm_head import LMHeadConfig
|
||||||
|
from petals.client.ptune import PTuneConfig
|
||||||
|
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
||||||
|
from petals.models.bloom.block import WrappedBloomBlock
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
|
||||||
|
block_class = WrappedBloomBlock
|
||||||
|
attn_class = BloomAttention
|
||||||
|
block_prefix = "h"
|
||||||
|
|
||||||
|
num_key_value_groups = 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
|
logger.info("Make sure you follow the BLOOM's terms of use: https://bit.ly/bloom-license")
|
||||||
|
|
||||||
|
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||||
|
if loading_from_repo and dht_prefix is None:
|
||||||
|
# We need "-petals" for backward compatibility with Petals < 1.2.0
|
||||||
|
dht_prefix = str(model_name_or_path) + "-petals"
|
||||||
|
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||||
|
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
@ -0,0 +1,126 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import hivemind
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||||
|
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
|
||||||
|
|
||||||
|
from petals.client.from_pretrained import FromPretrainedMixin
|
||||||
|
from petals.client.lm_head import LMHead
|
||||||
|
from petals.client.ptune import PTuneMixin
|
||||||
|
from petals.client.remote_generation import RemoteGenerationMixin
|
||||||
|
from petals.client.remote_sequential import RemoteSequential
|
||||||
|
from petals.models.bloom.config import DistributedBloomConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
|
||||||
|
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
||||||
|
|
||||||
|
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"^h\."]
|
||||||
|
|
||||||
|
config_class = DistributedBloomConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||||
|
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||||
|
super().__init__(config)
|
||||||
|
assert len(self.h) == 0
|
||||||
|
config.num_hidden_layers = n_layer
|
||||||
|
|
||||||
|
self.h = RemoteSequential(config, dht=dht)
|
||||||
|
|
||||||
|
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||||
|
self.init_prompts(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
|
||||||
|
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if not (v is None or v is False):
|
||||||
|
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||||
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||||
|
|
||||||
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
||||||
|
else:
|
||||||
|
hidden_states = self.h(hidden_states)
|
||||||
|
|
||||||
|
# Remove prefix
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||||
|
|
||||||
|
# Add last hidden state
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=None,
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
|
||||||
|
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedBloomConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedBloomConfig):
|
||||||
|
BloomPreTrainedModel.__init__(self, config)
|
||||||
|
self.transformer = DistributedBloomModel(config)
|
||||||
|
self.lm_head = LMHead(config)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
|
||||||
|
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedBloomConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedBloomConfig):
|
||||||
|
BloomPreTrainedModel.__init__(self, config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.transformer = DistributedBloomModel(config)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
@ -0,0 +1,15 @@
|
|||||||
|
from petals.models.llama.block import WrappedLlamaBlock
|
||||||
|
from petals.models.llama.config import DistributedLlamaConfig
|
||||||
|
from petals.models.llama.model import (
|
||||||
|
DistributedLlamaForCausalLM,
|
||||||
|
DistributedLlamaForSequenceClassification,
|
||||||
|
DistributedLlamaModel,
|
||||||
|
)
|
||||||
|
from petals.utils.auto_config import register_model_classes
|
||||||
|
|
||||||
|
register_model_classes(
|
||||||
|
config=DistributedLlamaConfig,
|
||||||
|
model=DistributedLlamaModel,
|
||||||
|
model_for_causal_lm=DistributedLlamaForCausalLM,
|
||||||
|
model_for_sequence_classification=DistributedLlamaForSequenceClassification,
|
||||||
|
)
|
@ -0,0 +1,91 @@
|
|||||||
|
"""
|
||||||
|
LLaMA intermediate layer
|
||||||
|
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||||
|
See commit history for authorship.
|
||||||
|
"""
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedLlamaBlock(LlamaDecoderLayer):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
batch_size, seq_length, _ = hidden_states.shape
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
past_key_value = layer_past
|
||||||
|
if past_key_value is not None:
|
||||||
|
past_key_values_length = past_key_value[0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = hidden_states.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||||
|
)
|
||||||
|
attention_mask = LlamaModel._prepare_decoder_attention_mask(
|
||||||
|
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = super().forward(
|
||||||
|
hidden_states,
|
||||||
|
*args,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
use_cache=use_cache,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
present_key_value = outputs[-1]
|
||||||
|
present_key_value = self._reorder_cache_from_llama_to_bloom(
|
||||||
|
present_key_value, batch_size, seq_length_with_past
|
||||||
|
)
|
||||||
|
outputs = outputs[:-1] + (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _reorder_cache_from_bloom_to_llama(
|
||||||
|
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||||
|
) -> Tuple[torch.Tensor]:
|
||||||
|
key_states, value_states = key_value
|
||||||
|
key_states = key_states.permute(0, 2, 1)
|
||||||
|
key_states = key_states.view(
|
||||||
|
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||||
|
)
|
||||||
|
value_states = value_states.view(*key_states.shape)
|
||||||
|
return (key_states, value_states)
|
||||||
|
|
||||||
|
def _reorder_cache_from_llama_to_bloom(
|
||||||
|
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||||
|
) -> Tuple[torch.Tensor]:
|
||||||
|
key_states, value_states = key_value
|
||||||
|
value_states = value_states.view(
|
||||||
|
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||||
|
)
|
||||||
|
key_states = key_states.view(*value_states.shape)
|
||||||
|
key_states = key_states.permute(0, 2, 1)
|
||||||
|
return (key_states, value_states)
|
@ -0,0 +1,45 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from hivemind import get_logger
|
||||||
|
from transformers.models.llama import LlamaConfig
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||||
|
|
||||||
|
from petals.client.lm_head import LMHeadConfig
|
||||||
|
from petals.client.ptune import PTuneConfig
|
||||||
|
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
||||||
|
from petals.models.llama.block import WrappedLlamaBlock
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
|
||||||
|
block_class = WrappedLlamaBlock
|
||||||
|
attn_class = LlamaAttention
|
||||||
|
block_prefix = "model.layers"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_key_value_groups(self):
|
||||||
|
return self.num_attention_heads // self.num_key_value_heads
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Make sure you follow the LLaMA's terms of use: "
|
||||||
|
"https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||||
|
if loading_from_repo and dht_prefix is None:
|
||||||
|
dht_prefix = str(model_name_or_path)
|
||||||
|
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
|
||||||
|
if not dht_prefix.endswith("-hf"):
|
||||||
|
dht_prefix += "-hf"
|
||||||
|
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||||
|
|
||||||
|
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||||
|
config = result[0] if isinstance(result, tuple) else result
|
||||||
|
config.pretraining_tp = 1 # This may give less accurate results but it doesn't matter if we use quantization
|
||||||
|
return result
|
@ -0,0 +1,151 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import hivemind
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
|
||||||
|
|
||||||
|
from petals.client.from_pretrained import FromPretrainedMixin
|
||||||
|
from petals.client.lm_head import LMHead
|
||||||
|
from petals.client.ptune import PTuneMixin
|
||||||
|
from petals.client.remote_generation import RemoteGenerationMixin
|
||||||
|
from petals.client.remote_sequential import RemoteSequential
|
||||||
|
from petals.models.llama.config import DistributedLlamaConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
|
||||||
|
"""LlamaModel, but all transformer layers are hosted by the swarm"""
|
||||||
|
|
||||||
|
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"^model\.layers\."]
|
||||||
|
|
||||||
|
config_class = DistributedLlamaConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedLlamaConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||||
|
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||||
|
super().__init__(config)
|
||||||
|
assert len(self.layers) == 0
|
||||||
|
config.num_hidden_layers = n_layer
|
||||||
|
|
||||||
|
self.layers = RemoteSequential(config, dht=dht)
|
||||||
|
|
||||||
|
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||||
|
self.init_prompts(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> BaseModelOutputWithPast:
|
||||||
|
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
|
||||||
|
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if not (v is None or v is False):
|
||||||
|
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||||
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
hidden_states = self.layers(hidden_states, prompts=intermediate_prompts)
|
||||||
|
else:
|
||||||
|
hidden_states = self.layers(hidden_states)
|
||||||
|
|
||||||
|
# Remove prefix
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||||
|
|
||||||
|
# Add last hidden state
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=None,
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||||
|
return nn.Identity()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.norm
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, LlamaForCausalLM):
|
||||||
|
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedLlamaConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedLlamaConfig):
|
||||||
|
LlamaPreTrainedModel.__init__(self, config)
|
||||||
|
self.model = DistributedLlamaModel(config)
|
||||||
|
self.pretraining_tp = config.pretraining_tp
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.lm_head = LMHead(config)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
|
||||||
|
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedLlamaConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
LlamaPreTrainedModel.__init__(self, config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.model = DistributedLlamaModel(config)
|
||||||
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.model
|
@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
This module implements server-side computations on served blocks: forward, backward and inference; used by handler
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import AsyncIterator, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
|
||||||
|
from hivemind.moe.expert_uid import ExpertUID
|
||||||
|
from hivemind.proto import runtime_pb2
|
||||||
|
from hivemind.utils.nested import nested_flatten
|
||||||
|
|
||||||
|
from petals.data_structures import InferenceMetadata
|
||||||
|
from petals.server.backend import TransformerBackend
|
||||||
|
from petals.server.memory_cache import Handle
|
||||||
|
from petals.server.task_pool import PrioritizedTaskPool
|
||||||
|
from petals.server.task_prioritizer import TaskPrioritizerBase
|
||||||
|
from petals.utils.convert_block import QuantType
|
||||||
|
from petals.utils.misc import DUMMY, is_dummy
|
||||||
|
|
||||||
|
# We prioritize short inference requests and make them use a *merged* inference pool,
|
||||||
|
# so they are processed without interruptions and extra overheads
|
||||||
|
# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward
|
||||||
|
MAX_SHORT_INFERENCE_TOKENS = 128
|
||||||
|
MAX_NF4_SHORT_INFERENCE_TOKENS = 1
|
||||||
|
|
||||||
|
|
||||||
|
async def run_rpc_forward(
|
||||||
|
*flat_tensors: torch.Tensor,
|
||||||
|
requested_backends: Sequence[TransformerBackend],
|
||||||
|
active_adapter: str = "",
|
||||||
|
prioritizer: TaskPrioritizerBase,
|
||||||
|
points: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
||||||
|
|
||||||
|
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
|
||||||
|
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
|
||||||
|
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
||||||
|
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
||||||
|
"""
|
||||||
|
hidden_states, prompts = flat_tensors
|
||||||
|
dtype = requested_backends[0].dtype
|
||||||
|
# check parse input tensors and cast dtypes
|
||||||
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
assert hidden_states.ndim == 3
|
||||||
|
if prompts is None or is_dummy(prompts):
|
||||||
|
prompts = [DUMMY] * len(requested_backends)
|
||||||
|
else:
|
||||||
|
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||||
|
|
||||||
|
# Run a chain of requested backends
|
||||||
|
for backend, prompt in zip(requested_backends, prompts):
|
||||||
|
if not is_dummy(prompt):
|
||||||
|
hidden_states[:, : prompt.shape[1]] += prompt
|
||||||
|
|
||||||
|
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||||
|
priority = prioritizer.prioritize(
|
||||||
|
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
|
||||||
|
)
|
||||||
|
(hidden_states,) = await backend.forward_pool.submit_task(
|
||||||
|
hidden_states,
|
||||||
|
active_adapter,
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
assert isinstance(hidden_states, torch.Tensor)
|
||||||
|
assert (
|
||||||
|
hidden_states.ndim == 3
|
||||||
|
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
async def run_rpc_backward(
|
||||||
|
*flat_tensors: torch.Tensor,
|
||||||
|
requested_backends: Sequence[TransformerBackend],
|
||||||
|
active_adapter: str = "",
|
||||||
|
prioritizer: TaskPrioritizerBase,
|
||||||
|
points: int = 0,
|
||||||
|
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
||||||
|
inputs, grad_outputs, prompts = flat_tensors
|
||||||
|
# Cast inputs & grad outputs to backend dtype
|
||||||
|
inputs = inputs.to(requested_backends[0].dtype)
|
||||||
|
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
||||||
|
|
||||||
|
if prompts is None or is_dummy(prompts):
|
||||||
|
prompts = [DUMMY] * len(requested_backends)
|
||||||
|
else:
|
||||||
|
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||||
|
|
||||||
|
# Run a forward chain to collect intermediate inputs
|
||||||
|
# Note that we do not forward for the last module since we do not need its output
|
||||||
|
inter_inputs = []
|
||||||
|
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
||||||
|
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
||||||
|
if not is_dummy(prompt):
|
||||||
|
inputs[:, : prompt.shape[1]] += prompt
|
||||||
|
inter_inputs.append(inputs)
|
||||||
|
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||||
|
priority = prioritizer.prioritize(
|
||||||
|
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
||||||
|
)
|
||||||
|
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
|
||||||
|
|
||||||
|
assert isinstance(inputs, torch.Tensor)
|
||||||
|
|
||||||
|
if not is_dummy(prompts[-1]):
|
||||||
|
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
|
||||||
|
inter_inputs.append(inputs)
|
||||||
|
|
||||||
|
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
|
||||||
|
grad_prompts_reversed = []
|
||||||
|
# Run a chain of requested backends
|
||||||
|
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
|
||||||
|
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||||
|
priority = prioritizer.prioritize(
|
||||||
|
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
||||||
|
)
|
||||||
|
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
|
||||||
|
|
||||||
|
assert isinstance(grad_outputs, torch.Tensor)
|
||||||
|
if not is_dummy(prompt):
|
||||||
|
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
||||||
|
|
||||||
|
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
||||||
|
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
|
||||||
|
|
||||||
|
|
||||||
|
async def iterate_rpc_inference(
|
||||||
|
requested_uids: Sequence[ExpertUID],
|
||||||
|
requested_backends: Sequence[TransformerBackend],
|
||||||
|
active_adapter: Optional[str],
|
||||||
|
input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
|
||||||
|
cache_handles: Sequence[Sequence[Handle]],
|
||||||
|
*,
|
||||||
|
max_length: int,
|
||||||
|
prioritizer: TaskPrioritizerBase,
|
||||||
|
points: int,
|
||||||
|
quant_type: QuantType,
|
||||||
|
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
|
||||||
|
assert len(cache_handles) == len(requested_backends)
|
||||||
|
|
||||||
|
prefix_length = 0
|
||||||
|
point_per_piece = points / max_length if max_length > 0 else 0.0
|
||||||
|
|
||||||
|
async for request, step_metadata in input_iterator:
|
||||||
|
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
|
||||||
|
batch_size, length_increment, _ = hidden_states.shape
|
||||||
|
|
||||||
|
# Cast inputs to backend dtype
|
||||||
|
hidden_states = hidden_states.to(requested_backends[0].dtype)
|
||||||
|
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
|
||||||
|
|
||||||
|
# parse deep prompts (optional argument)
|
||||||
|
has_prompts = prompts is not None and not is_dummy(prompts)
|
||||||
|
if not has_prompts:
|
||||||
|
prompts = [None] * len(requested_backends)
|
||||||
|
else:
|
||||||
|
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||||
|
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
|
||||||
|
|
||||||
|
if not (len(requested_backends) == len(prompts)):
|
||||||
|
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
|
||||||
|
|
||||||
|
if prefix_length + length_increment > max_length:
|
||||||
|
raise ValueError(
|
||||||
|
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
|
||||||
|
f" exceeds pre-allocated maximum {max_length}"
|
||||||
|
)
|
||||||
|
|
||||||
|
merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS
|
||||||
|
can_merge_pools = batch_size * length_increment <= merge_max_tokens
|
||||||
|
priority = prioritizer.prioritize(
|
||||||
|
hidden_states,
|
||||||
|
hypo_ids,
|
||||||
|
points=point_per_piece,
|
||||||
|
requested_uids=requested_uids,
|
||||||
|
type="short_inference" if can_merge_pools else "inference",
|
||||||
|
)
|
||||||
|
|
||||||
|
# A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.
|
||||||
|
# when user wants to pre-allocate cache or check that server *can* allocate that cache.
|
||||||
|
if hidden_states.numel() > 0:
|
||||||
|
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
|
||||||
|
if can_merge_pools:
|
||||||
|
inference_infos = tuple(
|
||||||
|
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
|
||||||
|
for uid, handles in zip(requested_uids, cache_handles)
|
||||||
|
)
|
||||||
|
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
|
||||||
|
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
|
||||||
|
inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
|
||||||
|
(hidden_states,) = await backend.inference_pool.submit_task(
|
||||||
|
hidden_states, hypo_ids, inference_infos, prompt, priority=priority
|
||||||
|
)
|
||||||
|
|
||||||
|
# serialize and send last layer outputs
|
||||||
|
output_tensors = [
|
||||||
|
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
||||||
|
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
||||||
|
]
|
||||||
|
can_push = not has_prompts
|
||||||
|
yield output_tensors, can_push
|
||||||
|
|
||||||
|
# prepare for next step
|
||||||
|
prefix_length += length_increment
|
@ -0,0 +1,177 @@
|
|||||||
|
"""
|
||||||
|
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
|
||||||
|
If necessary, one can rewrite this to implement a different behavior, such as:
|
||||||
|
- loading files from a local data source (e.g. S3)
|
||||||
|
- load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
|
||||||
|
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
|
||||||
|
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from huggingface_hub import get_hf_file_metadata, hf_hub_url
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from transformers.utils import get_file_from_repo
|
||||||
|
|
||||||
|
from petals.constants import DTYPE_MAP
|
||||||
|
from petals.server.block_utils import resolve_block_dtype
|
||||||
|
from petals.utils.auto_config import AutoDistributedConfig
|
||||||
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||||
|
from petals.utils.hf_auth import always_needs_auth
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_pretrained_block(
|
||||||
|
model_name: str,
|
||||||
|
block_index: int,
|
||||||
|
*,
|
||||||
|
config: Optional[PretrainedConfig] = None,
|
||||||
|
torch_dtype: Union[torch.dtype, str] = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
token: Optional[Union[str, bool]] = None,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
max_disk_space: Optional[int] = None,
|
||||||
|
) -> nn.Module:
|
||||||
|
if config is None:
|
||||||
|
config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=token)
|
||||||
|
if cache_dir is None:
|
||||||
|
cache_dir = DEFAULT_CACHE_DIR
|
||||||
|
|
||||||
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
||||||
|
torch_dtype = resolve_block_dtype(config, torch_dtype)
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
block = config.block_class(config)
|
||||||
|
|
||||||
|
block_prefix = f"{config.block_prefix}.{block_index}."
|
||||||
|
state_dict = _load_state_dict_from_repo(
|
||||||
|
model_name,
|
||||||
|
block_prefix,
|
||||||
|
revision=revision,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
max_disk_space=max_disk_space,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dummy load, check that keys match
|
||||||
|
report = block.load_state_dict(state_dict, strict=True)
|
||||||
|
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
|
||||||
|
|
||||||
|
for param_name, _ in block.named_parameters():
|
||||||
|
assert param_name in state_dict, f"{param_name} not in state dict"
|
||||||
|
param = state_dict[param_name]
|
||||||
|
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||||
|
param = param.to(torch_dtype)
|
||||||
|
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
||||||
|
|
||||||
|
logger.info(f"Loaded {model_name} block {block_index}, {report}")
|
||||||
|
return block
|
||||||
|
|
||||||
|
|
||||||
|
StateDict = Dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state_dict_from_repo(
|
||||||
|
model_name: str,
|
||||||
|
block_prefix: str,
|
||||||
|
*,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
token: Optional[Union[str, bool]] = None,
|
||||||
|
cache_dir: str,
|
||||||
|
max_disk_space: Optional[int] = None,
|
||||||
|
) -> StateDict:
|
||||||
|
if always_needs_auth(model_name) and token is None:
|
||||||
|
token = True
|
||||||
|
|
||||||
|
index_file = get_file_from_repo(
|
||||||
|
model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
|
||||||
|
)
|
||||||
|
if index_file is not None: # Sharded model
|
||||||
|
with open(index_file) as f:
|
||||||
|
index = json.load(f)
|
||||||
|
filenames = {
|
||||||
|
filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
|
||||||
|
}
|
||||||
|
if not filenames:
|
||||||
|
raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
|
||||||
|
else: # Non-sharded model
|
||||||
|
filenames = {"pytorch_model.bin"}
|
||||||
|
logger.debug(f"Loading {block_prefix}* from {filenames}")
|
||||||
|
|
||||||
|
state_dict = {}
|
||||||
|
for filename in filenames:
|
||||||
|
shard_state_dict = _load_state_dict_from_file(
|
||||||
|
model_name,
|
||||||
|
filename,
|
||||||
|
revision=revision,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
max_disk_space=max_disk_space,
|
||||||
|
)
|
||||||
|
shard_state_dict = {
|
||||||
|
param_name[len(block_prefix) :]: param
|
||||||
|
for param_name, param in shard_state_dict.items()
|
||||||
|
if param_name.startswith(block_prefix)
|
||||||
|
} # Remove unused parameters from memory
|
||||||
|
state_dict.update(shard_state_dict)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state_dict_from_file(
|
||||||
|
model_name: str,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
token: Optional[Union[str, bool]] = None,
|
||||||
|
cache_dir: str,
|
||||||
|
max_disk_space: Optional[int] = None,
|
||||||
|
delay: float = 30,
|
||||||
|
) -> StateDict:
|
||||||
|
# First, try to find the weights locally
|
||||||
|
try:
|
||||||
|
with allow_cache_reads(cache_dir):
|
||||||
|
path = get_file_from_repo(
|
||||||
|
model_name,
|
||||||
|
filename,
|
||||||
|
revision=revision,
|
||||||
|
use_auth_token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
if path is not None:
|
||||||
|
return torch.load(path, map_location="cpu")
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
|
||||||
|
|
||||||
|
# If not found, ensure that we have enough disk space to download them (maybe remove something)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with allow_cache_writes(cache_dir):
|
||||||
|
url = hf_hub_url(model_name, filename, revision=revision)
|
||||||
|
file_size = get_hf_file_metadata(url, token=token).size
|
||||||
|
if file_size is not None:
|
||||||
|
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}")
|
||||||
|
|
||||||
|
path = get_file_from_repo(
|
||||||
|
model_name,
|
||||||
|
filename,
|
||||||
|
revision=revision,
|
||||||
|
use_auth_token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
if path is None:
|
||||||
|
raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
|
||||||
|
return torch.load(path, map_location="cpu")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
||||||
|
time.sleep(delay)
|
@ -0,0 +1,164 @@
|
|||||||
|
import asyncio
|
||||||
|
import math
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from concurrent.futures import Future
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from hivemind.dht import DHT, DHTNode
|
||||||
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||||
|
from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
|
||||||
|
from hivemind.proto import dht_pb2
|
||||||
|
from hivemind.utils import get_logger
|
||||||
|
|
||||||
|
from petals.constants import REACHABILITY_API_URL
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:
|
||||||
|
"""verify that your peer is reachable from a (centralized) validator, whether directly or through a relay"""
|
||||||
|
for attempt_no in range(math.floor(wait_time / retry_delay) + 1):
|
||||||
|
try:
|
||||||
|
r = requests.get(f"{REACHABILITY_API_URL}/api/v1/is_reachable/{peer_id}", timeout=10)
|
||||||
|
r.raise_for_status()
|
||||||
|
response = r.json()
|
||||||
|
|
||||||
|
if response["success"]:
|
||||||
|
logger.info("Server is reachable from the Internet. It will appear at https://health.petals.dev soon")
|
||||||
|
return
|
||||||
|
|
||||||
|
if attempt_no == 0:
|
||||||
|
# Usually, libp2p manages to set up relays before we finish loading blocks.
|
||||||
|
# In other cases, we may need to wait for up to `wait_time` seconds before it's done.
|
||||||
|
logger.info("Detected a NAT or a firewall, connecting to libp2p relays. This takes a few minutes")
|
||||||
|
time.sleep(retry_delay)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Skipping reachability check because health.petals.dev is down: {repr(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Server has not become reachable from the Internet:\n\n"
|
||||||
|
f"{response['message']}\n\n"
|
||||||
|
f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
|
||||||
|
f" 1. Choose a specific port for the Petals server, for example, 31337.\n"
|
||||||
|
f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
|
||||||
|
f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
|
||||||
|
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
|
||||||
|
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_direct_reachability(max_peers: int = 5, threshold: float = 0.5, **kwargs) -> Optional[bool]:
|
||||||
|
"""test if your peer is accessible by others in the swarm with the specified network options in **kwargs"""
|
||||||
|
|
||||||
|
async def _check_direct_reachability():
|
||||||
|
target_dht = await DHTNode.create(client_mode=True, **kwargs)
|
||||||
|
try:
|
||||||
|
protocol = ReachabilityProtocol(probe=target_dht.protocol.p2p)
|
||||||
|
async with protocol.serve(target_dht.protocol.p2p):
|
||||||
|
successes = requests = 0
|
||||||
|
for remote_peer in list(target_dht.protocol.routing_table.peer_id_to_uid.keys()):
|
||||||
|
probe_available = await protocol.call_check(remote_peer=remote_peer, check_peer=target_dht.peer_id)
|
||||||
|
if probe_available is None:
|
||||||
|
continue # remote peer failed to check probe
|
||||||
|
successes += probe_available
|
||||||
|
requests += 1
|
||||||
|
if requests >= max_peers:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.debug(f"Direct reachability: {successes}/{requests}")
|
||||||
|
return (successes / requests) >= threshold if requests > 0 else None
|
||||||
|
finally:
|
||||||
|
await target_dht.shutdown()
|
||||||
|
|
||||||
|
return RemoteExpertWorker.run_coroutine(_check_direct_reachability())
|
||||||
|
|
||||||
|
|
||||||
|
STRIPPED_PROBE_ARGS = dict(
|
||||||
|
dht_mode="client", use_relay=False, auto_nat=False, nat_port_map=False, no_listen=True, startup_timeout=60
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReachabilityProtocol(ServicerBase):
|
||||||
|
"""Mini protocol to test if a locally running peer is accessible by other devices in the swarm"""
|
||||||
|
|
||||||
|
def __init__(self, *, probe: Optional[P2P] = None, wait_timeout: float = 5.0):
|
||||||
|
self.probe = probe
|
||||||
|
self.wait_timeout = wait_timeout
|
||||||
|
self._event_loop = self._stop = None
|
||||||
|
|
||||||
|
async def call_check(self, remote_peer: PeerID, *, check_peer: PeerID) -> Optional[bool]:
|
||||||
|
"""Returns True if remote_peer can reach check_peer, False if it cannot, None if it did not respond"""
|
||||||
|
try:
|
||||||
|
request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=check_peer.to_bytes()))
|
||||||
|
timeout = self.wait_timeout if check_peer == remote_peer else self.wait_timeout * 2
|
||||||
|
response = await self.get_stub(self.probe, remote_peer).rpc_check(request, timeout=timeout)
|
||||||
|
logger.debug(f"call_check(remote_peer={remote_peer}, check_peer={check_peer}) -> {response.available}")
|
||||||
|
return response.available
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Requested {remote_peer} to check {check_peer}, but got:", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def rpc_check(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse:
|
||||||
|
"""Help another peer to check its reachability"""
|
||||||
|
response = dht_pb2.PingResponse(available=True)
|
||||||
|
check_peer = PeerID(request.peer.node_id)
|
||||||
|
if check_peer != context.local_id: # remote peer wants us to check someone other than ourselves
|
||||||
|
response.available = await self.call_check(check_peer, check_peer=check_peer) is True
|
||||||
|
logger.info(
|
||||||
|
f"reachability.rpc_check(remote_peer=...{str(context.remote_id)[-6:]}, "
|
||||||
|
f"check_peer=...{str(check_peer)[-6:]}) -> {response.available}"
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def serve(self, p2p: P2P):
|
||||||
|
try:
|
||||||
|
await self.add_p2p_handlers(p2p)
|
||||||
|
yield self
|
||||||
|
finally:
|
||||||
|
await self.remove_p2p_handlers(p2p)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def attach_to_dht(cls, dht: DHT, await_ready: bool = False, **kwargs) -> Optional["ReachabilityProtocol"]:
|
||||||
|
protocol = cls(**kwargs)
|
||||||
|
ready = Future()
|
||||||
|
|
||||||
|
async def _serve_with_probe():
|
||||||
|
try:
|
||||||
|
common_p2p = await dht.replicate_p2p()
|
||||||
|
protocol._event_loop = asyncio.get_event_loop()
|
||||||
|
protocol._stop = asyncio.Event()
|
||||||
|
|
||||||
|
initial_peers = [str(addr) for addr in await common_p2p.get_visible_maddrs(latest=True)]
|
||||||
|
for info in await common_p2p.list_peers():
|
||||||
|
initial_peers.extend(f"{addr}/p2p/{info.peer_id}" for addr in info.addrs)
|
||||||
|
protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
|
||||||
|
|
||||||
|
ready.set_result(True)
|
||||||
|
logger.info("Reachability service started")
|
||||||
|
|
||||||
|
async with protocol.serve(common_p2p):
|
||||||
|
await protocol._stop.wait()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Reachability service failed:", exc_info=True)
|
||||||
|
|
||||||
|
if not ready.done():
|
||||||
|
ready.set_exception(e)
|
||||||
|
finally:
|
||||||
|
if protocol is not None and protocol.probe is not None:
|
||||||
|
await protocol.probe.shutdown()
|
||||||
|
logger.debug("Reachability service shut down")
|
||||||
|
|
||||||
|
threading.Thread(target=partial(asyncio.run, _serve_with_probe()), daemon=True).start()
|
||||||
|
if await_ready:
|
||||||
|
ready.result() # Propagates startup exceptions, if any
|
||||||
|
return protocol
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
if self._event_loop is not None and self._stop is not None:
|
||||||
|
self._event_loop.call_soon_threadsafe(self._stop.set)
|
@ -0,0 +1,6 @@
|
|||||||
|
from petals.utils.auto_config import (
|
||||||
|
AutoDistributedConfig,
|
||||||
|
AutoDistributedModel,
|
||||||
|
AutoDistributedModelForCausalLM,
|
||||||
|
AutoDistributedModelForSequenceClassification,
|
||||||
|
)
|
@ -0,0 +1,65 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Type, Union
|
||||||
|
|
||||||
|
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
|
from petals.utils.hf_auth import always_needs_auth
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ModelClasses:
|
||||||
|
config: Type[PretrainedConfig]
|
||||||
|
model: Optional[Type[PreTrainedModel]] = None
|
||||||
|
model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
|
||||||
|
model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None
|
||||||
|
|
||||||
|
|
||||||
|
_CLASS_MAPPING = {} # Populated by petals.models.* subpackages with register_model_classes()
|
||||||
|
|
||||||
|
|
||||||
|
def register_model_classes(*, config: Type[PretrainedConfig], **kwargs):
|
||||||
|
assert issubclass(config, PretrainedConfig)
|
||||||
|
assert config.model_type not in _CLASS_MAPPING, f"Model type {config.model_type} is already registered"
|
||||||
|
|
||||||
|
_CLASS_MAPPING[config.model_type] = _ModelClasses(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class _AutoDistributedBase:
|
||||||
|
_mapping_field = None # Should be defined in child classes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig:
|
||||||
|
if (
|
||||||
|
always_needs_auth(model_name_or_path)
|
||||||
|
and kwargs.get("token") is None
|
||||||
|
and kwargs.get("use_auth_token") is None
|
||||||
|
):
|
||||||
|
kwargs["use_auth_token"] = True
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs)
|
||||||
|
if config.model_type not in _CLASS_MAPPING:
|
||||||
|
raise ValueError(f"Petals does not support model type {config.model_type}")
|
||||||
|
|
||||||
|
proper_cls = getattr(_CLASS_MAPPING[config.model_type], cls._mapping_field)
|
||||||
|
if proper_cls is None:
|
||||||
|
raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}")
|
||||||
|
|
||||||
|
return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoDistributedConfig(_AutoDistributedBase):
|
||||||
|
_mapping_field = "config"
|
||||||
|
|
||||||
|
|
||||||
|
class AutoDistributedModel(_AutoDistributedBase):
|
||||||
|
_mapping_field = "model"
|
||||||
|
|
||||||
|
|
||||||
|
class AutoDistributedModelForCausalLM(_AutoDistributedBase):
|
||||||
|
_mapping_field = "model_for_causal_lm"
|
||||||
|
|
||||||
|
|
||||||
|
class AutoDistributedModelForSequenceClassification(_AutoDistributedBase):
|
||||||
|
_mapping_field = "model_for_sequence_classification"
|
@ -1,39 +0,0 @@
|
|||||||
import bitsandbytes as bnb
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
|
|
||||||
|
|
||||||
|
|
||||||
def replace_8bit_linear(model, threshold=6.0):
|
|
||||||
"""
|
|
||||||
A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
|
|
||||||
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
|
|
||||||
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
|
|
||||||
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
|
|
||||||
bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
|
|
||||||
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
|
|
||||||
be kept as a `torch.nn.Linear` module.
|
|
||||||
Parameters:
|
|
||||||
model (`torch.nn.Module`):
|
|
||||||
Input model or `torch.nn.Module` as the function is run recursively.
|
|
||||||
threshold (`float`, *optional*):
|
|
||||||
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
|
|
||||||
`6.0` as described by the paper.
|
|
||||||
"""
|
|
||||||
for n, module in model.named_children():
|
|
||||||
if len(list(module.children())) > 0:
|
|
||||||
replace_8bit_linear(module, threshold)
|
|
||||||
|
|
||||||
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
|
|
||||||
model._modules[n] = CustomLinear8bitLt(
|
|
||||||
module.in_features,
|
|
||||||
module.out_features,
|
|
||||||
module.bias is not None,
|
|
||||||
has_fp16_weights=False,
|
|
||||||
threshold=threshold,
|
|
||||||
)
|
|
||||||
model._modules[n].weight = bnb.nn.Int8Params(
|
|
||||||
module.weight.data, requires_grad=False, has_fp16_weights=False
|
|
||||||
).to(module.weight.dtype)
|
|
||||||
model._modules[n].bias = module.bias
|
|
||||||
return model
|
|
@ -0,0 +1,156 @@
|
|||||||
|
"""
|
||||||
|
Tools for converting transformer blocks, applying quantization and/or tensor parallelism
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
import tensor_parallel as tp
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||||
|
from tensor_parallel.slicing_configs import get_bloom_config
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
use_hivemind_log_handler("in_root_logger")
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QuantType(Enum):
|
||||||
|
NONE = 0
|
||||||
|
INT8 = 1 # 8-bit as in the LLM.int8() paper
|
||||||
|
NF4 = 2 # 4-bit as in the QLoRA paper
|
||||||
|
|
||||||
|
|
||||||
|
def convert_block(
|
||||||
|
block: nn.Module,
|
||||||
|
block_index: int,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
tensor_parallel_devices: Sequence[torch.device],
|
||||||
|
output_device: torch.device,
|
||||||
|
quant_type: QuantType,
|
||||||
|
freeze: bool = True,
|
||||||
|
adapters: Optional[Sequence[str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tp.TensorParallel:
|
||||||
|
"""
|
||||||
|
Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
|
||||||
|
|
||||||
|
:note: some optimizations will modify the input block in-place!
|
||||||
|
:param block: a single transformer block, either pre-trained or newly initialized
|
||||||
|
:param config: HF transformers config for the full model
|
||||||
|
:param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices
|
||||||
|
:note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)
|
||||||
|
:param output_device: if tensor_parallel_devices is True, output
|
||||||
|
:param quant_type: quantization type
|
||||||
|
:param freeze: if True (default), make all module parameters non-trainable
|
||||||
|
:return: a module that acts like the original block, but runs with all specified optimizations
|
||||||
|
|
||||||
|
"""
|
||||||
|
if freeze:
|
||||||
|
block.requires_grad_(False)
|
||||||
|
|
||||||
|
block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
|
||||||
|
|
||||||
|
if quant_type != QuantType.NONE:
|
||||||
|
block = quantize_module(block, quant_type=quant_type)
|
||||||
|
|
||||||
|
for shard, device in zip(block.module_shards, block.devices):
|
||||||
|
shard.to(device)
|
||||||
|
|
||||||
|
if adapters:
|
||||||
|
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
|
||||||
|
|
||||||
|
create_lora_adapter(block, quant_type=quant_type)
|
||||||
|
for adapter_name in adapters:
|
||||||
|
adapter_config, adapter_state_dict = load_peft(
|
||||||
|
adapter_name,
|
||||||
|
block_idx=block_index,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict)
|
||||||
|
|
||||||
|
return block
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_module(model: nn.Module, *, quant_type: QuantType) -> nn.Module:
|
||||||
|
# Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
|
for n, module in model.named_children():
|
||||||
|
if len(list(module.children())) > 0:
|
||||||
|
quantize_module(module, quant_type=quant_type)
|
||||||
|
|
||||||
|
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
|
||||||
|
assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}"
|
||||||
|
if quant_type == QuantType.INT8:
|
||||||
|
model._modules[n] = bnb.nn.Linear8bitLt(
|
||||||
|
module.in_features,
|
||||||
|
module.out_features,
|
||||||
|
module.bias is not None,
|
||||||
|
has_fp16_weights=False,
|
||||||
|
threshold=6.0, # Default from the LLM.int8() paper
|
||||||
|
)
|
||||||
|
model._modules[n].weight = bnb.nn.Int8Params(
|
||||||
|
module.weight.data, requires_grad=False, has_fp16_weights=False
|
||||||
|
).to(module.weight.dtype)
|
||||||
|
elif quant_type == QuantType.NF4:
|
||||||
|
compress_statistics = True
|
||||||
|
model._modules[n] = bnb.nn.LinearNF4(
|
||||||
|
module.in_features,
|
||||||
|
module.out_features,
|
||||||
|
module.bias is not None,
|
||||||
|
compress_statistics=compress_statistics,
|
||||||
|
)
|
||||||
|
model._modules[n].weight = bnb.nn.Params4bit(
|
||||||
|
module.weight.data,
|
||||||
|
requires_grad=False,
|
||||||
|
quant_type="nf4",
|
||||||
|
blocksize=64,
|
||||||
|
compress_statistics=compress_statistics,
|
||||||
|
).to(module.weight.dtype)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported quant_type='{quant_type}'")
|
||||||
|
model._modules[n].bias = module.bias
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def make_tensor_parallel(
|
||||||
|
block: nn.Module, model_config: PretrainedConfig, devices: Sequence[torch.device], output_device: torch.device
|
||||||
|
) -> nn.Module:
|
||||||
|
if model_config.model_type == "bloom":
|
||||||
|
tp_config = get_bloom_config(model_config, devices)
|
||||||
|
del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
|
||||||
|
else:
|
||||||
|
if len(devices) > 1:
|
||||||
|
logger.warning("Tensor parallelism is not tested for models other than BLOOM yet, proceed with caution")
|
||||||
|
tp_config = None
|
||||||
|
tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
|
||||||
|
total_heads = 0
|
||||||
|
for tp_shard in tp_block.module_shards:
|
||||||
|
for submodule in tp_shard.modules():
|
||||||
|
if isinstance(submodule, model_config.attn_class):
|
||||||
|
total_heads += submodule.num_heads
|
||||||
|
assert total_heads == model_config.num_attention_heads
|
||||||
|
return tp_block
|
||||||
|
|
||||||
|
|
||||||
|
def check_device_balance(devices: Sequence[torch.device]):
|
||||||
|
if not all(device.type == "cuda" for device in devices):
|
||||||
|
logger.warning("Running tensor parallelism on non-GPU devices; proceed at your own risk")
|
||||||
|
return
|
||||||
|
unique_device_capabilities = set(map(torch.cuda.get_device_capability, devices))
|
||||||
|
if len(unique_device_capabilities) > 1:
|
||||||
|
logger.warning(
|
||||||
|
f"Found GPUs with uneven capabilities: {unique_device_capabilities}. "
|
||||||
|
f"Using GPUs with different performance will cause the server to wait for the slowest GPU."
|
||||||
|
)
|
||||||
|
|
||||||
|
memory_per_device = tuple(torch.cuda.get_device_properties(device).total_memory for device in devices)
|
||||||
|
used_memory = min(memory_per_device) * len(memory_per_device)
|
||||||
|
wasted_memory_rate = (sum(memory_per_device) - used_memory) / sum(memory_per_device)
|
||||||
|
if wasted_memory_rate > 0.05:
|
||||||
|
logger.warning(
|
||||||
|
f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. "
|
||||||
|
f"Consider running high-memory GPUs in a separate server."
|
||||||
|
)
|
@ -0,0 +1,7 @@
|
|||||||
|
import os
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
def always_needs_auth(model_name: Union[str, os.PathLike, None]) -> bool:
|
||||||
|
loading_from_repo = model_name is not None and not os.path.isdir(model_name)
|
||||||
|
return loading_from_repo and model_name.startswith("meta-llama/Llama-2-")
|
@ -1,334 +0,0 @@
|
|||||||
"""
|
|
||||||
A patch to bitsandbytes 0.34.0 that introduces an option to run backward pass in default (fast) matrix layout.
|
|
||||||
Authors: modification by @borzunov, original code by @timdettmers. Please disregard commit authors in this file.
|
|
||||||
|
|
||||||
Core idea: layouts apply the same permutation to every tile in the matrix. We can treat this as (batched) gather ops.
|
|
||||||
Reshape input tensor so that ij-th gather operation op will apply to ij-th elements in each tile.
|
|
||||||
Prototype: https://colab.research.google.com/drive/1EJ0MKifajXSSVq7O2_QGwtb0l6gRAGrh?usp=sharing
|
|
||||||
Based on: https://github.com/TimDettmers/bitsandbytes/blob/main/csrc/kernels.cu#L2130-L2136
|
|
||||||
Exact match tests: see $REPO/tests/test_linear8bitlt.py
|
|
||||||
"""
|
|
||||||
import dataclasses
|
|
||||||
import logging
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import bitsandbytes.functional as F
|
|
||||||
import torch
|
|
||||||
from bitsandbytes.autograd._functions import GlobalOutlierPooler, MatMul8bitLt, MatmulLtState, prod
|
|
||||||
from bitsandbytes.nn import Linear8bitLt
|
|
||||||
|
|
||||||
|
|
||||||
def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]):
|
|
||||||
"""
|
|
||||||
Compute a permutation of indices that invert the specified (tiled) matrix transformation
|
|
||||||
|
|
||||||
:param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2]
|
|
||||||
:param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere
|
|
||||||
:note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size
|
|
||||||
:example: transform_tile function for the turing layout (bitsandbytes.functional as F)
|
|
||||||
:returns: indices
|
|
||||||
"""
|
|
||||||
d1, d2 = tile_size
|
|
||||||
assert 0 < d1 * d2 < 2**64
|
|
||||||
tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2)
|
|
||||||
# encode each position in tile as a tuple of <= 8 unique bytes
|
|
||||||
permuted_tile_indices = torch.zeros_like(tile_indices)
|
|
||||||
for i in range(8):
|
|
||||||
# select i-th byte, apply transformation and trace where each index ended up
|
|
||||||
ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256
|
|
||||||
sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous()
|
|
||||||
assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow"
|
|
||||||
permuted_tile_i = transform_tile(sample_tile_i)
|
|
||||||
ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128
|
|
||||||
permuted_tile_indices += ith_permuted_indices * (256**i)
|
|
||||||
if d1 * d2 < 256**i:
|
|
||||||
break # if all indices fit in i bytes, stop early
|
|
||||||
return permuted_tile_indices
|
|
||||||
|
|
||||||
|
|
||||||
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Undo a tiled permutation such as turing or ampere layout
|
|
||||||
|
|
||||||
:param permuted_tensor: torch tensor in a permuted layout
|
|
||||||
:param tile_indices: reverse transformation indices, from get_inverse_transform_indices
|
|
||||||
:return: contiguous row-major tensor
|
|
||||||
"""
|
|
||||||
(rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
|
|
||||||
assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
|
|
||||||
tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
|
|
||||||
outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda
|
|
||||||
outputs[tile_indices.flatten()] = tensor
|
|
||||||
outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows)
|
|
||||||
outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols)
|
|
||||||
return outputs.reshape(rows, cols).contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
# the rest of this file is just a patch to bitsandbytes that modifies Linear8bitLt and dependencies
|
|
||||||
|
|
||||||
|
|
||||||
class CustomLinear8bitLt(Linear8bitLt):
|
|
||||||
def __init__(self, *args, memory_efficient_backward: bool = False, **kwargs):
|
|
||||||
assert not memory_efficient_backward, "memory_efficient_backward is no longer used"
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
old_state, self.state = self.state, CustomMatmulLtState()
|
|
||||||
self.state.threshold = old_state.threshold
|
|
||||||
self.state.has_fp16_weights = old_state.has_fp16_weights
|
|
||||||
self.state.memory_efficient_backward = old_state.memory_efficient_backward
|
|
||||||
if old_state.threshold > 0.0 and not old_state.has_fp16_weights:
|
|
||||||
self.state.use_pool = True
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
self.state.is_training = self.training
|
|
||||||
if self.weight.CB is not None:
|
|
||||||
self.init_8bit_state()
|
|
||||||
|
|
||||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
|
||||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
|
||||||
self.bias.data = self.bias.data.to(x.dtype)
|
|
||||||
|
|
||||||
out = custom_matmul8bitlt(x, self.weight, bias=self.bias, state=self.state)
|
|
||||||
if not self.state.has_fp16_weights:
|
|
||||||
if self.state.CB is not None and self.state.CxB is not None:
|
|
||||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
|
||||||
# we no longer need the row-major weight
|
|
||||||
del self.state.CB
|
|
||||||
self.weight.data = self.state.CxB
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(init=True)
|
|
||||||
class CustomMatmulLtState(MatmulLtState):
|
|
||||||
tile_indices: Optional[torch.Tensor] = None
|
|
||||||
force_no_igemmlt: bool = False
|
|
||||||
|
|
||||||
def get_tile_size(self):
|
|
||||||
assert self.formatB in (
|
|
||||||
"col_turing",
|
|
||||||
"col_ampere",
|
|
||||||
), f"please find this assert and manually enter tile size for {self.formatB}"
|
|
||||||
return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
|
||||||
|
|
||||||
|
|
||||||
def custom_matmul8bitlt(
|
|
||||||
A: torch.Tensor,
|
|
||||||
B: torch.Tensor,
|
|
||||||
out: torch.Tensor = None,
|
|
||||||
state: CustomMatmulLtState = None,
|
|
||||||
threshold=0.0,
|
|
||||||
bias=None,
|
|
||||||
):
|
|
||||||
state = state or MatmulLtState()
|
|
||||||
if threshold > 0.0:
|
|
||||||
state.threshold = threshold
|
|
||||||
return CustomMatMul8bitLt.apply(A, B, out, bias, state)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomMatMul8bitLt(MatMul8bitLt):
|
|
||||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
|
||||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, A, B, out=None, bias=None, state=CustomMatmulLtState):
|
|
||||||
using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt
|
|
||||||
# default to pytorch behavior if inputs are empty
|
|
||||||
ctx.is_empty = False
|
|
||||||
if prod(A.shape) == 0:
|
|
||||||
ctx.is_empty = True
|
|
||||||
ctx.A = A
|
|
||||||
ctx.B = B
|
|
||||||
ctx.bias = bias
|
|
||||||
if A.shape[-1] == B.shape[0]:
|
|
||||||
return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device)
|
|
||||||
else:
|
|
||||||
return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
|
|
||||||
|
|
||||||
# 1. Quantize A
|
|
||||||
# 2. Quantize B
|
|
||||||
# 3. Matmul
|
|
||||||
# 4. Mixed-precision decomposition matmul
|
|
||||||
# 5. Save state
|
|
||||||
formatB = state.formatB
|
|
||||||
input_shape = A.shape
|
|
||||||
if state.outlier_pool is None:
|
|
||||||
state.outlier_pool = GlobalOutlierPooler.get_instance()
|
|
||||||
|
|
||||||
# Cast A to fp16
|
|
||||||
if A.dtype != torch.float16:
|
|
||||||
logging.debug(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
|
|
||||||
|
|
||||||
# 1. Quantize A
|
|
||||||
if len(A.shape) == 3:
|
|
||||||
A = A.view(-1, A.shape[-1]).contiguous()
|
|
||||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
|
|
||||||
|
|
||||||
if state.threshold > 0.0 and coo_tensorA is not None:
|
|
||||||
if state.has_fp16_weights:
|
|
||||||
idx = torch.unique(coo_tensorA.colidx).long()
|
|
||||||
CA[:, idx] = 0
|
|
||||||
CAt[:, idx] = 0
|
|
||||||
subA = A[:, idx]
|
|
||||||
state.subB = B[:, idx].t().contiguous()
|
|
||||||
state.idx = idx
|
|
||||||
else:
|
|
||||||
if state.CxB is None and using_igemmlt:
|
|
||||||
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
|
||||||
# we also need to convert it to the turing/ampere format
|
|
||||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
|
||||||
else:
|
|
||||||
if not state.has_fp16_weights and state.CxB is None and using_igemmlt:
|
|
||||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
|
||||||
subA = None
|
|
||||||
|
|
||||||
# 2. Quantize B
|
|
||||||
if state.has_fp16_weights:
|
|
||||||
has_grad = True if (getattr(B, "grad", None) is not None) else False
|
|
||||||
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
|
|
||||||
if is_transposed:
|
|
||||||
B = B.contiguous()
|
|
||||||
|
|
||||||
if (state.is_training and not has_grad) or state.CxB is None:
|
|
||||||
state.reset_grads()
|
|
||||||
(
|
|
||||||
CB,
|
|
||||||
state.CBt,
|
|
||||||
state.SCB,
|
|
||||||
state.SCBt,
|
|
||||||
coo_tensorB,
|
|
||||||
) = F.double_quant(B.to(torch.float16))
|
|
||||||
if using_igemmlt:
|
|
||||||
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
|
||||||
else:
|
|
||||||
state.CB = CB
|
|
||||||
else:
|
|
||||||
has_grad = False
|
|
||||||
|
|
||||||
if coo_tensorA is not None and not state.has_fp16_weights:
|
|
||||||
# extract outliers
|
|
||||||
|
|
||||||
outlier_idx = torch.unique(coo_tensorA.colidx)
|
|
||||||
state.idx = outlier_idx
|
|
||||||
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
|
|
||||||
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
|
|
||||||
# # do not use pool for 2nd FFN layer
|
|
||||||
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
|
|
||||||
# else:
|
|
||||||
# state.idx = outlier_idx
|
|
||||||
if state.CxB is not None:
|
|
||||||
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
|
||||||
else:
|
|
||||||
outliers = state.CB[:, state.idx.long()].clone()
|
|
||||||
|
|
||||||
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
|
|
||||||
CA[:, state.idx.long()] = 0
|
|
||||||
CAt[:, state.idx.long()] = 0
|
|
||||||
subA = A[:, state.idx.long()]
|
|
||||||
|
|
||||||
shapeB = state.SB[0] if state.SB else B.shape
|
|
||||||
|
|
||||||
if len(input_shape) == 3:
|
|
||||||
output_shape = (input_shape[0], input_shape[1], shapeB[0])
|
|
||||||
else:
|
|
||||||
output_shape = (input_shape[0], shapeB[0])
|
|
||||||
|
|
||||||
# 3. Matmul
|
|
||||||
if using_igemmlt:
|
|
||||||
C32A, SA = F.transform(CA, "col32")
|
|
||||||
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
|
||||||
if bias is None or bias.dtype == torch.float16:
|
|
||||||
# we apply the fused bias here
|
|
||||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
|
||||||
output = output.to(A.dtype)
|
|
||||||
else: # apply bias separately
|
|
||||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
|
||||||
output = output.to(A.dtype).add_(bias)
|
|
||||||
|
|
||||||
else:
|
|
||||||
A_wo_outliers = A.clone()
|
|
||||||
if state.idx is not None:
|
|
||||||
A_wo_outliers[:, state.idx.long()] = 0
|
|
||||||
output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
|
|
||||||
output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))
|
|
||||||
if bias is not None:
|
|
||||||
output = output.add_(bias)
|
|
||||||
|
|
||||||
# 4. Mixed-precision decomposition matmul
|
|
||||||
if coo_tensorA is not None and subA is not None:
|
|
||||||
output += torch.matmul(subA, state.subB)
|
|
||||||
|
|
||||||
# 5. Save state
|
|
||||||
ctx.state = state
|
|
||||||
|
|
||||||
ctx.formatB = formatB
|
|
||||||
ctx.grad_shape = input_shape
|
|
||||||
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
|
|
||||||
|
|
||||||
if any(ctx.needs_input_grad[:2]):
|
|
||||||
ctx.tensors = (CAt, subA)
|
|
||||||
ctx.tensor_states = (SCAt, state.idx)
|
|
||||||
else:
|
|
||||||
ctx.tensors = [None, None]
|
|
||||||
ctx.tensor_states = (None, None)
|
|
||||||
ctx.save_for_backward(None, None)
|
|
||||||
|
|
||||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
|
|
||||||
return clone_func(output.view(output_shape))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
if ctx.is_empty:
|
|
||||||
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
|
|
||||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
|
||||||
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
|
|
||||||
CAt, subA = ctx.tensors
|
|
||||||
SCAt, idx = ctx.tensor_states
|
|
||||||
formatB = ctx.formatB
|
|
||||||
state = ctx.state
|
|
||||||
grad_A = grad_B = grad_bias = None
|
|
||||||
|
|
||||||
if req_gradBias:
|
|
||||||
# compute grad_bias first before changing grad_output dtype
|
|
||||||
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
|
|
||||||
|
|
||||||
# Cast grad_output to fp16
|
|
||||||
if len(grad_output.shape) == 3:
|
|
||||||
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
|
||||||
|
|
||||||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
|
|
||||||
if req_gradB:
|
|
||||||
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
|
|
||||||
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
|
|
||||||
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
|
|
||||||
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
|
|
||||||
if state.threshold > 0.0 and subA is not None:
|
|
||||||
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
|
|
||||||
|
|
||||||
if req_gradA:
|
|
||||||
if state.CBt is not None:
|
|
||||||
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
|
||||||
if state.CxBt is None:
|
|
||||||
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
|
|
||||||
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
|
||||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
|
|
||||||
|
|
||||||
elif state.CB is not None:
|
|
||||||
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
|
|
||||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
|
||||||
elif state.CxB is not None:
|
|
||||||
|
|
||||||
if state.tile_indices is None:
|
|
||||||
order, tile_size = state.formatB, state.get_tile_size()
|
|
||||||
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device)
|
|
||||||
|
|
||||||
CB = (
|
|
||||||
undo_layout(state.CxB, state.tile_indices)
|
|
||||||
.to(ctx.dtype_A)
|
|
||||||
.mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
|
|
||||||
)
|
|
||||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
|
||||||
else:
|
|
||||||
raise Exception("State must contain either CBt or CB or CxB matrix for backward")
|
|
||||||
|
|
||||||
return grad_A, grad_B, None, grad_bias, None
|
|
@ -0,0 +1,288 @@
|
|||||||
|
import contextlib
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from typing import Optional, Sequence, Union
|
||||||
|
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import transformers
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
|
||||||
|
from peft.tuners import lora
|
||||||
|
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
|
||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from transformers.utils import get_file_from_repo
|
||||||
|
|
||||||
|
from petals.server.block_utils import resolve_block_dtype
|
||||||
|
from petals.utils.convert_block import QuantType
|
||||||
|
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_peft_repository(repo_id: str) -> bool:
|
||||||
|
fs = HfFileSystem()
|
||||||
|
list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
|
||||||
|
return len(list_of_files) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
|
||||||
|
tensors = dict()
|
||||||
|
is_tensors_found = dict()
|
||||||
|
common_layer_patter_re = (
|
||||||
|
".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+"
|
||||||
|
)
|
||||||
|
with safe_open(filepath, framework=framework, device=device) as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if re.match(common_layer_patter_re, k):
|
||||||
|
is_tensors_found[block_idx] = True
|
||||||
|
tensors[k] = f.get_tensor(k)
|
||||||
|
if not is_tensors_found.get(block_idx, False):
|
||||||
|
logger.warning(f"There is no peft weights for block {block_idx}")
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
|
||||||
|
def get_adapter_from_repo(
|
||||||
|
repo_id: str,
|
||||||
|
block_idx: Optional[int] = None,
|
||||||
|
device: Optional[int] = None,
|
||||||
|
*,
|
||||||
|
token: Optional[Union[str, bool]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs)
|
||||||
|
if config_path is None:
|
||||||
|
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
||||||
|
config = PeftConfig.from_json_file(config_path)
|
||||||
|
|
||||||
|
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, use_auth_token=token, **kwargs)
|
||||||
|
if weight_path is None:
|
||||||
|
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
||||||
|
if block_idx is None:
|
||||||
|
return config, load_file(weight_path)
|
||||||
|
return config, load_specific_module(block_idx, weight_path, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def load_peft(
|
||||||
|
repo_id: str,
|
||||||
|
block_idx: Optional[int] = None,
|
||||||
|
device: Optional[int] = None,
|
||||||
|
*,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
token: Optional[Union[str, bool]] = None,
|
||||||
|
cache_dir: str,
|
||||||
|
max_disk_space: Optional[int] = None,
|
||||||
|
delay: float = 30,
|
||||||
|
):
|
||||||
|
# TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
|
||||||
|
|
||||||
|
if not check_peft_repository(repo_id):
|
||||||
|
raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with allow_cache_reads(cache_dir):
|
||||||
|
return get_adapter_from_repo(
|
||||||
|
repo_id,
|
||||||
|
block_idx,
|
||||||
|
device,
|
||||||
|
revision=revision,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with allow_cache_writes(cache_dir):
|
||||||
|
config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
|
||||||
|
config_file_size = get_hf_file_metadata(config_url, token=token).size
|
||||||
|
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
||||||
|
weight_file_size = get_hf_file_metadata(weight_url, token=token).size
|
||||||
|
|
||||||
|
file_size = config_file_size + weight_file_size
|
||||||
|
if file_size is not None:
|
||||||
|
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to fetch size from peft repo {repo_id}")
|
||||||
|
|
||||||
|
return get_adapter_from_repo(
|
||||||
|
repo_id,
|
||||||
|
block_idx,
|
||||||
|
device,
|
||||||
|
revision=revision,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterContextMixin:
|
||||||
|
"""A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""
|
||||||
|
|
||||||
|
ADAPTER_NOT_SET = "__ADAPTER_NOT_SET"
|
||||||
|
_context_active_adapter = ADAPTER_NOT_SET
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def using_adapter(active_adapter: Optional[str]):
|
||||||
|
prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
AdapterContextMixin._context_active_adapter = prev
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_adapter(self):
|
||||||
|
if self._context_active_adapter == self.ADAPTER_NOT_SET:
|
||||||
|
logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug")
|
||||||
|
return self._context_active_adapter
|
||||||
|
|
||||||
|
@active_adapter.setter
|
||||||
|
def active_adapter(self, value: Optional[str]):
|
||||||
|
assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""
|
||||||
|
|
||||||
|
|
||||||
|
using_adapter = AdapterContextMixin.using_adapter
|
||||||
|
|
||||||
|
|
||||||
|
class LoraLinear(lora.Linear, AdapterContextMixin):
|
||||||
|
"""LoRA linear layer that uses adapter selected via using_adapter"""
|
||||||
|
|
||||||
|
|
||||||
|
class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin):
|
||||||
|
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
|
||||||
|
|
||||||
|
|
||||||
|
class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin):
|
||||||
|
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""
|
||||||
|
|
||||||
|
|
||||||
|
def create_lora_adapter(block, quant_type: QuantType):
|
||||||
|
for _, module in block.named_modules():
|
||||||
|
for child_name, child in module.named_children():
|
||||||
|
lora_wrapped_child = None
|
||||||
|
if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
|
||||||
|
continue
|
||||||
|
if quant_type == QuantType.INT8:
|
||||||
|
kwargs = {
|
||||||
|
"has_fp16_weights": False,
|
||||||
|
"threshold": 6.0,
|
||||||
|
"bias": hasattr(child, "bias") and child.bias is not None,
|
||||||
|
}
|
||||||
|
lora_wrapped_child = LoraLinear8bitLt(
|
||||||
|
AdapterContextMixin.ADAPTER_NOT_SET,
|
||||||
|
child.in_features,
|
||||||
|
child.out_features,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif quant_type == QuantType.NF4:
|
||||||
|
kwargs = {
|
||||||
|
"compress_statistics": True,
|
||||||
|
"quant_type": "nf4",
|
||||||
|
"blocksize": 64,
|
||||||
|
"bias": hasattr(child, "bias") and child.bias is not None,
|
||||||
|
}
|
||||||
|
lora_wrapped_child = LoraLinear4bit(
|
||||||
|
AdapterContextMixin.ADAPTER_NOT_SET,
|
||||||
|
child.in_features,
|
||||||
|
child.out_features,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
lora_wrapped_child.compute_dtype = child.compute_dtype
|
||||||
|
else:
|
||||||
|
bias = hasattr(child, "bias") and child.bias is not None
|
||||||
|
lora_wrapped_child = LoraLinear(
|
||||||
|
AdapterContextMixin.ADAPTER_NOT_SET,
|
||||||
|
child.in_features,
|
||||||
|
child.out_features,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
if lora_wrapped_child:
|
||||||
|
lora_wrapped_child.weight = child.weight
|
||||||
|
lora_wrapped_child.bias = child.bias
|
||||||
|
for p in lora_wrapped_child.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
setattr(module, child_name, lora_wrapped_child)
|
||||||
|
|
||||||
|
|
||||||
|
def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
|
||||||
|
assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
|
||||||
|
if peft_config["lora_dropout"] > 0:
|
||||||
|
logger.info(f"Adapter {adapter_name} has dropout enabled, this server will disable dropout")
|
||||||
|
|
||||||
|
for _, module in block.named_modules():
|
||||||
|
for child_name, child in module.named_children():
|
||||||
|
if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if child_name in peft_config["target_modules"] or (
|
||||||
|
isinstance(peft_config["target_modules"], str)
|
||||||
|
and re.fullmatch(peft_config["target_modules"], child_name)
|
||||||
|
):
|
||||||
|
is_lora_a_loaded = False
|
||||||
|
is_lora_b_loaded = False
|
||||||
|
for peft_key in peft_state_dict:
|
||||||
|
if child_name not in peft_key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if adapter_name not in child.lora_A:
|
||||||
|
child.update_layer(
|
||||||
|
adapter_name,
|
||||||
|
peft_config["r"],
|
||||||
|
peft_config["lora_alpha"],
|
||||||
|
lora_dropout=peft_config["lora_dropout"],
|
||||||
|
init_lora_weights=peft_config["init_lora_weights"],
|
||||||
|
)
|
||||||
|
child.train(False)
|
||||||
|
for p in child.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
if peft_key.endswith(".lora_A.weight"):
|
||||||
|
child.lora_A[adapter_name].weight[...] = peft_state_dict[peft_key]
|
||||||
|
is_lora_a_loaded = True
|
||||||
|
elif peft_key.endswith(".lora_A.bias"):
|
||||||
|
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
||||||
|
elif peft_key.endswith(".lora_B.weight"):
|
||||||
|
child.lora_B[adapter_name].weight[...] = peft_state_dict[peft_key]
|
||||||
|
is_lora_b_loaded = True
|
||||||
|
elif peft_key.endswith(".lora_B.bias"):
|
||||||
|
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
||||||
|
|
||||||
|
if is_lora_a_loaded and is_lora_b_loaded:
|
||||||
|
logger.debug(f"Loaded adapter {adapter_name} for block {block_index}.{child_name}")
|
||||||
|
elif is_lora_a_loaded or is_lora_b_loaded:
|
||||||
|
raise ValueError(f"Invalid adapter {adapter_name} for block {block_index}.{child_name}")
|
||||||
|
logger.info(f"Loaded adapter {adapter_name} for block {block_index}")
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_adapter_memory_per_block(
|
||||||
|
block_config: transformers.PretrainedConfig,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
adapters: Sequence[str],
|
||||||
|
**load_peft_kwargs,
|
||||||
|
) -> int:
|
||||||
|
"""Get the number of extra bytes used to store a set of adapters per given block"""
|
||||||
|
with init_empty_weights(include_buffers=True):
|
||||||
|
block = block_config.block_class(block_config)
|
||||||
|
base_block_parameters = sum(p.numel() for p in block.parameters())
|
||||||
|
create_lora_adapter(block, quant_type=QuantType.NONE)
|
||||||
|
|
||||||
|
for adapter in adapters:
|
||||||
|
peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)
|
||||||
|
assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now"
|
||||||
|
add_adapter_to_block(
|
||||||
|
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
|
||||||
|
)
|
||||||
|
adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
|
||||||
|
bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
|
||||||
|
return adapter_parameters * bytes_per_parameter
|
@ -0,0 +1,64 @@
|
|||||||
|
import asyncio
|
||||||
|
import math
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from functools import partial
|
||||||
|
from typing import Dict, Sequence
|
||||||
|
|
||||||
|
import hivemind
|
||||||
|
from hivemind.proto import dht_pb2
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def ping(
|
||||||
|
peer_id: hivemind.PeerID,
|
||||||
|
_dht: hivemind.DHT,
|
||||||
|
node: hivemind.dht.DHTNode,
|
||||||
|
*,
|
||||||
|
wait_timeout: float = 5,
|
||||||
|
) -> float:
|
||||||
|
try:
|
||||||
|
ping_request = dht_pb2.PingRequest(peer=node.protocol.node_info)
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
await node.protocol.get_stub(peer_id).rpc_ping(ping_request, timeout=wait_timeout)
|
||||||
|
return time.perf_counter() - start_time
|
||||||
|
except Exception as e:
|
||||||
|
if str(e) == "protocol not supported": # Happens on servers with client-mode DHT (e.g., reachable via relays)
|
||||||
|
return time.perf_counter() - start_time
|
||||||
|
|
||||||
|
logger.debug(f"Failed to ping {peer_id}:", exc_info=True)
|
||||||
|
return math.inf
|
||||||
|
|
||||||
|
|
||||||
|
async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) -> Dict[hivemind.PeerID, float]:
|
||||||
|
rpc_infos = await asyncio.gather(*[ping(peer_id, *args, **kwargs) for peer_id in peer_ids])
|
||||||
|
return dict(zip(peer_ids, rpc_infos))
|
||||||
|
|
||||||
|
|
||||||
|
class PingAggregator:
|
||||||
|
def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 300):
|
||||||
|
self.dht = dht
|
||||||
|
self.ema_alpha = ema_alpha
|
||||||
|
self.expiration = expiration
|
||||||
|
self.ping_emas = hivemind.TimedStorage()
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs) -> None:
|
||||||
|
current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs))
|
||||||
|
logger.debug(f"Current RTTs: {current_rtts}")
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
expiration = hivemind.get_dht_time() + self.expiration
|
||||||
|
for peer_id, rtt in current_rtts.items():
|
||||||
|
prev_rtt = self.ping_emas.get(peer_id)
|
||||||
|
if prev_rtt is not None and prev_rtt.value != math.inf:
|
||||||
|
rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value # Exponential smoothing
|
||||||
|
self.ping_emas.store(peer_id, rtt, expiration)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[hivemind.PeerID, float]:
|
||||||
|
with self.lock, self.ping_emas.freeze():
|
||||||
|
smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()}
|
||||||
|
logger.debug(f"Smothed RTTs: {smoothed_rtts}")
|
||||||
|
return smoothed_rtts
|
@ -0,0 +1,12 @@
|
|||||||
|
import random
|
||||||
|
from typing import Collection, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def sample_up_to(population: Collection[T], k: int) -> T:
|
||||||
|
if not isinstance(population, list):
|
||||||
|
population = list(population)
|
||||||
|
if len(population) > k:
|
||||||
|
population = random.sample(population, k)
|
||||||
|
return population
|
@ -0,0 +1,44 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from hivemind.utils.logging import TextStyle, get_logger
|
||||||
|
from packaging.version import parse
|
||||||
|
|
||||||
|
import petals
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_version() -> None:
|
||||||
|
logger.info(f"Running {TextStyle.BOLD}Petals {petals.__version__}{TextStyle.RESET}")
|
||||||
|
try:
|
||||||
|
r = requests.get("https://pypi.python.org/pypi/petals/json")
|
||||||
|
r.raise_for_status()
|
||||||
|
response = r.json()
|
||||||
|
|
||||||
|
versions = [parse(ver) for ver in response.get("releases")]
|
||||||
|
latest = max(ver for ver in versions if not ver.is_prerelease)
|
||||||
|
|
||||||
|
if parse(petals.__version__) < latest:
|
||||||
|
logger.info(
|
||||||
|
f"A newer version {latest} is available. Please upgrade with: "
|
||||||
|
f"{TextStyle.BOLD}pip install --upgrade petals{TextStyle.RESET}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to fetch the latest Petals version from PyPI:", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_compatible_model_repo(model_name_or_path: Union[str, os.PathLike, None]) -> Union[str, os.PathLike, None]:
|
||||||
|
if model_name_or_path is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
match = re.fullmatch(r"(bigscience/.+)-petals", str(model_name_or_path))
|
||||||
|
if match is None:
|
||||||
|
return model_name_or_path
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Loading model from {match.group(1)}, since Petals 1.2.0+ uses original repos instead of converted ones"
|
||||||
|
)
|
||||||
|
return match.group(1)
|
@ -1,25 +0,0 @@
|
|||||||
import argparse
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from huggingface_hub import delete_repo, list_models
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Remove old testing models from HF hub")
|
|
||||||
parser.add_argument("--author", type=str, default="bloom-testing", help="auth token for from_pretrained")
|
|
||||||
parser.add_argument("--seconds_since_last_updated", type=int, default=7 * 24 * 60 * 60)
|
|
||||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
|
||||||
parser.add_argument("--dry_run", action="store_true")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
for model in list_models(author=args.author, full=True):
|
|
||||||
last_modified = datetime.strptime(model.lastModified, "%Y-%m-%dT%H:%M:%S.%fZ")
|
|
||||||
|
|
||||||
if model.modelId.endswith("-main") or "/test-" not in model.modelId:
|
|
||||||
continue # remove only test models
|
|
||||||
|
|
||||||
if (datetime.now() - last_modified).total_seconds() > args.seconds_since_last_updated:
|
|
||||||
if args.dry_run:
|
|
||||||
print(f"{model.modelId} can be deleted")
|
|
||||||
else:
|
|
||||||
delete_repo(repo_id=model.modelId, token=args.use_auth_token)
|
|
Binary file not shown.
@ -1,43 +1,43 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
import hivemind
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from test_utils import *
|
|
||||||
|
|
||||||
from petals.bloom.from_pretrained import load_pretrained_block
|
from petals import AutoDistributedConfig, RemoteSequential
|
||||||
from petals.client import DistributedBloomConfig
|
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
|
||||||
from petals.client.remote_sequential import RemoteTransformerBlock
|
from petals.server.from_pretrained import load_pretrained_block
|
||||||
from petals.data_structures import UID_DELIMITER
|
from test_utils import *
|
||||||
from petals.dht_utils import get_remote_module
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
|
def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
|
||||||
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
||||||
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
remote_sequential = RemoteSequential(config)
|
||||||
|
|
||||||
for block_index in random.sample(range(config.n_layer), 3):
|
block_index = random.randint(0, config.num_hidden_layers - 1)
|
||||||
remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config)
|
remote_block = remote_sequential[block_index]
|
||||||
assert isinstance(remote_block, RemoteTransformerBlock)
|
|
||||||
|
|
||||||
inputs = torch.randn(1, 8, config.hidden_size)
|
inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS + 8, config.hidden_size)
|
||||||
outputs_forward = remote_block(inputs)
|
outputs_forward = remote_block(inputs)
|
||||||
|
|
||||||
outputs_inference = []
|
outputs_inference = []
|
||||||
|
with torch.inference_mode():
|
||||||
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
|
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
|
||||||
for i in range(inputs.shape[1]):
|
# Test long inference (unmerged inference pools)
|
||||||
|
outputs_inference.append(sess.step(inputs[:, : MAX_SHORT_INFERENCE_TOKENS + 1, :]))
|
||||||
|
|
||||||
|
# Test short inference (merged inference pools)
|
||||||
|
for i in range(MAX_SHORT_INFERENCE_TOKENS + 1, inputs.shape[1]):
|
||||||
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
|
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
|
||||||
|
|
||||||
# test that max length is respected
|
# test that max length is respected
|
||||||
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
|
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
|
||||||
sess.step(inputs[:, -1:, :])
|
sess.step(inputs[:, -1:, :])
|
||||||
assert "Maximum length exceeded" in repr(exc_info.value)
|
assert "Maximum length exceeded" in repr(exc_info.value)
|
||||||
|
outputs_inference = torch.cat(outputs_inference, dim=1)
|
||||||
|
|
||||||
outputs_inference = torch.cat(outputs_inference, dim=1)
|
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
|
||||||
|
(outputs_local,) = ref_block(inputs)
|
||||||
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
|
|
||||||
(outputs_local,) = ref_block(inputs)
|
|
||||||
|
|
||||||
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
|
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
|
||||||
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
|
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from petals.server.block_utils import resolve_block_dtype
|
||||||
|
from petals.server.from_pretrained import load_pretrained_block
|
||||||
|
from petals.utils.auto_config import AutoDistributedConfig
|
||||||
|
from test_utils import MODEL_NAME
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"])
|
||||||
|
def test_block_dtype(torch_dtype):
|
||||||
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
||||||
|
block = load_pretrained_block(MODEL_NAME, 0, config=config, torch_dtype=torch_dtype)
|
||||||
|
expected_dtype = resolve_block_dtype(config, torch_dtype)
|
||||||
|
assert all(param.dtype == expected_dtype for param in block.parameters())
|
@ -1,108 +0,0 @@
|
|||||||
import bitsandbytes as bnb
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from bitsandbytes import functional as F
|
|
||||||
|
|
||||||
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt, get_inverse_transform_indices, undo_layout
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
|
|
||||||
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
|
|
||||||
)
|
|
||||||
def test_layout_exact_match():
|
|
||||||
x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda()
|
|
||||||
for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"):
|
|
||||||
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
|
|
||||||
tile_indices = get_inverse_transform_indices(transform, tile_size)
|
|
||||||
cxb = transform(x)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
restored_x = undo_layout(cxb, tile_indices)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
assert restored_x.is_contiguous()
|
|
||||||
assert torch.all(torch.eq(restored_x, x))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
|
|
||||||
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
|
|
||||||
)
|
|
||||||
def test_linear_exact_match():
|
|
||||||
linear = torch.nn.Linear(1024, 3072)
|
|
||||||
x = torch.randn(3, 1024, dtype=torch.half)
|
|
||||||
linear8bitlt = bnb.nn.Linear8bitLt(
|
|
||||||
linear.in_features,
|
|
||||||
linear.out_features,
|
|
||||||
linear.bias is not None,
|
|
||||||
has_fp16_weights=False,
|
|
||||||
threshold=6.0,
|
|
||||||
memory_efficient_backward=True,
|
|
||||||
)
|
|
||||||
linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to(
|
|
||||||
linear.weight.dtype
|
|
||||||
)
|
|
||||||
linear8bitlt.bias = linear.bias
|
|
||||||
linear8bitlt.cuda()
|
|
||||||
|
|
||||||
linear_custom = CustomLinear8bitLt(
|
|
||||||
linear.in_features,
|
|
||||||
linear.out_features,
|
|
||||||
linear.bias is not None,
|
|
||||||
has_fp16_weights=False,
|
|
||||||
threshold=6.0,
|
|
||||||
)
|
|
||||||
linear_custom.weight = bnb.nn.Int8Params(
|
|
||||||
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
|
|
||||||
).to(linear.weight.dtype)
|
|
||||||
linear_custom.bias = linear.bias
|
|
||||||
linear_custom.cuda()
|
|
||||||
|
|
||||||
x_ref = x.clone().cuda().requires_grad_(True)
|
|
||||||
x_ours = x.clone().cuda().requires_grad_(True)
|
|
||||||
fx_ref = linear8bitlt(x_ref).float()
|
|
||||||
grad_proj = torch.randn_like(fx_ref)
|
|
||||||
(fx_ref * grad_proj).mean().backward()
|
|
||||||
|
|
||||||
fx_ours = linear_custom(x_ours).float()
|
|
||||||
(fx_ours * grad_proj).mean().backward()
|
|
||||||
assert torch.equal(fx_ref, fx_ours)
|
|
||||||
assert torch.allclose(x_ref.grad, x_ours.grad)
|
|
||||||
assert not linear_custom.state.has_fp16_weights
|
|
||||||
assert linear_custom.state.CB is None
|
|
||||||
assert linear_custom.state.CxB is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
|
||||||
def test_linear_no_igemmlt():
|
|
||||||
linear = torch.nn.Linear(1024, 3072)
|
|
||||||
x = torch.randn(3, 1024, dtype=torch.half)
|
|
||||||
linear_custom = CustomLinear8bitLt(
|
|
||||||
linear.in_features,
|
|
||||||
linear.out_features,
|
|
||||||
linear.bias is not None,
|
|
||||||
has_fp16_weights=False,
|
|
||||||
threshold=6.0,
|
|
||||||
)
|
|
||||||
linear_custom.state.force_no_igemmlt = True
|
|
||||||
|
|
||||||
linear_custom.weight = bnb.nn.Int8Params(
|
|
||||||
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
|
|
||||||
).to(linear.weight.dtype)
|
|
||||||
linear_custom.bias = linear.bias
|
|
||||||
linear_custom.cuda()
|
|
||||||
linear.half().cuda()
|
|
||||||
|
|
||||||
x_ref = x.clone().cuda().requires_grad_(True)
|
|
||||||
x_ours = x.clone().cuda().requires_grad_(True)
|
|
||||||
fx_ref = linear(x_ref).float()
|
|
||||||
grad_proj = torch.randn_like(fx_ref)
|
|
||||||
(fx_ref * grad_proj).mean().backward()
|
|
||||||
|
|
||||||
fx_ours = linear_custom(x_ours).float()
|
|
||||||
(fx_ours * grad_proj).mean().backward()
|
|
||||||
assert torch.allclose(fx_ref, fx_ours, atol=0.02)
|
|
||||||
assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)
|
|
||||||
assert not linear_custom.state.has_fp16_weights
|
|
||||||
assert linear_custom.state.CB is not None
|
|
||||||
assert linear_custom.state.CxB is None
|
|
@ -0,0 +1,66 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from petals.utils.peft import check_peft_repository, load_peft
|
||||||
|
|
||||||
|
UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
|
||||||
|
SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft"
|
||||||
|
TMP_CACHE_DIR = "tmp_cache/"
|
||||||
|
|
||||||
|
|
||||||
|
def clear_dir(path_to_dir):
|
||||||
|
shutil.rmtree(path_to_dir)
|
||||||
|
os.mkdir(path_to_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def dir_empty(path_to_dir):
|
||||||
|
files = os.listdir(path_to_dir)
|
||||||
|
return len(files) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_check_peft():
|
||||||
|
assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load."
|
||||||
|
assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_load_noncached(tmpdir):
|
||||||
|
clear_dir(tmpdir)
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||||
|
|
||||||
|
assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded"
|
||||||
|
|
||||||
|
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||||
|
|
||||||
|
assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_load_cached(tmpdir):
|
||||||
|
clear_dir(tmpdir)
|
||||||
|
snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||||
|
|
||||||
|
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_load_layer_exists(tmpdir):
|
||||||
|
clear_dir(tmpdir)
|
||||||
|
|
||||||
|
load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_load_layer_nonexists(tmpdir):
|
||||||
|
clear_dir(tmpdir)
|
||||||
|
|
||||||
|
load_peft(
|
||||||
|
SAFE_PEFT_REPO,
|
||||||
|
block_idx=1337,
|
||||||
|
cache_dir=tmpdir,
|
||||||
|
)
|
@ -0,0 +1,39 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import hivemind
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from petals import AutoDistributedConfig, RemoteSequential
|
||||||
|
from petals.server.handler import CACHE_TOKENS_AVAILABLE
|
||||||
|
from test_utils import *
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50):
|
||||||
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
||||||
|
config.allowed_servers = ["QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX"] # PeerID from server2.id
|
||||||
|
|
||||||
|
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
||||||
|
blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to)
|
||||||
|
blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to)
|
||||||
|
|
||||||
|
info_before = blocks1.sequence_manager.rpc_info
|
||||||
|
|
||||||
|
with blocks1.inference_session(max_length=max_length) as sess:
|
||||||
|
sess.step(torch.randn(1, 1, config.hidden_size))
|
||||||
|
blocks1.sequence_manager.state.rpc_info = None # invalidate cache
|
||||||
|
info_inside = blocks1.sequence_manager.rpc_info
|
||||||
|
|
||||||
|
with blocks2.inference_session(max_length=max_length2) as sess2:
|
||||||
|
sess2.step(torch.randn(1, 1, config.hidden_size))
|
||||||
|
blocks2.sequence_manager.state.rpc_info = None # invalidate cache
|
||||||
|
info_inside2 = blocks2.sequence_manager.rpc_info
|
||||||
|
|
||||||
|
time.sleep(0.1)
|
||||||
|
blocks1.sequence_manager.state.rpc_info = None # invalidate cache
|
||||||
|
info_after = blocks1.sequence_manager.rpc_info
|
||||||
|
|
||||||
|
assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE]
|
||||||
|
assert info_before[CACHE_TOKENS_AVAILABLE] - info_inside[CACHE_TOKENS_AVAILABLE] == max_length * len(blocks1)
|
||||||
|
assert info_inside[CACHE_TOKENS_AVAILABLE] - info_inside2[CACHE_TOKENS_AVAILABLE] == max_length2 * len(blocks2)
|
@ -0,0 +1,49 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from tensor_parallel import TensorParallel
|
||||||
|
from tensor_parallel.slicing_configs import get_bloom_config
|
||||||
|
|
||||||
|
from petals.server.from_pretrained import load_pretrained_block
|
||||||
|
from test_utils import MODEL_NAME
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
@pytest.mark.parametrize("custom_config", [True, False])
|
||||||
|
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4])
|
||||||
|
def test_tp_block(devices, custom_config):
|
||||||
|
model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
|
||||||
|
if model_config.model_type != "bloom":
|
||||||
|
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
|
||||||
|
|
||||||
|
block_index = random.randint(0, 10)
|
||||||
|
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])
|
||||||
|
|
||||||
|
tp_config = None
|
||||||
|
if custom_config:
|
||||||
|
tp_config = get_bloom_config(model_config, devices)
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
prefix_length = 5
|
||||||
|
|
||||||
|
test_inputs1 = torch.randn(batch_size, 3, 1024, requires_grad=True, device=devices[0])
|
||||||
|
test_inputs2 = test_inputs1.detach().clone().requires_grad_(True)
|
||||||
|
test_prefix1 = torch.randn(batch_size, prefix_length, 1024, requires_grad=True, device=devices[0])
|
||||||
|
test_prefix2 = test_prefix1.detach().clone().requires_grad_(True)
|
||||||
|
grad_proj = torch.rand_like(test_inputs1)
|
||||||
|
|
||||||
|
y_prefix_ref, layer_past = block(test_prefix1, use_cache=True)
|
||||||
|
y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past)
|
||||||
|
y_ref.backward(grad_proj)
|
||||||
|
|
||||||
|
block_tp = TensorParallel(block, devices, config=tp_config)
|
||||||
|
y_prefix, layer_past = block_tp(test_prefix2, use_cache=True)
|
||||||
|
y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
|
||||||
|
y_ours.backward(grad_proj)
|
||||||
|
|
||||||
|
assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-5)
|
||||||
|
assert torch.allclose(y_ours, y_ref, atol=1e-5)
|
||||||
|
assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)
|
||||||
|
assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)
|
Loading…
Reference in New Issue