Merge branch 'main' into repetition-penalty
commit
dd677d9e76
@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from time import perf_counter
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
|
||||
from petals import AutoDistributedModel
|
||||
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--model", type=str, required=True, help="Model")
|
||||
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
|
||||
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
|
||||
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
|
||||
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
|
||||
parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps")
|
||||
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
|
||||
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.n_processes == "n_gpus":
|
||||
args.n_processes = torch.cuda.device_count()
|
||||
else:
|
||||
args.n_processes = int(args.n_processes)
|
||||
|
||||
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
||||
processes = [mp.Process(target=benchmark_forward, args=(i, args, pipe_send)) for i in range(args.n_processes)]
|
||||
for proc in processes:
|
||||
proc.start()
|
||||
for proc in processes:
|
||||
proc.join()
|
||||
|
||||
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
|
||||
logger.info(f"Final result: {speed=:.2f}")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def benchmark_forward(process_idx, args, result_pipe):
|
||||
model = AutoDistributedModel.from_pretrained(
|
||||
args.model,
|
||||
initial_peers=args.initial_peers,
|
||||
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||
)
|
||||
logger.info(f"Created model: {process_idx=} {model.device=}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
step_times = []
|
||||
for step in range(args.warmup_steps + args.n_steps):
|
||||
start_time = perf_counter()
|
||||
|
||||
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len))
|
||||
|
||||
logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
|
||||
h = model(input_ids)
|
||||
# We don't use model.lm_head
|
||||
logger.info(f"{process_idx=} Fwd end")
|
||||
|
||||
if step >= args.warmup_steps:
|
||||
step_times.append(perf_counter() - start_time)
|
||||
speed = input_ids.numel() / np.mean(step_times)
|
||||
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
|
||||
|
||||
result_pipe.send(speed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from time import perf_counter
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from petals import AutoDistributedModelForCausalLM
|
||||
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--model", type=str, required=True, help="Model")
|
||||
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
|
||||
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
|
||||
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
|
||||
parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length")
|
||||
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.n_processes == "n_gpus":
|
||||
args.n_processes = torch.cuda.device_count()
|
||||
else:
|
||||
args.n_processes = int(args.n_processes)
|
||||
|
||||
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
||||
processes = [mp.Process(target=benchmark_inference, args=(i, args, pipe_send)) for i in range(args.n_processes)]
|
||||
for proc in processes:
|
||||
proc.start()
|
||||
for proc in processes:
|
||||
proc.join()
|
||||
|
||||
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
|
||||
logger.info(f"Final result: {speed=:.2f}")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def benchmark_inference(process_idx, args, result_pipe):
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
|
||||
# Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway
|
||||
|
||||
model = AutoDistributedModelForCausalLM.from_pretrained(
|
||||
args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
||||
)
|
||||
logger.info(f"Created model: {process_idx=} {model.device=}")
|
||||
|
||||
result = ""
|
||||
step_times = []
|
||||
with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
|
||||
for step in range(args.seq_len):
|
||||
start_time = perf_counter()
|
||||
|
||||
outputs = model.generate(max_new_tokens=1, session=sess)
|
||||
result += tokenizer.decode(outputs[0])
|
||||
|
||||
if step >= args.warmup_steps:
|
||||
step_times.append(perf_counter() - start_time)
|
||||
speed = 1 / np.mean(step_times)
|
||||
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
|
||||
|
||||
result_pipe.send(speed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from time import perf_counter
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
|
||||
from petals import AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification
|
||||
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--model", type=str, required=True, help="Model")
|
||||
parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client")
|
||||
parser.add_argument("--task", type=str, default="cls", help="Training task type")
|
||||
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
|
||||
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
|
||||
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
|
||||
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
|
||||
parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens")
|
||||
parser.add_argument("--n_steps", type=int, default=10, help="Number of benchmark steps")
|
||||
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
|
||||
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.task in ["cls", "causal_lm"]
|
||||
|
||||
if args.n_processes == "n_gpus":
|
||||
args.n_processes = torch.cuda.device_count()
|
||||
else:
|
||||
args.n_processes = int(args.n_processes)
|
||||
|
||||
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
||||
processes = [mp.Process(target=benchmark_training, args=(i, args, pipe_send)) for i in range(args.n_processes)]
|
||||
for proc in processes:
|
||||
proc.start()
|
||||
for proc in processes:
|
||||
proc.join()
|
||||
|
||||
fwd_speed, bwd_speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)], axis=0)
|
||||
logger.info(f"Final result: {fwd_speed=:.2f} {bwd_speed=:.2f}")
|
||||
|
||||
|
||||
def benchmark_training(process_idx, args, result_pipe):
|
||||
if args.task == "cls":
|
||||
model = AutoDistributedModelForSequenceClassification.from_pretrained(
|
||||
args.model,
|
||||
initial_peers=args.initial_peers,
|
||||
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||
tuning_mode="deep_ptune",
|
||||
pre_seq_len=args.pre_seq_len,
|
||||
num_labels=2,
|
||||
)
|
||||
elif args.task == "causal_lm":
|
||||
model = AutoDistributedModelForCausalLM.from_pretrained(
|
||||
args.model,
|
||||
initial_peers=args.initial_peers,
|
||||
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||
tuning_mode="deep_ptune",
|
||||
pre_seq_len=args.pre_seq_len,
|
||||
)
|
||||
model = model.to(args.device)
|
||||
opt = torch.optim.Adam(model.parameters())
|
||||
logger.info(f"Created model: {process_idx=} {model.device=}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
fwd_times = []
|
||||
bwd_times = []
|
||||
for step in range(args.warmup_steps + args.n_steps):
|
||||
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device)
|
||||
if args.task == "cls":
|
||||
labels = torch.randint(0, 2, size=[args.batch_size], device=args.device)
|
||||
else:
|
||||
labels = input_ids
|
||||
|
||||
logger.info(f"{process_idx=} {step=} Forward")
|
||||
start_time = perf_counter()
|
||||
outputs = model(input_ids, labels=labels)
|
||||
if step >= args.warmup_steps:
|
||||
fwd_times.append(perf_counter() - start_time)
|
||||
|
||||
logger.info(f"{process_idx=} {step=} Backward")
|
||||
start_time = perf_counter()
|
||||
outputs.loss.backward()
|
||||
if step >= args.warmup_steps:
|
||||
bwd_times.append(perf_counter() - start_time)
|
||||
|
||||
logger.info(f"{process_idx=} {step=} Optimizer step")
|
||||
opt.step()
|
||||
opt.zero_grad()
|
||||
|
||||
if step >= args.warmup_steps:
|
||||
fwd_speed = input_ids.numel() / np.mean(fwd_times)
|
||||
bwd_speed = input_ids.numel() / np.mean(bwd_times)
|
||||
logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
|
||||
|
||||
result_pipe.send((fwd_speed, bwd_speed))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,6 +1,29 @@
|
||||
import os
|
||||
|
||||
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
|
||||
|
||||
import hivemind
|
||||
import transformers
|
||||
from packaging import version
|
||||
|
||||
from petals.client import *
|
||||
from petals.models import *
|
||||
from petals.utils import *
|
||||
from petals.utils.logging import initialize_logs as _initialize_logs
|
||||
|
||||
__version__ = "1.0alpha1"
|
||||
__version__ = "2.0.1.post2"
|
||||
|
||||
|
||||
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
|
||||
assert (
|
||||
version.parse("4.31.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
|
||||
), "Please install a proper transformers version: pip install transformers>=4.31.0,<5.0.0"
|
||||
|
||||
|
||||
def _override_bfloat16_mode_default():
|
||||
if os.getenv("USE_LEGACY_BFLOAT16") is None:
|
||||
hivemind.compression.base.USE_LEGACY_BFLOAT16 = False
|
||||
|
||||
|
||||
_initialize_logs()
|
||||
_override_bfloat16_mode_default()
|
||||
|
@ -1,59 +0,0 @@
|
||||
"""
|
||||
Bloom intermediate layer
|
||||
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
||||
See commit history for authorship.
|
||||
"""
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch.nn.quantized.dynamic.modules.linear
|
||||
import transformers
|
||||
from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor
|
||||
|
||||
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
|
||||
assert transformers.__version__.startswith("4.25."), "Please install transformers 4.25.1"
|
||||
|
||||
|
||||
class WrappedBloomBlock(BloomBlock):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
alibi: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs
|
||||
):
|
||||
assert attention_mask is None
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
|
||||
seq_length_with_past = seq_length + past_length
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
if alibi is None:
|
||||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||
attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
||||
return super().forward(
|
||||
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
||||
)
|
||||
|
||||
def _prepare_attn_mask(
|
||||
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
# create causal mask
|
||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
||||
combined_attention_mask = None
|
||||
device = attention_mask.device
|
||||
_, src_length = input_shape
|
||||
|
||||
if src_length > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
torch.Size(input_shape), device=device, past_key_values_length=past_key_values_length
|
||||
)
|
||||
|
||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
@ -1,125 +0,0 @@
|
||||
"""
|
||||
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
|
||||
If necessary, one can rewrite this to implement a different behavior, such as:
|
||||
- loading files from a local data source (e.g. S3)
|
||||
- load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
|
||||
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import time
|
||||
from typing import Optional, OrderedDict, Union
|
||||
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_utils import WEIGHTS_NAME
|
||||
from transformers.models.bloom.configuration_bloom import BloomConfig
|
||||
from transformers.utils import get_file_from_repo
|
||||
|
||||
from petals.bloom.block import WrappedBloomBlock
|
||||
from petals.server.block_utils import get_block_size
|
||||
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
CLIENT_BRANCH = "main"
|
||||
BLOCK_BRANCH_PREFIX = "block_"
|
||||
|
||||
|
||||
def load_pretrained_block(
|
||||
converted_model_name_or_path: str,
|
||||
block_index: int,
|
||||
config: Optional[BloomConfig] = None,
|
||||
torch_dtype: Union[torch.dtype, str] = "auto",
|
||||
use_auth_token: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
max_disk_space: Optional[int] = None,
|
||||
) -> WrappedBloomBlock:
|
||||
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
|
||||
|
||||
if config is None:
|
||||
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
||||
if cache_dir is None:
|
||||
cache_dir = DEFAULT_CACHE_DIR
|
||||
|
||||
block = WrappedBloomBlock(config)
|
||||
state_dict = _load_state_dict(
|
||||
converted_model_name_or_path,
|
||||
block_index,
|
||||
config,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
max_disk_space=max_disk_space,
|
||||
)
|
||||
|
||||
if torch_dtype == "auto":
|
||||
with torch.no_grad():
|
||||
for name, param in block.named_parameters():
|
||||
assert name in state_dict, f"{name} not in state dict"
|
||||
param.data = param.data.to(state_dict[name].dtype)
|
||||
else:
|
||||
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
||||
block = block.to(dtype=torch_dtype)
|
||||
|
||||
report = block.load_state_dict(state_dict, strict=True)
|
||||
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
|
||||
return block
|
||||
|
||||
|
||||
def _load_state_dict(
|
||||
pretrained_model_name_or_path: str,
|
||||
block_index: int,
|
||||
config: BloomConfig,
|
||||
*,
|
||||
use_auth_token: Optional[str] = None,
|
||||
cache_dir: str,
|
||||
max_disk_space: Optional[int] = None,
|
||||
min_backoff: float = 5,
|
||||
) -> OrderedDict[str, torch.Tensor]:
|
||||
revision = BLOCK_BRANCH_PREFIX + str(block_index)
|
||||
|
||||
# First, try to find the weights locally
|
||||
try:
|
||||
with allow_cache_reads(cache_dir):
|
||||
archive_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path,
|
||||
filename=WEIGHTS_NAME,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=True,
|
||||
)
|
||||
if archive_file is not None:
|
||||
return torch.load(archive_file, map_location="cpu")
|
||||
except Exception:
|
||||
logger.debug(
|
||||
f"Failed to load block {block_index} from cache. The block will be downloaded again", exc_info=True
|
||||
)
|
||||
|
||||
# If not found, ensure that we have enough disk space to download them (maybe remove something)
|
||||
for attempt_no in itertools.count():
|
||||
try:
|
||||
with allow_cache_writes(cache_dir):
|
||||
block_size = get_block_size(config, "disk")
|
||||
free_disk_space_for(
|
||||
pretrained_model_name_or_path, block_size, cache_dir=cache_dir, max_disk_space=max_disk_space
|
||||
)
|
||||
|
||||
archive_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path,
|
||||
filename=WEIGHTS_NAME,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
return torch.load(archive_file, map_location="cpu")
|
||||
except Exception as e:
|
||||
delay = min_backoff * (2**attempt_no)
|
||||
logger.warning(f"Failed to load block {block_index} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
@ -1,72 +0,0 @@
|
||||
"""
|
||||
PyTorch BLOOM model that implements several memory-efficient modes.
|
||||
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
||||
See commit history for authorship.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from hivemind import get_logger
|
||||
from torch import nn
|
||||
from transformers import BloomConfig
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class LMHead(nn.Module):
|
||||
"""
|
||||
The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
|
||||
embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
|
||||
In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
|
||||
"""
|
||||
|
||||
def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
|
||||
super().__init__()
|
||||
self.word_embeddings = word_embeddings
|
||||
self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
|
||||
|
||||
@property
|
||||
def in_features(self) -> int:
|
||||
return self.word_embeddings.num_embeddings
|
||||
|
||||
@property
|
||||
def out_features(self) -> int:
|
||||
return self.word_embeddings.embedding_dim
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.word_embeddings.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
word_embeddings = self.word_embeddings.weight
|
||||
|
||||
# We use 'chunked_forward' only when embeddings are in half-precision on CPU.
|
||||
if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
|
||||
lm_logits = self.chunked_forward(hidden_states)
|
||||
else:
|
||||
# Switch dtype in case word_embeddings are fp16/bf16
|
||||
hidden_states = hidden_states.to(word_embeddings.dtype)
|
||||
lm_logits = F.linear(hidden_states, word_embeddings)
|
||||
return lm_logits
|
||||
|
||||
def chunked_forward(self, hidden_states):
|
||||
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
|
||||
chunk_size: provides trade-off between efficiency and extra memory consumption.
|
||||
"""
|
||||
assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
|
||||
|
||||
word_embeddings = self.word_embeddings.weight
|
||||
num_embeddings = self.word_embeddings.num_embeddings
|
||||
|
||||
hidden_states = hidden_states.float()
|
||||
output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
|
||||
|
||||
for i in range(0, num_embeddings, self.chunk_size):
|
||||
chunk = word_embeddings[i : i + self.chunk_size].float()
|
||||
output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
|
||||
return output
|
@ -1,20 +0,0 @@
|
||||
{
|
||||
"apply_residual_connection_post_layernorm": false,
|
||||
"attention_dropout": 0.0,
|
||||
"attention_softmax_in_fp32": true,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_dropout": 0.0,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"masked_softmax_fusion": true,
|
||||
"model_type": "bloom",
|
||||
"n_embed": 14336,
|
||||
"n_layer": 70,
|
||||
"num_attention_heads": 112,
|
||||
"pretraining_tp": 4,
|
||||
"slow_but_exact": false,
|
||||
"transformers_version": "4.20.0.dev0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 250880
|
||||
}
|
@ -1,92 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import psutil
|
||||
import torch.backends.quantized
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from hivemind.utils.logging import get_logger
|
||||
from huggingface_hub import Repository
|
||||
from tqdm.auto import tqdm
|
||||
from transformers.models.bloom.modeling_bloom import BloomModel
|
||||
|
||||
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
|
||||
from petals.client import DistributedBloomConfig
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
|
||||
|
||||
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
|
||||
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
|
||||
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
||||
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
|
||||
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
|
||||
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
|
||||
parser.add_argument(
|
||||
"--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
|
||||
)
|
||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
||||
parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size")
|
||||
args = parser.parse_args()
|
||||
|
||||
free_ram_gb = psutil.virtual_memory().available / 2**30
|
||||
if args.model == "bigscience/bloom" and free_ram_gb < 400:
|
||||
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
|
||||
|
||||
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
|
||||
if os.path.exists(args.output_path) and (
|
||||
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
|
||||
):
|
||||
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
|
||||
|
||||
logger.info(f"Loading source model {args.model} (this may take a few minutes)")
|
||||
config = DistributedBloomConfig.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
||||
)
|
||||
config.dht_prefix = args.output_repo
|
||||
|
||||
model = BloomModel.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
||||
)
|
||||
if args.resize_token_embeddings:
|
||||
logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
|
||||
model.resize_token_embeddings(args.resize_token_embeddings)
|
||||
config.vocab_size = args.resize_token_embeddings
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
||||
)
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
|
||||
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
||||
repo.git_pull()
|
||||
|
||||
transformer_blocks = model.h
|
||||
logger.info(
|
||||
f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
|
||||
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
|
||||
)
|
||||
for i, block in enumerate(tqdm(transformer_blocks)):
|
||||
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
||||
with repo.commit(
|
||||
commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
|
||||
):
|
||||
torch.save(block.state_dict(), "./pytorch_model.bin")
|
||||
|
||||
logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
|
||||
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
||||
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
||||
model.h = nn.ModuleList()
|
||||
model.save_pretrained(".")
|
||||
tokenizer.save_pretrained(".")
|
||||
config.save_pretrained(".")
|
||||
|
||||
logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
|
@ -1,79 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
#################
|
||||
# Parse options #
|
||||
#################
|
||||
|
||||
instructions() {
|
||||
echo "Usage: $0 [-m] [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
|
||||
echo " -m: model name"
|
||||
echo " -i: initial peer"
|
||||
echo " -d: device" >&2
|
||||
echo " -p: server identity path" >&2
|
||||
echo " -b: block_ids" >&2
|
||||
echo " -a: host maddrs" >&2
|
||||
echo " -t: whether to run local tests" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [ ! $# -ge 8 ]; then
|
||||
instructions
|
||||
fi
|
||||
|
||||
while getopts ":m:i:d:p:b:a:t:" option; do
|
||||
case $option in
|
||||
m) MODEL_NAME=${OPTARG}
|
||||
;;
|
||||
i) INITIAL_PEER=${OPTARG}
|
||||
;;
|
||||
d) DEVICE=${OPTARG}
|
||||
;;
|
||||
p) SERVER_ID_PATH=${OPTARG}
|
||||
;;
|
||||
b) BLOCK_IDS=${OPTARG}
|
||||
;;
|
||||
a) HOST_MADDR=${OPTARG} # TODO: allow several maddrs
|
||||
;;
|
||||
t) RUN_LOCAL_TESTS=true
|
||||
;;
|
||||
\?) instructions
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
echo "=========="
|
||||
echo "= Config ="
|
||||
echo "=========="
|
||||
echo "Model name: ${MODEL_NAME}"
|
||||
echo "Initial peer: ${INITIAL_PEER}"
|
||||
echo "Device: ${DEVICE}"
|
||||
echo "Server name: ${SERVER_ID_PATH}"
|
||||
echo "Server address: ${HOST_MADDR}"
|
||||
echo "Bloom blocks: ${BLOCK_IDS}"
|
||||
|
||||
|
||||
###########################
|
||||
# Install or activate env #
|
||||
###########################
|
||||
|
||||
# TODO fix bug with self calling
|
||||
source ~/miniconda3/etc/profile.d/conda.sh
|
||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
||||
conda activate bloom-demo
|
||||
else
|
||||
conda create -y --name bloom-demo python=3.8.12 pip
|
||||
conda activate bloom-demo
|
||||
|
||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install -i https://pypi.org/simple -r .
|
||||
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
|
||||
fi
|
||||
|
||||
##############
|
||||
# Run server #
|
||||
##############
|
||||
|
||||
python -m petals.cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
|
||||
--block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log
|
@ -1,51 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
from tqdm.auto import trange
|
||||
from transformers import BloomConfig
|
||||
from transformers.models.bloom.modeling_bloom import build_alibi_tensor
|
||||
|
||||
from petals.bloom.block import BloomBlock
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
logger.warning("inference_one_block will soon be deprecated in favour of tests!")
|
||||
|
||||
|
||||
def print_device_info(device=None):
|
||||
"""Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
|
||||
device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Additional Info when using cuda
|
||||
if device.type == "cuda":
|
||||
logger.info(torch.cuda.get_device_name(0))
|
||||
logger.info(f"Memory Usage:")
|
||||
logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
|
||||
logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
|
||||
parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
|
||||
parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
|
||||
parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
|
||||
parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.device is None:
|
||||
args.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
config = BloomConfig.from_json_file(args.config)
|
||||
block = BloomBlock(config).to(args.device)
|
||||
|
||||
cache = None
|
||||
|
||||
for i in trange(args.num_steps):
|
||||
dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device)
|
||||
alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device)
|
||||
with torch.no_grad():
|
||||
outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
|
||||
|
||||
print_device_info(args.device)
|
@ -1,5 +0,0 @@
|
||||
device=cpu
|
||||
block_ids=2:3
|
||||
id_path=./server.id
|
||||
maddr=/ip4/127.0.0.1/tcp/30000
|
||||
#
|
@ -1,6 +0,0 @@
|
||||
name=bloom-peer-0.bloom.net
|
||||
device=cpu
|
||||
block_ids=1:3
|
||||
id_path=./server.id
|
||||
maddr=/ip4/0.0.0.0/tcp/30000
|
||||
#
|
@ -0,0 +1,106 @@
|
||||
"""
|
||||
A copy of run_dht.py from hivemind with the ReachabilityProtocol added:
|
||||
https://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_dht.py
|
||||
|
||||
This script may be used for launching lightweight CPU machines serving as bootstrap nodes to a Petals swarm.
|
||||
|
||||
This may be eventually merged to the hivemind upstream.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from secrets import token_hex
|
||||
|
||||
from hivemind.dht import DHT, DHTNode
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
from hivemind.utils.networking import log_visible_maddrs
|
||||
|
||||
from petals.server.reachability import ReachabilityProtocol
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def report_status(dht: DHT, node: DHTNode):
|
||||
logger.info(
|
||||
f"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) "
|
||||
f"are in the local routing table "
|
||||
)
|
||||
logger.debug(f"Routing table contents: {node.protocol.routing_table}")
|
||||
logger.info(f"Local storage contains {len(node.protocol.storage)} keys")
|
||||
logger.debug(f"Local storage contents: {node.protocol.storage}")
|
||||
|
||||
# Contact peers and keep the routing table healthy (remove stale PeerIDs)
|
||||
await node.get(f"heartbeat_{token_hex(16)}", latest=True)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument(
|
||||
"--initial_peers",
|
||||
nargs="*",
|
||||
help="Multiaddrs of the peers that will welcome you into the existing DHT. "
|
||||
"Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host_maddrs",
|
||||
nargs="*",
|
||||
default=["/ip4/0.0.0.0/tcp/0", "/ip6/::/tcp/0"],
|
||||
help="Multiaddrs to listen for external connections from other DHT instances. "
|
||||
"Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--announce_maddrs",
|
||||
nargs="*",
|
||||
help="Visible multiaddrs the host announces for external connections from other DHT instances",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_ipfs",
|
||||
action="store_true",
|
||||
help='Use IPFS to find initial_peers. If enabled, you only need to provide the "/p2p/XXXX" '
|
||||
"part of the multiaddrs for the initial_peers "
|
||||
"(no need to specify a particular IPv4/IPv6 host and port)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--identity_path",
|
||||
help="Path to a private key file. If defined, makes the peer ID deterministic. "
|
||||
"If the file does not exist, writes a new private key to this file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_relay",
|
||||
action="store_false",
|
||||
dest="use_relay",
|
||||
help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_auto_relay",
|
||||
action="store_true",
|
||||
help="Look for libp2p relays to become reachable if we are behind NAT/firewall",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dht = DHT(
|
||||
start=True,
|
||||
initial_peers=args.initial_peers,
|
||||
host_maddrs=args.host_maddrs,
|
||||
announce_maddrs=args.announce_maddrs,
|
||||
use_ipfs=args.use_ipfs,
|
||||
identity_path=args.identity_path,
|
||||
use_relay=args.use_relay,
|
||||
use_auto_relay=args.use_auto_relay,
|
||||
)
|
||||
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)
|
||||
|
||||
reachability_protocol = ReachabilityProtocol.attach_to_dht(dht, await_ready=True)
|
||||
|
||||
while True:
|
||||
dht.run_coroutine(report_status, return_future=False)
|
||||
time.sleep(args.refresh_period)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,109 +0,0 @@
|
||||
# !/usr/bin/env bash
|
||||
|
||||
#################
|
||||
# Parse options #
|
||||
#################
|
||||
|
||||
instructions() {
|
||||
echo "Usage: $0 [-n] [-c]" >&2
|
||||
echo " -n: number of servers to run" >&2
|
||||
echo " -c: path to the server configs" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [ $# != 4 ]; then
|
||||
instructions
|
||||
fi
|
||||
|
||||
while getopts ":n:c:t:" option; do
|
||||
case $option in
|
||||
n) NUM_SERVERS=${OPTARG}
|
||||
;;
|
||||
c) CONFIG_PATH=${OPTARG}
|
||||
;;
|
||||
\?) instructions
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
###########################
|
||||
# Install or activate env #
|
||||
###########################
|
||||
|
||||
source ~/miniconda3/etc/profile.d/conda.sh
|
||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
||||
conda activate bloom-demo
|
||||
else
|
||||
conda create -y --name bloom-demo python=3.8.12 pip
|
||||
conda activate bloom-demo
|
||||
|
||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install -i https://pypi.org/simple -r .
|
||||
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
|
||||
fi
|
||||
|
||||
|
||||
#######################
|
||||
# Create Initial peer #
|
||||
#######################
|
||||
|
||||
hivemind-dht &> tmp.out &
|
||||
sleep 5
|
||||
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
|
||||
echo "Initial peer: ${INITIAL_PEER}"
|
||||
|
||||
|
||||
##############################
|
||||
# Initialize the config file #
|
||||
##############################
|
||||
|
||||
typeset -A cfg
|
||||
cfg=( # set default values in config array
|
||||
[device]="cpu"
|
||||
[block_ids]="1:2"
|
||||
[id_path]="server.id"
|
||||
[maddr]="/ip4/127.0.0.1/tcp/30000"
|
||||
)
|
||||
|
||||
###############
|
||||
# Run servers #
|
||||
###############
|
||||
|
||||
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
|
||||
do
|
||||
###############
|
||||
# Read config #
|
||||
###############
|
||||
|
||||
while read line
|
||||
do
|
||||
if echo $line | grep -F = &>/dev/null
|
||||
then
|
||||
varname=$(echo "$line" | cut -d '=' -f 1)
|
||||
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
|
||||
fi
|
||||
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
|
||||
|
||||
echo "=== Server #${SERVER_ID} ==="
|
||||
echo "Server ID: ${cfg[id_path]}"
|
||||
echo "Device: ${cfg[device]}"
|
||||
echo "Bloom block ids: ${cfg[block_ids]}"
|
||||
echo "Host maddr: ${cfg[maddr]}"
|
||||
echo ""
|
||||
|
||||
##############
|
||||
# Run server #
|
||||
##############
|
||||
|
||||
tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -m "bigscience/test-bloomd" -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
|
||||
done
|
||||
|
||||
#####################
|
||||
# Kill initial peer #
|
||||
#####################
|
||||
|
||||
sleep 10
|
||||
pkill -f hivemind-dht # TODO: kill only particular pids of hivemind-dht
|
||||
rm tmp.out
|
@ -1,110 +0,0 @@
|
||||
# !/usr/bin/env bash
|
||||
|
||||
SSH_KEY_PATH="~/.ssh/<YOUR_KEY>"
|
||||
|
||||
#################
|
||||
# Parse options #
|
||||
#################
|
||||
|
||||
instructions() {
|
||||
echo "Usage: $0 [-u] [-n] [-c]" >&2
|
||||
echo " -u: username" >&2
|
||||
echo " -n: number of servers to run" >&2
|
||||
echo " -c: path to the server configs" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [ $# != 6 ]; then
|
||||
instructions
|
||||
fi
|
||||
|
||||
while getopts ":u:n:c:" option; do
|
||||
case $option in
|
||||
u) USERNAME=${OPTARG}
|
||||
;;
|
||||
n) NUM_SERVERS=${OPTARG}
|
||||
;;
|
||||
c) CONFIG_PATH=${OPTARG}
|
||||
;;
|
||||
\?) instructions
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
###########################
|
||||
# Install or activate env #
|
||||
###########################
|
||||
|
||||
source ~/miniconda3/etc/profile.d/conda.sh
|
||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
||||
conda activate bloom-demo
|
||||
else
|
||||
conda create -y --name bloom-demo python=3.8.12 pip
|
||||
conda activate bloom-demo
|
||||
|
||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install -i https://pypi.org/simple -r .
|
||||
fi
|
||||
|
||||
|
||||
#######################
|
||||
# Create Initial peer #
|
||||
#######################
|
||||
|
||||
hivemind-dht &> tmp.out &
|
||||
|
||||
sleep 5
|
||||
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
|
||||
rm tmp.out
|
||||
echo "Initial peer: ${INITIAL_PEER}"
|
||||
|
||||
|
||||
##############################
|
||||
# Initialize the config file #
|
||||
##############################
|
||||
|
||||
typeset -A cfg
|
||||
cfg=( # set default values in config array
|
||||
[name]=""
|
||||
[device]="cpu"
|
||||
[block_ids]="1:2"
|
||||
[id_path]="server.id"
|
||||
[maddr]="/ip4/0.0.0.0/tcp/30000"
|
||||
)
|
||||
|
||||
###############
|
||||
# Run servers #
|
||||
###############
|
||||
|
||||
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
|
||||
do
|
||||
###############
|
||||
# Read config #
|
||||
###############
|
||||
|
||||
while read line
|
||||
do
|
||||
if echo $line | grep -F = &>/dev/null
|
||||
then
|
||||
varname=$(echo "$line" | cut -d '=' -f 1)
|
||||
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
|
||||
fi
|
||||
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
|
||||
|
||||
SERVER_NAME="${USERNAME}@${cfg[name]}"
|
||||
echo "=== Server #${SERVER_ID} ==="
|
||||
echo "Server name ${SERVER_NAME}"
|
||||
echo "Server ID: ${cfg[id_path]}"
|
||||
echo "Device: ${cfg[device]}"
|
||||
echo "Bloom block ids: ${cfg[block_ids]}"
|
||||
echo "Host maddr: ${cfg[maddr]}"
|
||||
echo "================="
|
||||
|
||||
##############
|
||||
# Run server #
|
||||
##############
|
||||
|
||||
ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'"
|
||||
done
|
@ -1,10 +1,4 @@
|
||||
from petals.client.inference_session import InferenceSession
|
||||
from petals.client.remote_model import (
|
||||
DistributedBloomConfig,
|
||||
DistributedBloomForCausalLM,
|
||||
DistributedBloomForSequenceClassification,
|
||||
DistributedBloomModel,
|
||||
)
|
||||
from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.client.routing.sequence_manager import RemoteSequenceManager
|
||||
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||
|
@ -0,0 +1,94 @@
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import threading
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers import BloomPreTrainedModel, modeling_utils
|
||||
|
||||
from petals.utils.version import get_compatible_model_repo
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FromPretrainedMixin:
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_name_or_path: Union[str, os.PathLike, None],
|
||||
*args,
|
||||
low_cpu_mem_usage: Optional[bool] = None,
|
||||
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
model_name_or_path = get_compatible_model_repo(model_name_or_path)
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
if torch_dtype is None:
|
||||
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
|
||||
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
|
||||
torch_dtype = "auto"
|
||||
|
||||
with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
|
||||
return super().from_pretrained(
|
||||
model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
|
||||
)
|
||||
|
||||
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
||||
"low_cpu_mem_usage(`bool`, *optional*)",
|
||||
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
||||
).replace(
|
||||
"torch_dtype (`str` or `torch.dtype`, *optional*)",
|
||||
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
|
||||
)
|
||||
|
||||
|
||||
_shard_config = threading.local()
|
||||
_shard_config.ignored_keys = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ignore_keys(patterns: List[str]):
|
||||
try:
|
||||
prev_patterns = _shard_config.ignored_keys
|
||||
_shard_config.ignored_keys = patterns
|
||||
yield
|
||||
finally:
|
||||
_shard_config.ignored_keys = prev_patterns
|
||||
|
||||
|
||||
def patched_get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path, index_filename, *args, **kwargs
|
||||
) -> Tuple[List[str], dict]:
|
||||
"""Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
|
||||
|
||||
should_ignore_keys = _shard_config.ignored_keys is not None
|
||||
tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
|
||||
with tempdir_ctx as tempdir:
|
||||
if should_ignore_keys:
|
||||
with open(index_filename) as f:
|
||||
index = json.load(f)
|
||||
n_original_shards = len(set(index["weight_map"].values()))
|
||||
|
||||
index["weight_map"] = {
|
||||
param_name: filename
|
||||
for param_name, filename in index["weight_map"].items()
|
||||
if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys)
|
||||
}
|
||||
n_loaded_shards = len(set(index["weight_map"].values()))
|
||||
logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")
|
||||
|
||||
# Replace the original index with a patched JSON, where ignored keys are removed
|
||||
index_filename = os.path.join(tempdir, "pytorch_model.bin.index.json")
|
||||
with open(index_filename, "w") as f:
|
||||
json.dump(index, f)
|
||||
|
||||
return original_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
|
||||
|
||||
original_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files
|
@ -0,0 +1,84 @@
|
||||
import dataclasses
|
||||
import platform
|
||||
from typing import Optional, Union
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from hivemind import get_logger
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LMHeadConfig:
|
||||
# This settings matter for running the client with dtype bfloat16 on CPU.
|
||||
# If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
|
||||
use_chunked_forward: Union[str, bool] = "auto"
|
||||
chunked_forward_step: int = 16384
|
||||
|
||||
|
||||
class LMHead(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
|
||||
if not config.tie_word_embeddings:
|
||||
self.weight = nn.Parameter(torch.zeros(config.vocab_size, config.hidden_size))
|
||||
self.weight.requires_grad = False
|
||||
else:
|
||||
self.weight = None # Will be set to get_input_embeddings().weight during loading the model
|
||||
self.bias = None
|
||||
self.in_features = config.hidden_size # Similar to nn.Linear attributes
|
||||
self.out_features = config.vocab_size
|
||||
|
||||
self.use_chunked_forward = config.use_chunked_forward
|
||||
if self.use_chunked_forward == "auto":
|
||||
if platform.machine() == "x86_64":
|
||||
# Import of cpufeature may crash on non-x86_64 machines
|
||||
from cpufeature import CPUFeature
|
||||
|
||||
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
|
||||
# Otherwise, it's ~8x slower.
|
||||
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
|
||||
else:
|
||||
self.use_chunked_forward = True
|
||||
self.chunked_forward_step = config.chunked_forward_step
|
||||
self._bf16_warning_shown = False
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if (
|
||||
self.weight.dtype in [torch.float16, torch.bfloat16]
|
||||
and self.weight.device.type == "cpu"
|
||||
and self.use_chunked_forward
|
||||
):
|
||||
lm_logits = self.chunked_forward(hidden_states)
|
||||
else:
|
||||
# Switch dtype in case word_embeddings are fp16/bf16
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
lm_logits = F.linear(hidden_states, self.weight)
|
||||
return lm_logits
|
||||
|
||||
def chunked_forward(self, hidden_states):
|
||||
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
|
||||
chunked_forward_step: provides trade-off between efficiency and extra memory consumption.
|
||||
"""
|
||||
assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
|
||||
|
||||
if not self._bf16_warning_shown:
|
||||
if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
|
||||
logger.warning(
|
||||
"Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
|
||||
"Consider loading the model with torch_dtype='float32'"
|
||||
)
|
||||
self._bf16_warning_shown = True
|
||||
|
||||
hidden_states = hidden_states.float()
|
||||
output = torch.empty(*hidden_states.shape[:-1], self.out_features)
|
||||
|
||||
for i in range(0, self.out_features, self.chunked_forward_step):
|
||||
chunk = self.weight[i : i + self.chunked_forward_step].float()
|
||||
output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)
|
||||
return output
|
@ -0,0 +1,84 @@
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind import get_logger
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from petals.utils.misc import DUMMY
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PTuneConfig:
|
||||
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
||||
tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
|
||||
|
||||
|
||||
class PTuneMixin:
|
||||
_keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"]
|
||||
|
||||
def init_prompts(self, config: PretrainedConfig) -> None:
|
||||
if config.tuning_mode and "ptune" in config.tuning_mode:
|
||||
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
||||
self.pre_seq_len = config.pre_seq_len
|
||||
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
||||
|
||||
with force_non_empty_weights():
|
||||
# Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality
|
||||
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
||||
if config.tuning_mode == "deep_ptune":
|
||||
self.intermediate_prompt_embeddings = nn.Embedding(
|
||||
self.pre_seq_len,
|
||||
config.num_hidden_layers * config.hidden_size,
|
||||
# ^-- TODO: should be num_hidden_layers - 1
|
||||
dtype=torch.float32,
|
||||
)
|
||||
elif config.tuning_mode:
|
||||
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
||||
|
||||
def get_prompt(self, batch_size):
|
||||
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
||||
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
||||
prompts = self.prompt_embeddings(prefix_tokens)
|
||||
|
||||
if self.config.tuning_mode == "deep_ptune":
|
||||
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
||||
intermediate_prompts = intermediate_prompts.view(
|
||||
batch_size,
|
||||
self.pre_seq_len,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.hidden_size
|
||||
# TODO: should be num_hidden_layers - 1
|
||||
)
|
||||
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
||||
else:
|
||||
intermediate_prompts = DUMMY
|
||||
|
||||
dtype = self.word_embeddings.weight.dtype
|
||||
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
||||
|
||||
|
||||
_original_register_parameter = nn.Module.register_parameter
|
||||
|
||||
|
||||
@contextmanager
|
||||
def force_non_empty_weights():
|
||||
"""
|
||||
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
||||
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
||||
The transformers library should replace all meta tensors by empty tensors by itself
|
||||
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
||||
|
||||
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
||||
"""
|
||||
|
||||
try:
|
||||
possibly_patched_register_parameter = nn.Module.register_parameter
|
||||
nn.Module.register_parameter = _original_register_parameter
|
||||
yield
|
||||
finally:
|
||||
nn.Module.register_parameter = possibly_patched_register_parameter
|
@ -1,264 +0,0 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional
|
||||
|
||||
import hivemind
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.bloom import (
|
||||
BloomConfig,
|
||||
BloomForCausalLM,
|
||||
BloomForSequenceClassification,
|
||||
BloomModel,
|
||||
BloomPreTrainedModel,
|
||||
)
|
||||
|
||||
from petals.bloom.modeling_utils import LMHead
|
||||
from petals.client.remote_generation import RemoteGenerationMixin
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.constants import PUBLIC_INITIAL_PEERS
|
||||
from petals.utils.misc import DUMMY
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class DistributedBloomConfig(BloomConfig):
|
||||
"""
|
||||
A bloom config that contains information about DHT peers.
|
||||
To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
|
||||
"""
|
||||
|
||||
initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT
|
||||
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
|
||||
daemon_startup_timeout: int = 30
|
||||
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
|
||||
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
|
||||
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
||||
tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
|
||||
request_timeout: int = 30 # a number of seconds for waiting result from each node
|
||||
|
||||
|
||||
original_register_parameter = nn.Module.register_parameter
|
||||
|
||||
|
||||
@contextmanager
|
||||
def force_non_empty_weights():
|
||||
"""
|
||||
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
||||
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
||||
The transformers library should replace all meta tensors by empty tensors by itself
|
||||
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
||||
|
||||
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
||||
"""
|
||||
|
||||
try:
|
||||
possibly_patched_register_parameter = nn.Module.register_parameter
|
||||
nn.Module.register_parameter = original_register_parameter
|
||||
yield
|
||||
finally:
|
||||
nn.Module.register_parameter = possibly_patched_register_parameter
|
||||
|
||||
|
||||
class _LowCPUMemoryMixin:
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
|
||||
|
||||
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
||||
"low_cpu_mem_usage(`bool`, *optional*)",
|
||||
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
||||
)
|
||||
|
||||
|
||||
class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
|
||||
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
|
||||
r"^(intermediate_)?prompt_embeddings\.weight$",
|
||||
]
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig):
|
||||
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
|
||||
assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
|
||||
|
||||
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
|
||||
super().__init__(config)
|
||||
assert len(self.h) == 0
|
||||
config.n_layer = n_layer
|
||||
|
||||
dht = (
|
||||
config.dht
|
||||
if config.dht is not None
|
||||
else hivemind.DHT(
|
||||
initial_peers=config.initial_peers,
|
||||
client_mode=True,
|
||||
num_workers=n_layer,
|
||||
startup_timeout=config.daemon_startup_timeout,
|
||||
start=True,
|
||||
)
|
||||
)
|
||||
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
|
||||
self.h = RemoteSequential(config, dht, config.dht_prefix, request_timeout=config.request_timeout)
|
||||
|
||||
# Forbid accumulate grads for embeddings and layernorm
|
||||
self.set_requires_grad(False)
|
||||
|
||||
if config.tuning_mode and "ptune" in config.tuning_mode:
|
||||
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
||||
self.pre_seq_len = config.pre_seq_len
|
||||
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
||||
|
||||
with force_non_empty_weights():
|
||||
if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16):
|
||||
logger.info(
|
||||
"Prompt embeddings and their optimizer statistics will be kept in float32 "
|
||||
"to increase ptune quality"
|
||||
)
|
||||
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
||||
if config.tuning_mode == "deep_ptune":
|
||||
self.intermediate_prompt_embeddings = nn.Embedding(
|
||||
self.pre_seq_len,
|
||||
config.num_hidden_layers * config.hidden_size,
|
||||
# ^-- TODO: should be num_hidden_layers - 1
|
||||
dtype=torch.float32,
|
||||
)
|
||||
elif config.tuning_mode:
|
||||
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
||||
|
||||
def set_requires_grad(self, value):
|
||||
for p in self.parameters():
|
||||
p.requires_grad = value
|
||||
|
||||
def get_prompt(self, batch_size):
|
||||
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
||||
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
||||
prompts = self.prompt_embeddings(prefix_tokens)
|
||||
|
||||
if self.config.tuning_mode == "deep_ptune":
|
||||
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
||||
intermediate_prompts = intermediate_prompts.view(
|
||||
batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1
|
||||
)
|
||||
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
||||
else:
|
||||
intermediate_prompts = DUMMY
|
||||
|
||||
dtype = self.word_embeddings.weight.dtype
|
||||
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if not (v is None or v is False):
|
||||
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
||||
else:
|
||||
hidden_states = self.h(hidden_states)
|
||||
|
||||
# Remove prefix
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
||||
class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, BloomForCausalLM):
|
||||
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = (
|
||||
BloomForCausalLM._keys_to_ignore_on_load_missing
|
||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||
+ [r"^lm_head.word_embeddings\.weight$"] # Missing since they are shared with input embeddings
|
||||
)
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig):
|
||||
BloomPreTrainedModel.__init__(self, config)
|
||||
self.transformer = DistributedBloomModel(config)
|
||||
self.lm_head = LMHead(config, self.transformer.word_embeddings)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.transformer.word_embeddings
|
||||
|
||||
def get_output_embeddings(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
return None
|
||||
return self.lm_head
|
||||
|
||||
def set_input_embeddings(self, new_embeddings: nn.Embedding):
|
||||
assert isinstance(new_embeddings, nn.Embedding)
|
||||
self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
|
||||
assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
|
||||
|
||||
def set_output_embeddings(self, new_lm_head: nn.Linear):
|
||||
with torch.no_grad():
|
||||
self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
|
||||
self.lm_head.bias[...] = new_lm_head.bias
|
||||
|
||||
|
||||
class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequenceClassification):
|
||||
_keys_to_ignore_on_load_missing = (
|
||||
BloomForSequenceClassification._keys_to_ignore_on_load_missing
|
||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||
)
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig):
|
||||
BloomPreTrainedModel.__init__(self, config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = DistributedBloomModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
@ -1,6 +1,18 @@
|
||||
import torch
|
||||
|
||||
PUBLIC_INITIAL_PEERS = [
|
||||
"/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
"/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
# IPv4 DNS addresses
|
||||
"/dns/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/dns/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
# IPv6 DNS addresses
|
||||
"/dns6/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/dns6/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
# Reserved IPs
|
||||
"/ip4/159.89.214.152/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/ip4/159.203.156.48/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
]
|
||||
|
||||
# The reachability API is currently used only when connecting to the public swarm
|
||||
REACHABILITY_API_URL = "https://health.petals.dev"
|
||||
|
||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
||||
|
@ -0,0 +1,2 @@
|
||||
from petals.models.bloom import *
|
||||
from petals.models.llama import *
|
@ -0,0 +1,15 @@
|
||||
from petals.models.bloom.block import WrappedBloomBlock
|
||||
from petals.models.bloom.config import DistributedBloomConfig
|
||||
from petals.models.bloom.model import (
|
||||
DistributedBloomForCausalLM,
|
||||
DistributedBloomForSequenceClassification,
|
||||
DistributedBloomModel,
|
||||
)
|
||||
from petals.utils.auto_config import register_model_classes
|
||||
|
||||
register_model_classes(
|
||||
config=DistributedBloomConfig,
|
||||
model=DistributedBloomModel,
|
||||
model_for_causal_lm=DistributedBloomForCausalLM,
|
||||
model_for_sequence_classification=DistributedBloomForSequenceClassification,
|
||||
)
|
@ -0,0 +1,32 @@
|
||||
"""
|
||||
Bloom intermediate layer
|
||||
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
||||
See commit history for authorship.
|
||||
"""
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
|
||||
|
||||
|
||||
class WrappedBloomBlock(BloomBlock):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
alibi: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs
|
||||
):
|
||||
assert attention_mask is None, "Non-causal attention masks are not supported yet"
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
|
||||
seq_length_with_past = seq_length + past_length
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
if alibi is None:
|
||||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
|
||||
return super().forward(
|
||||
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
||||
)
|
@ -0,0 +1,34 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from hivemind import get_logger
|
||||
from transformers.models.bloom import BloomConfig
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
from petals.client.lm_head import LMHeadConfig
|
||||
from petals.client.ptune import PTuneConfig
|
||||
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
||||
from petals.models.bloom.block import WrappedBloomBlock
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
|
||||
block_class = WrappedBloomBlock
|
||||
attn_class = BloomAttention
|
||||
block_prefix = "h"
|
||||
|
||||
num_key_value_groups = 1
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||
):
|
||||
logger.info("Make sure you follow the BLOOM's terms of use: https://bit.ly/bloom-license")
|
||||
|
||||
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||
if loading_from_repo and dht_prefix is None:
|
||||
# We need "-petals" for backward compatibility with Petals < 1.2.0
|
||||
dht_prefix = str(model_name_or_path) + "-petals"
|
||||
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
@ -0,0 +1,126 @@
|
||||
from typing import Optional
|
||||
|
||||
import hivemind
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
|
||||
|
||||
from petals.client.from_pretrained import FromPretrainedMixin
|
||||
from petals.client.lm_head import LMHead
|
||||
from petals.client.ptune import PTuneMixin
|
||||
from petals.client.remote_generation import RemoteGenerationMixin
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.models.bloom.config import DistributedBloomConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
|
||||
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = [r"^h\."]
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||
super().__init__(config)
|
||||
assert len(self.h) == 0
|
||||
config.num_hidden_layers = n_layer
|
||||
|
||||
self.h = RemoteSequential(config, dht=dht)
|
||||
|
||||
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||
self.init_prompts(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if not (v is None or v is False):
|
||||
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
||||
else:
|
||||
hidden_states = self.h(hidden_states)
|
||||
|
||||
# Remove prefix
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
||||
class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
|
||||
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
|
||||
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig):
|
||||
BloomPreTrainedModel.__init__(self, config)
|
||||
self.transformer = DistributedBloomModel(config)
|
||||
self.lm_head = LMHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
|
||||
class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
|
||||
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig):
|
||||
BloomPreTrainedModel.__init__(self, config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = DistributedBloomModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
@ -0,0 +1,15 @@
|
||||
from petals.models.llama.block import WrappedLlamaBlock
|
||||
from petals.models.llama.config import DistributedLlamaConfig
|
||||
from petals.models.llama.model import (
|
||||
DistributedLlamaForCausalLM,
|
||||
DistributedLlamaForSequenceClassification,
|
||||
DistributedLlamaModel,
|
||||
)
|
||||
from petals.utils.auto_config import register_model_classes
|
||||
|
||||
register_model_classes(
|
||||
config=DistributedLlamaConfig,
|
||||
model=DistributedLlamaModel,
|
||||
model_for_causal_lm=DistributedLlamaForCausalLM,
|
||||
model_for_sequence_classification=DistributedLlamaForSequenceClassification,
|
||||
)
|
@ -0,0 +1,91 @@
|
||||
"""
|
||||
LLaMA intermediate layer
|
||||
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
See commit history for authorship.
|
||||
"""
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
|
||||
class WrappedLlamaBlock(LlamaDecoderLayer):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
past_key_value = layer_past
|
||||
if past_key_value is not None:
|
||||
past_key_values_length = past_key_value[0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
|
||||
|
||||
if position_ids is None:
|
||||
device = hidden_states.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
attention_mask = LlamaModel._prepare_decoder_attention_mask(
|
||||
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
)
|
||||
|
||||
outputs = super().forward(
|
||||
hidden_states,
|
||||
*args,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
present_key_value = outputs[-1]
|
||||
present_key_value = self._reorder_cache_from_llama_to_bloom(
|
||||
present_key_value, batch_size, seq_length_with_past
|
||||
)
|
||||
outputs = outputs[:-1] + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def _reorder_cache_from_bloom_to_llama(
|
||||
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||
) -> Tuple[torch.Tensor]:
|
||||
key_states, value_states = key_value
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
key_states = key_states.view(
|
||||
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||
)
|
||||
value_states = value_states.view(*key_states.shape)
|
||||
return (key_states, value_states)
|
||||
|
||||
def _reorder_cache_from_llama_to_bloom(
|
||||
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||
) -> Tuple[torch.Tensor]:
|
||||
key_states, value_states = key_value
|
||||
value_states = value_states.view(
|
||||
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||
)
|
||||
key_states = key_states.view(*value_states.shape)
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
return (key_states, value_states)
|
@ -0,0 +1,45 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from hivemind import get_logger
|
||||
from transformers.models.llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
|
||||
from petals.client.lm_head import LMHeadConfig
|
||||
from petals.client.ptune import PTuneConfig
|
||||
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
||||
from petals.models.llama.block import WrappedLlamaBlock
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
|
||||
block_class = WrappedLlamaBlock
|
||||
attn_class = LlamaAttention
|
||||
block_prefix = "model.layers"
|
||||
|
||||
@property
|
||||
def num_key_value_groups(self):
|
||||
return self.num_attention_heads // self.num_key_value_heads
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||
):
|
||||
logger.info(
|
||||
"Make sure you follow the LLaMA's terms of use: "
|
||||
"https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1"
|
||||
)
|
||||
|
||||
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||
if loading_from_repo and dht_prefix is None:
|
||||
dht_prefix = str(model_name_or_path)
|
||||
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
|
||||
if not dht_prefix.endswith("-hf"):
|
||||
dht_prefix += "-hf"
|
||||
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||
|
||||
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||
config = result[0] if isinstance(result, tuple) else result
|
||||
config.pretraining_tp = 1 # This may give less accurate results but it doesn't matter if we use quantization
|
||||
return result
|
@ -0,0 +1,151 @@
|
||||
from typing import Optional
|
||||
|
||||
import hivemind
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
|
||||
|
||||
from petals.client.from_pretrained import FromPretrainedMixin
|
||||
from petals.client.lm_head import LMHead
|
||||
from petals.client.ptune import PTuneMixin
|
||||
from petals.client.remote_generation import RemoteGenerationMixin
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.models.llama.config import DistributedLlamaConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
|
||||
"""LlamaModel, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = [r"^model\.layers\."]
|
||||
|
||||
config_class = DistributedLlamaConfig
|
||||
|
||||
def __init__(self, config: DistributedLlamaConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||
super().__init__(config)
|
||||
assert len(self.layers) == 0
|
||||
config.num_hidden_layers = n_layer
|
||||
|
||||
self.layers = RemoteSequential(config, dht=dht)
|
||||
|
||||
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||
self.init_prompts(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> BaseModelOutputWithPast:
|
||||
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if not (v is None or v is False):
|
||||
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = self.layers(hidden_states, prompts=intermediate_prompts)
|
||||
else:
|
||||
hidden_states = self.layers(hidden_states)
|
||||
|
||||
# Remove prefix
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
|
||||
return self.embed_tokens
|
||||
|
||||
@property
|
||||
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||
return nn.Identity()
|
||||
|
||||
@property
|
||||
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
|
||||
return self.layers
|
||||
|
||||
@property
|
||||
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||
return self.norm
|
||||
|
||||
|
||||
class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, LlamaForCausalLM):
|
||||
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedLlamaConfig
|
||||
|
||||
def __init__(self, config: DistributedLlamaConfig):
|
||||
LlamaPreTrainedModel.__init__(self, config)
|
||||
self.model = DistributedLlamaModel(config)
|
||||
self.pretraining_tp = config.pretraining_tp
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = LMHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
@property
|
||||
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
|
||||
return self.model
|
||||
|
||||
|
||||
class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
|
||||
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedLlamaConfig
|
||||
|
||||
def __init__(self, config):
|
||||
LlamaPreTrainedModel.__init__(self, config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.model = DistributedLlamaModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@property
|
||||
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
|
||||
return self.model
|
@ -0,0 +1,211 @@
|
||||
"""
|
||||
This module implements server-side computations on served blocks: forward, backward and inference; used by handler
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncIterator, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
|
||||
from hivemind.moe.expert_uid import ExpertUID
|
||||
from hivemind.proto import runtime_pb2
|
||||
from hivemind.utils.nested import nested_flatten
|
||||
|
||||
from petals.data_structures import InferenceMetadata
|
||||
from petals.server.backend import TransformerBackend
|
||||
from petals.server.memory_cache import Handle
|
||||
from petals.server.task_pool import PrioritizedTaskPool
|
||||
from petals.server.task_prioritizer import TaskPrioritizerBase
|
||||
from petals.utils.convert_block import QuantType
|
||||
from petals.utils.misc import DUMMY, is_dummy
|
||||
|
||||
# We prioritize short inference requests and make them use a *merged* inference pool,
|
||||
# so they are processed without interruptions and extra overheads
|
||||
# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward
|
||||
MAX_SHORT_INFERENCE_TOKENS = 128
|
||||
MAX_NF4_SHORT_INFERENCE_TOKENS = 1
|
||||
|
||||
|
||||
async def run_rpc_forward(
|
||||
*flat_tensors: torch.Tensor,
|
||||
requested_backends: Sequence[TransformerBackend],
|
||||
active_adapter: str = "",
|
||||
prioritizer: TaskPrioritizerBase,
|
||||
points: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
||||
|
||||
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
|
||||
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
|
||||
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
||||
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
||||
"""
|
||||
hidden_states, prompts = flat_tensors
|
||||
dtype = requested_backends[0].dtype
|
||||
# check parse input tensors and cast dtypes
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
assert hidden_states.ndim == 3
|
||||
if prompts is None or is_dummy(prompts):
|
||||
prompts = [DUMMY] * len(requested_backends)
|
||||
else:
|
||||
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||
|
||||
# Run a chain of requested backends
|
||||
for backend, prompt in zip(requested_backends, prompts):
|
||||
if not is_dummy(prompt):
|
||||
hidden_states[:, : prompt.shape[1]] += prompt
|
||||
|
||||
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||
priority = prioritizer.prioritize(
|
||||
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
|
||||
)
|
||||
(hidden_states,) = await backend.forward_pool.submit_task(
|
||||
hidden_states,
|
||||
active_adapter,
|
||||
priority=priority,
|
||||
)
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
assert (
|
||||
hidden_states.ndim == 3
|
||||
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
async def run_rpc_backward(
|
||||
*flat_tensors: torch.Tensor,
|
||||
requested_backends: Sequence[TransformerBackend],
|
||||
active_adapter: str = "",
|
||||
prioritizer: TaskPrioritizerBase,
|
||||
points: int = 0,
|
||||
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
||||
inputs, grad_outputs, prompts = flat_tensors
|
||||
# Cast inputs & grad outputs to backend dtype
|
||||
inputs = inputs.to(requested_backends[0].dtype)
|
||||
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
||||
|
||||
if prompts is None or is_dummy(prompts):
|
||||
prompts = [DUMMY] * len(requested_backends)
|
||||
else:
|
||||
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||
|
||||
# Run a forward chain to collect intermediate inputs
|
||||
# Note that we do not forward for the last module since we do not need its output
|
||||
inter_inputs = []
|
||||
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
||||
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
||||
if not is_dummy(prompt):
|
||||
inputs[:, : prompt.shape[1]] += prompt
|
||||
inter_inputs.append(inputs)
|
||||
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||
priority = prioritizer.prioritize(
|
||||
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
||||
)
|
||||
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
|
||||
|
||||
assert isinstance(inputs, torch.Tensor)
|
||||
|
||||
if not is_dummy(prompts[-1]):
|
||||
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
|
||||
inter_inputs.append(inputs)
|
||||
|
||||
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
|
||||
grad_prompts_reversed = []
|
||||
# Run a chain of requested backends
|
||||
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
|
||||
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||
priority = prioritizer.prioritize(
|
||||
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
||||
)
|
||||
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
|
||||
|
||||
assert isinstance(grad_outputs, torch.Tensor)
|
||||
if not is_dummy(prompt):
|
||||
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
||||
|
||||
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
||||
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
|
||||
|
||||
|
||||
async def iterate_rpc_inference(
|
||||
requested_uids: Sequence[ExpertUID],
|
||||
requested_backends: Sequence[TransformerBackend],
|
||||
active_adapter: Optional[str],
|
||||
input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
|
||||
cache_handles: Sequence[Sequence[Handle]],
|
||||
*,
|
||||
max_length: int,
|
||||
prioritizer: TaskPrioritizerBase,
|
||||
points: int,
|
||||
quant_type: QuantType,
|
||||
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
|
||||
assert len(cache_handles) == len(requested_backends)
|
||||
|
||||
prefix_length = 0
|
||||
point_per_piece = points / max_length if max_length > 0 else 0.0
|
||||
|
||||
async for request, step_metadata in input_iterator:
|
||||
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
|
||||
batch_size, length_increment, _ = hidden_states.shape
|
||||
|
||||
# Cast inputs to backend dtype
|
||||
hidden_states = hidden_states.to(requested_backends[0].dtype)
|
||||
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
|
||||
|
||||
# parse deep prompts (optional argument)
|
||||
has_prompts = prompts is not None and not is_dummy(prompts)
|
||||
if not has_prompts:
|
||||
prompts = [None] * len(requested_backends)
|
||||
else:
|
||||
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
|
||||
|
||||
if not (len(requested_backends) == len(prompts)):
|
||||
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
|
||||
|
||||
if prefix_length + length_increment > max_length:
|
||||
raise ValueError(
|
||||
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
|
||||
f" exceeds pre-allocated maximum {max_length}"
|
||||
)
|
||||
|
||||
merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS
|
||||
can_merge_pools = batch_size * length_increment <= merge_max_tokens
|
||||
priority = prioritizer.prioritize(
|
||||
hidden_states,
|
||||
hypo_ids,
|
||||
points=point_per_piece,
|
||||
requested_uids=requested_uids,
|
||||
type="short_inference" if can_merge_pools else "inference",
|
||||
)
|
||||
|
||||
# A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.
|
||||
# when user wants to pre-allocate cache or check that server *can* allocate that cache.
|
||||
if hidden_states.numel() > 0:
|
||||
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
|
||||
if can_merge_pools:
|
||||
inference_infos = tuple(
|
||||
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
|
||||
for uid, handles in zip(requested_uids, cache_handles)
|
||||
)
|
||||
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
|
||||
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
|
||||
)
|
||||
else:
|
||||
for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
|
||||
inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
|
||||
(hidden_states,) = await backend.inference_pool.submit_task(
|
||||
hidden_states, hypo_ids, inference_infos, prompt, priority=priority
|
||||
)
|
||||
|
||||
# serialize and send last layer outputs
|
||||
output_tensors = [
|
||||
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
||||
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
||||
]
|
||||
can_push = not has_prompts
|
||||
yield output_tensors, can_push
|
||||
|
||||
# prepare for next step
|
||||
prefix_length += length_increment
|
@ -0,0 +1,177 @@
|
||||
"""
|
||||
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
|
||||
If necessary, one can rewrite this to implement a different behavior, such as:
|
||||
- loading files from a local data source (e.g. S3)
|
||||
- load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
|
||||
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
|
||||
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from hivemind.utils.logging import get_logger
|
||||
from huggingface_hub import get_hf_file_metadata, hf_hub_url
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import get_file_from_repo
|
||||
|
||||
from petals.constants import DTYPE_MAP
|
||||
from petals.server.block_utils import resolve_block_dtype
|
||||
from petals.utils.auto_config import AutoDistributedConfig
|
||||
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||
from petals.utils.hf_auth import always_needs_auth
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def load_pretrained_block(
|
||||
model_name: str,
|
||||
block_index: int,
|
||||
*,
|
||||
config: Optional[PretrainedConfig] = None,
|
||||
torch_dtype: Union[torch.dtype, str] = "auto",
|
||||
revision: Optional[str] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
max_disk_space: Optional[int] = None,
|
||||
) -> nn.Module:
|
||||
if config is None:
|
||||
config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=token)
|
||||
if cache_dir is None:
|
||||
cache_dir = DEFAULT_CACHE_DIR
|
||||
|
||||
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
||||
torch_dtype = resolve_block_dtype(config, torch_dtype)
|
||||
|
||||
with init_empty_weights():
|
||||
block = config.block_class(config)
|
||||
|
||||
block_prefix = f"{config.block_prefix}.{block_index}."
|
||||
state_dict = _load_state_dict_from_repo(
|
||||
model_name,
|
||||
block_prefix,
|
||||
revision=revision,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
max_disk_space=max_disk_space,
|
||||
)
|
||||
|
||||
# dummy load, check that keys match
|
||||
report = block.load_state_dict(state_dict, strict=True)
|
||||
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
|
||||
|
||||
for param_name, _ in block.named_parameters():
|
||||
assert param_name in state_dict, f"{param_name} not in state dict"
|
||||
param = state_dict[param_name]
|
||||
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
param = param.to(torch_dtype)
|
||||
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
||||
|
||||
logger.info(f"Loaded {model_name} block {block_index}, {report}")
|
||||
return block
|
||||
|
||||
|
||||
StateDict = Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
def _load_state_dict_from_repo(
|
||||
model_name: str,
|
||||
block_prefix: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
cache_dir: str,
|
||||
max_disk_space: Optional[int] = None,
|
||||
) -> StateDict:
|
||||
if always_needs_auth(model_name) and token is None:
|
||||
token = True
|
||||
|
||||
index_file = get_file_from_repo(
|
||||
model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
|
||||
)
|
||||
if index_file is not None: # Sharded model
|
||||
with open(index_file) as f:
|
||||
index = json.load(f)
|
||||
filenames = {
|
||||
filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
|
||||
}
|
||||
if not filenames:
|
||||
raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
|
||||
else: # Non-sharded model
|
||||
filenames = {"pytorch_model.bin"}
|
||||
logger.debug(f"Loading {block_prefix}* from {filenames}")
|
||||
|
||||
state_dict = {}
|
||||
for filename in filenames:
|
||||
shard_state_dict = _load_state_dict_from_file(
|
||||
model_name,
|
||||
filename,
|
||||
revision=revision,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
max_disk_space=max_disk_space,
|
||||
)
|
||||
shard_state_dict = {
|
||||
param_name[len(block_prefix) :]: param
|
||||
for param_name, param in shard_state_dict.items()
|
||||
if param_name.startswith(block_prefix)
|
||||
} # Remove unused parameters from memory
|
||||
state_dict.update(shard_state_dict)
|
||||
return state_dict
|
||||
|
||||
|
||||
def _load_state_dict_from_file(
|
||||
model_name: str,
|
||||
filename: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
cache_dir: str,
|
||||
max_disk_space: Optional[int] = None,
|
||||
delay: float = 30,
|
||||
) -> StateDict:
|
||||
# First, try to find the weights locally
|
||||
try:
|
||||
with allow_cache_reads(cache_dir):
|
||||
path = get_file_from_repo(
|
||||
model_name,
|
||||
filename,
|
||||
revision=revision,
|
||||
use_auth_token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=True,
|
||||
)
|
||||
if path is not None:
|
||||
return torch.load(path, map_location="cpu")
|
||||
except Exception:
|
||||
logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
|
||||
|
||||
# If not found, ensure that we have enough disk space to download them (maybe remove something)
|
||||
while True:
|
||||
try:
|
||||
with allow_cache_writes(cache_dir):
|
||||
url = hf_hub_url(model_name, filename, revision=revision)
|
||||
file_size = get_hf_file_metadata(url, token=token).size
|
||||
if file_size is not None:
|
||||
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
||||
else:
|
||||
logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}")
|
||||
|
||||
path = get_file_from_repo(
|
||||
model_name,
|
||||
filename,
|
||||
revision=revision,
|
||||
use_auth_token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
if path is None:
|
||||
raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
|
||||
return torch.load(path, map_location="cpu")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
||||
time.sleep(delay)
|
@ -0,0 +1,164 @@
|
||||
import asyncio
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from hivemind.dht import DHT, DHTNode
|
||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||
from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
|
||||
from hivemind.proto import dht_pb2
|
||||
from hivemind.utils import get_logger
|
||||
|
||||
from petals.constants import REACHABILITY_API_URL
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:
|
||||
"""verify that your peer is reachable from a (centralized) validator, whether directly or through a relay"""
|
||||
for attempt_no in range(math.floor(wait_time / retry_delay) + 1):
|
||||
try:
|
||||
r = requests.get(f"{REACHABILITY_API_URL}/api/v1/is_reachable/{peer_id}", timeout=10)
|
||||
r.raise_for_status()
|
||||
response = r.json()
|
||||
|
||||
if response["success"]:
|
||||
logger.info("Server is reachable from the Internet. It will appear at https://health.petals.dev soon")
|
||||
return
|
||||
|
||||
if attempt_no == 0:
|
||||
# Usually, libp2p manages to set up relays before we finish loading blocks.
|
||||
# In other cases, we may need to wait for up to `wait_time` seconds before it's done.
|
||||
logger.info("Detected a NAT or a firewall, connecting to libp2p relays. This takes a few minutes")
|
||||
time.sleep(retry_delay)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping reachability check because health.petals.dev is down: {repr(e)}")
|
||||
return
|
||||
|
||||
raise RuntimeError(
|
||||
f"Server has not become reachable from the Internet:\n\n"
|
||||
f"{response['message']}\n\n"
|
||||
f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
|
||||
f" 1. Choose a specific port for the Petals server, for example, 31337.\n"
|
||||
f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
|
||||
f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
|
||||
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
|
||||
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
|
||||
)
|
||||
|
||||
|
||||
def check_direct_reachability(max_peers: int = 5, threshold: float = 0.5, **kwargs) -> Optional[bool]:
|
||||
"""test if your peer is accessible by others in the swarm with the specified network options in **kwargs"""
|
||||
|
||||
async def _check_direct_reachability():
|
||||
target_dht = await DHTNode.create(client_mode=True, **kwargs)
|
||||
try:
|
||||
protocol = ReachabilityProtocol(probe=target_dht.protocol.p2p)
|
||||
async with protocol.serve(target_dht.protocol.p2p):
|
||||
successes = requests = 0
|
||||
for remote_peer in list(target_dht.protocol.routing_table.peer_id_to_uid.keys()):
|
||||
probe_available = await protocol.call_check(remote_peer=remote_peer, check_peer=target_dht.peer_id)
|
||||
if probe_available is None:
|
||||
continue # remote peer failed to check probe
|
||||
successes += probe_available
|
||||
requests += 1
|
||||
if requests >= max_peers:
|
||||
break
|
||||
|
||||
logger.debug(f"Direct reachability: {successes}/{requests}")
|
||||
return (successes / requests) >= threshold if requests > 0 else None
|
||||
finally:
|
||||
await target_dht.shutdown()
|
||||
|
||||
return RemoteExpertWorker.run_coroutine(_check_direct_reachability())
|
||||
|
||||
|
||||
STRIPPED_PROBE_ARGS = dict(
|
||||
dht_mode="client", use_relay=False, auto_nat=False, nat_port_map=False, no_listen=True, startup_timeout=60
|
||||
)
|
||||
|
||||
|
||||
class ReachabilityProtocol(ServicerBase):
|
||||
"""Mini protocol to test if a locally running peer is accessible by other devices in the swarm"""
|
||||
|
||||
def __init__(self, *, probe: Optional[P2P] = None, wait_timeout: float = 5.0):
|
||||
self.probe = probe
|
||||
self.wait_timeout = wait_timeout
|
||||
self._event_loop = self._stop = None
|
||||
|
||||
async def call_check(self, remote_peer: PeerID, *, check_peer: PeerID) -> Optional[bool]:
|
||||
"""Returns True if remote_peer can reach check_peer, False if it cannot, None if it did not respond"""
|
||||
try:
|
||||
request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=check_peer.to_bytes()))
|
||||
timeout = self.wait_timeout if check_peer == remote_peer else self.wait_timeout * 2
|
||||
response = await self.get_stub(self.probe, remote_peer).rpc_check(request, timeout=timeout)
|
||||
logger.debug(f"call_check(remote_peer={remote_peer}, check_peer={check_peer}) -> {response.available}")
|
||||
return response.available
|
||||
except Exception as e:
|
||||
logger.debug(f"Requested {remote_peer} to check {check_peer}, but got:", exc_info=True)
|
||||
return None
|
||||
|
||||
async def rpc_check(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse:
|
||||
"""Help another peer to check its reachability"""
|
||||
response = dht_pb2.PingResponse(available=True)
|
||||
check_peer = PeerID(request.peer.node_id)
|
||||
if check_peer != context.local_id: # remote peer wants us to check someone other than ourselves
|
||||
response.available = await self.call_check(check_peer, check_peer=check_peer) is True
|
||||
logger.info(
|
||||
f"reachability.rpc_check(remote_peer=...{str(context.remote_id)[-6:]}, "
|
||||
f"check_peer=...{str(check_peer)[-6:]}) -> {response.available}"
|
||||
)
|
||||
return response
|
||||
|
||||
@asynccontextmanager
|
||||
async def serve(self, p2p: P2P):
|
||||
try:
|
||||
await self.add_p2p_handlers(p2p)
|
||||
yield self
|
||||
finally:
|
||||
await self.remove_p2p_handlers(p2p)
|
||||
|
||||
@classmethod
|
||||
def attach_to_dht(cls, dht: DHT, await_ready: bool = False, **kwargs) -> Optional["ReachabilityProtocol"]:
|
||||
protocol = cls(**kwargs)
|
||||
ready = Future()
|
||||
|
||||
async def _serve_with_probe():
|
||||
try:
|
||||
common_p2p = await dht.replicate_p2p()
|
||||
protocol._event_loop = asyncio.get_event_loop()
|
||||
protocol._stop = asyncio.Event()
|
||||
|
||||
initial_peers = [str(addr) for addr in await common_p2p.get_visible_maddrs(latest=True)]
|
||||
for info in await common_p2p.list_peers():
|
||||
initial_peers.extend(f"{addr}/p2p/{info.peer_id}" for addr in info.addrs)
|
||||
protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
|
||||
|
||||
ready.set_result(True)
|
||||
logger.info("Reachability service started")
|
||||
|
||||
async with protocol.serve(common_p2p):
|
||||
await protocol._stop.wait()
|
||||
except Exception as e:
|
||||
logger.debug("Reachability service failed:", exc_info=True)
|
||||
|
||||
if not ready.done():
|
||||
ready.set_exception(e)
|
||||
finally:
|
||||
if protocol is not None and protocol.probe is not None:
|
||||
await protocol.probe.shutdown()
|
||||
logger.debug("Reachability service shut down")
|
||||
|
||||
threading.Thread(target=partial(asyncio.run, _serve_with_probe()), daemon=True).start()
|
||||
if await_ready:
|
||||
ready.result() # Propagates startup exceptions, if any
|
||||
return protocol
|
||||
|
||||
def shutdown(self):
|
||||
if self._event_loop is not None and self._stop is not None:
|
||||
self._event_loop.call_soon_threadsafe(self._stop.set)
|
@ -0,0 +1,6 @@
|
||||
from petals.utils.auto_config import (
|
||||
AutoDistributedConfig,
|
||||
AutoDistributedModel,
|
||||
AutoDistributedModelForCausalLM,
|
||||
AutoDistributedModelForSequenceClassification,
|
||||
)
|
@ -0,0 +1,65 @@
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
||||
|
||||
from petals.utils.hf_auth import always_needs_auth
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ModelClasses:
|
||||
config: Type[PretrainedConfig]
|
||||
model: Optional[Type[PreTrainedModel]] = None
|
||||
model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
|
||||
model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None
|
||||
|
||||
|
||||
_CLASS_MAPPING = {} # Populated by petals.models.* subpackages with register_model_classes()
|
||||
|
||||
|
||||
def register_model_classes(*, config: Type[PretrainedConfig], **kwargs):
|
||||
assert issubclass(config, PretrainedConfig)
|
||||
assert config.model_type not in _CLASS_MAPPING, f"Model type {config.model_type} is already registered"
|
||||
|
||||
_CLASS_MAPPING[config.model_type] = _ModelClasses(config=config, **kwargs)
|
||||
|
||||
|
||||
class _AutoDistributedBase:
|
||||
_mapping_field = None # Should be defined in child classes
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig:
|
||||
if (
|
||||
always_needs_auth(model_name_or_path)
|
||||
and kwargs.get("token") is None
|
||||
and kwargs.get("use_auth_token") is None
|
||||
):
|
||||
kwargs["use_auth_token"] = True
|
||||
|
||||
config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs)
|
||||
if config.model_type not in _CLASS_MAPPING:
|
||||
raise ValueError(f"Petals does not support model type {config.model_type}")
|
||||
|
||||
proper_cls = getattr(_CLASS_MAPPING[config.model_type], cls._mapping_field)
|
||||
if proper_cls is None:
|
||||
raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}")
|
||||
|
||||
return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs)
|
||||
|
||||
|
||||
class AutoDistributedConfig(_AutoDistributedBase):
|
||||
_mapping_field = "config"
|
||||
|
||||
|
||||
class AutoDistributedModel(_AutoDistributedBase):
|
||||
_mapping_field = "model"
|
||||
|
||||
|
||||
class AutoDistributedModelForCausalLM(_AutoDistributedBase):
|
||||
_mapping_field = "model_for_causal_lm"
|
||||
|
||||
|
||||
class AutoDistributedModelForSequenceClassification(_AutoDistributedBase):
|
||||
_mapping_field = "model_for_sequence_classification"
|
@ -1,39 +0,0 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
|
||||
|
||||
|
||||
def replace_8bit_linear(model, threshold=6.0):
|
||||
"""
|
||||
A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
|
||||
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
|
||||
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
|
||||
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
|
||||
bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
|
||||
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
|
||||
be kept as a `torch.nn.Linear` module.
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
threshold (`float`, *optional*):
|
||||
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
|
||||
`6.0` as described by the paper.
|
||||
"""
|
||||
for n, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
replace_8bit_linear(module, threshold)
|
||||
|
||||
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
|
||||
model._modules[n] = CustomLinear8bitLt(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
has_fp16_weights=False,
|
||||
threshold=threshold,
|
||||
)
|
||||
model._modules[n].weight = bnb.nn.Int8Params(
|
||||
module.weight.data, requires_grad=False, has_fp16_weights=False
|
||||
).to(module.weight.dtype)
|
||||
model._modules[n].bias = module.bias
|
||||
return model
|
@ -0,0 +1,156 @@
|
||||
"""
|
||||
Tools for converting transformer blocks, applying quantization and/or tensor parallelism
|
||||
"""
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import tensor_parallel as tp
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
from tensor_parallel.slicing_configs import get_bloom_config
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class QuantType(Enum):
|
||||
NONE = 0
|
||||
INT8 = 1 # 8-bit as in the LLM.int8() paper
|
||||
NF4 = 2 # 4-bit as in the QLoRA paper
|
||||
|
||||
|
||||
def convert_block(
|
||||
block: nn.Module,
|
||||
block_index: int,
|
||||
config: PretrainedConfig,
|
||||
tensor_parallel_devices: Sequence[torch.device],
|
||||
output_device: torch.device,
|
||||
quant_type: QuantType,
|
||||
freeze: bool = True,
|
||||
adapters: Optional[Sequence[str]] = None,
|
||||
**kwargs,
|
||||
) -> tp.TensorParallel:
|
||||
"""
|
||||
Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
|
||||
|
||||
:note: some optimizations will modify the input block in-place!
|
||||
:param block: a single transformer block, either pre-trained or newly initialized
|
||||
:param config: HF transformers config for the full model
|
||||
:param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices
|
||||
:note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)
|
||||
:param output_device: if tensor_parallel_devices is True, output
|
||||
:param quant_type: quantization type
|
||||
:param freeze: if True (default), make all module parameters non-trainable
|
||||
:return: a module that acts like the original block, but runs with all specified optimizations
|
||||
|
||||
"""
|
||||
if freeze:
|
||||
block.requires_grad_(False)
|
||||
|
||||
block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
|
||||
|
||||
if quant_type != QuantType.NONE:
|
||||
block = quantize_module(block, quant_type=quant_type)
|
||||
|
||||
for shard, device in zip(block.module_shards, block.devices):
|
||||
shard.to(device)
|
||||
|
||||
if adapters:
|
||||
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
|
||||
|
||||
create_lora_adapter(block, quant_type=quant_type)
|
||||
for adapter_name in adapters:
|
||||
adapter_config, adapter_state_dict = load_peft(
|
||||
adapter_name,
|
||||
block_idx=block_index,
|
||||
**kwargs,
|
||||
)
|
||||
add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict)
|
||||
|
||||
return block
|
||||
|
||||
|
||||
def quantize_module(model: nn.Module, *, quant_type: QuantType) -> nn.Module:
|
||||
# Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes
|
||||
import bitsandbytes as bnb
|
||||
|
||||
for n, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
quantize_module(module, quant_type=quant_type)
|
||||
|
||||
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
|
||||
assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}"
|
||||
if quant_type == QuantType.INT8:
|
||||
model._modules[n] = bnb.nn.Linear8bitLt(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0, # Default from the LLM.int8() paper
|
||||
)
|
||||
model._modules[n].weight = bnb.nn.Int8Params(
|
||||
module.weight.data, requires_grad=False, has_fp16_weights=False
|
||||
).to(module.weight.dtype)
|
||||
elif quant_type == QuantType.NF4:
|
||||
compress_statistics = True
|
||||
model._modules[n] = bnb.nn.LinearNF4(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
compress_statistics=compress_statistics,
|
||||
)
|
||||
model._modules[n].weight = bnb.nn.Params4bit(
|
||||
module.weight.data,
|
||||
requires_grad=False,
|
||||
quant_type="nf4",
|
||||
blocksize=64,
|
||||
compress_statistics=compress_statistics,
|
||||
).to(module.weight.dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quant_type='{quant_type}'")
|
||||
model._modules[n].bias = module.bias
|
||||
return model
|
||||
|
||||
|
||||
def make_tensor_parallel(
|
||||
block: nn.Module, model_config: PretrainedConfig, devices: Sequence[torch.device], output_device: torch.device
|
||||
) -> nn.Module:
|
||||
if model_config.model_type == "bloom":
|
||||
tp_config = get_bloom_config(model_config, devices)
|
||||
del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
|
||||
else:
|
||||
if len(devices) > 1:
|
||||
logger.warning("Tensor parallelism is not tested for models other than BLOOM yet, proceed with caution")
|
||||
tp_config = None
|
||||
tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
|
||||
total_heads = 0
|
||||
for tp_shard in tp_block.module_shards:
|
||||
for submodule in tp_shard.modules():
|
||||
if isinstance(submodule, model_config.attn_class):
|
||||
total_heads += submodule.num_heads
|
||||
assert total_heads == model_config.num_attention_heads
|
||||
return tp_block
|
||||
|
||||
|
||||
def check_device_balance(devices: Sequence[torch.device]):
|
||||
if not all(device.type == "cuda" for device in devices):
|
||||
logger.warning("Running tensor parallelism on non-GPU devices; proceed at your own risk")
|
||||
return
|
||||
unique_device_capabilities = set(map(torch.cuda.get_device_capability, devices))
|
||||
if len(unique_device_capabilities) > 1:
|
||||
logger.warning(
|
||||
f"Found GPUs with uneven capabilities: {unique_device_capabilities}. "
|
||||
f"Using GPUs with different performance will cause the server to wait for the slowest GPU."
|
||||
)
|
||||
|
||||
memory_per_device = tuple(torch.cuda.get_device_properties(device).total_memory for device in devices)
|
||||
used_memory = min(memory_per_device) * len(memory_per_device)
|
||||
wasted_memory_rate = (sum(memory_per_device) - used_memory) / sum(memory_per_device)
|
||||
if wasted_memory_rate > 0.05:
|
||||
logger.warning(
|
||||
f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. "
|
||||
f"Consider running high-memory GPUs in a separate server."
|
||||
)
|
@ -0,0 +1,7 @@
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
|
||||
def always_needs_auth(model_name: Union[str, os.PathLike, None]) -> bool:
|
||||
loading_from_repo = model_name is not None and not os.path.isdir(model_name)
|
||||
return loading_from_repo and model_name.startswith("meta-llama/Llama-2-")
|
@ -1,334 +0,0 @@
|
||||
"""
|
||||
A patch to bitsandbytes 0.34.0 that introduces an option to run backward pass in default (fast) matrix layout.
|
||||
Authors: modification by @borzunov, original code by @timdettmers. Please disregard commit authors in this file.
|
||||
|
||||
Core idea: layouts apply the same permutation to every tile in the matrix. We can treat this as (batched) gather ops.
|
||||
Reshape input tensor so that ij-th gather operation op will apply to ij-th elements in each tile.
|
||||
Prototype: https://colab.research.google.com/drive/1EJ0MKifajXSSVq7O2_QGwtb0l6gRAGrh?usp=sharing
|
||||
Based on: https://github.com/TimDettmers/bitsandbytes/blob/main/csrc/kernels.cu#L2130-L2136
|
||||
Exact match tests: see $REPO/tests/test_linear8bitlt.py
|
||||
"""
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import bitsandbytes.functional as F
|
||||
import torch
|
||||
from bitsandbytes.autograd._functions import GlobalOutlierPooler, MatMul8bitLt, MatmulLtState, prod
|
||||
from bitsandbytes.nn import Linear8bitLt
|
||||
|
||||
|
||||
def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]):
|
||||
"""
|
||||
Compute a permutation of indices that invert the specified (tiled) matrix transformation
|
||||
|
||||
:param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2]
|
||||
:param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere
|
||||
:note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size
|
||||
:example: transform_tile function for the turing layout (bitsandbytes.functional as F)
|
||||
:returns: indices
|
||||
"""
|
||||
d1, d2 = tile_size
|
||||
assert 0 < d1 * d2 < 2**64
|
||||
tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2)
|
||||
# encode each position in tile as a tuple of <= 8 unique bytes
|
||||
permuted_tile_indices = torch.zeros_like(tile_indices)
|
||||
for i in range(8):
|
||||
# select i-th byte, apply transformation and trace where each index ended up
|
||||
ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256
|
||||
sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous()
|
||||
assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow"
|
||||
permuted_tile_i = transform_tile(sample_tile_i)
|
||||
ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128
|
||||
permuted_tile_indices += ith_permuted_indices * (256**i)
|
||||
if d1 * d2 < 256**i:
|
||||
break # if all indices fit in i bytes, stop early
|
||||
return permuted_tile_indices
|
||||
|
||||
|
||||
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
|
||||
"""
|
||||
Undo a tiled permutation such as turing or ampere layout
|
||||
|
||||
:param permuted_tensor: torch tensor in a permuted layout
|
||||
:param tile_indices: reverse transformation indices, from get_inverse_transform_indices
|
||||
:return: contiguous row-major tensor
|
||||
"""
|
||||
(rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
|
||||
assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
|
||||
tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
|
||||
outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda
|
||||
outputs[tile_indices.flatten()] = tensor
|
||||
outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows)
|
||||
outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols)
|
||||
return outputs.reshape(rows, cols).contiguous()
|
||||
|
||||
|
||||
# the rest of this file is just a patch to bitsandbytes that modifies Linear8bitLt and dependencies
|
||||
|
||||
|
||||
class CustomLinear8bitLt(Linear8bitLt):
|
||||
def __init__(self, *args, memory_efficient_backward: bool = False, **kwargs):
|
||||
assert not memory_efficient_backward, "memory_efficient_backward is no longer used"
|
||||
super().__init__(*args, **kwargs)
|
||||
old_state, self.state = self.state, CustomMatmulLtState()
|
||||
self.state.threshold = old_state.threshold
|
||||
self.state.has_fp16_weights = old_state.has_fp16_weights
|
||||
self.state.memory_efficient_backward = old_state.memory_efficient_backward
|
||||
if old_state.threshold > 0.0 and not old_state.has_fp16_weights:
|
||||
self.state.use_pool = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
self.state.is_training = self.training
|
||||
if self.weight.CB is not None:
|
||||
self.init_8bit_state()
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
out = custom_matmul8bitlt(x, self.weight, bias=self.bias, state=self.state)
|
||||
if not self.state.has_fp16_weights:
|
||||
if self.state.CB is not None and self.state.CxB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del self.state.CB
|
||||
self.weight.data = self.state.CxB
|
||||
return out
|
||||
|
||||
|
||||
@dataclasses.dataclass(init=True)
|
||||
class CustomMatmulLtState(MatmulLtState):
|
||||
tile_indices: Optional[torch.Tensor] = None
|
||||
force_no_igemmlt: bool = False
|
||||
|
||||
def get_tile_size(self):
|
||||
assert self.formatB in (
|
||||
"col_turing",
|
||||
"col_ampere",
|
||||
), f"please find this assert and manually enter tile size for {self.formatB}"
|
||||
return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
||||
|
||||
|
||||
def custom_matmul8bitlt(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out: torch.Tensor = None,
|
||||
state: CustomMatmulLtState = None,
|
||||
threshold=0.0,
|
||||
bias=None,
|
||||
):
|
||||
state = state or MatmulLtState()
|
||||
if threshold > 0.0:
|
||||
state.threshold = threshold
|
||||
return CustomMatMul8bitLt.apply(A, B, out, bias, state)
|
||||
|
||||
|
||||
class CustomMatMul8bitLt(MatMul8bitLt):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, bias=None, state=CustomMatmulLtState):
|
||||
using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt
|
||||
# default to pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
ctx.bias = bias
|
||||
if A.shape[-1] == B.shape[0]:
|
||||
return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Quantize A
|
||||
# 2. Quantize B
|
||||
# 3. Matmul
|
||||
# 4. Mixed-precision decomposition matmul
|
||||
# 5. Save state
|
||||
formatB = state.formatB
|
||||
input_shape = A.shape
|
||||
if state.outlier_pool is None:
|
||||
state.outlier_pool = GlobalOutlierPooler.get_instance()
|
||||
|
||||
# Cast A to fp16
|
||||
if A.dtype != torch.float16:
|
||||
logging.debug(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
|
||||
|
||||
# 1. Quantize A
|
||||
if len(A.shape) == 3:
|
||||
A = A.view(-1, A.shape[-1]).contiguous()
|
||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
|
||||
|
||||
if state.threshold > 0.0 and coo_tensorA is not None:
|
||||
if state.has_fp16_weights:
|
||||
idx = torch.unique(coo_tensorA.colidx).long()
|
||||
CA[:, idx] = 0
|
||||
CAt[:, idx] = 0
|
||||
subA = A[:, idx]
|
||||
state.subB = B[:, idx].t().contiguous()
|
||||
state.idx = idx
|
||||
else:
|
||||
if state.CxB is None and using_igemmlt:
|
||||
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
||||
# we also need to convert it to the turing/ampere format
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
else:
|
||||
if not state.has_fp16_weights and state.CxB is None and using_igemmlt:
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
subA = None
|
||||
|
||||
# 2. Quantize B
|
||||
if state.has_fp16_weights:
|
||||
has_grad = True if (getattr(B, "grad", None) is not None) else False
|
||||
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
|
||||
if is_transposed:
|
||||
B = B.contiguous()
|
||||
|
||||
if (state.is_training and not has_grad) or state.CxB is None:
|
||||
state.reset_grads()
|
||||
(
|
||||
CB,
|
||||
state.CBt,
|
||||
state.SCB,
|
||||
state.SCBt,
|
||||
coo_tensorB,
|
||||
) = F.double_quant(B.to(torch.float16))
|
||||
if using_igemmlt:
|
||||
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
||||
else:
|
||||
state.CB = CB
|
||||
else:
|
||||
has_grad = False
|
||||
|
||||
if coo_tensorA is not None and not state.has_fp16_weights:
|
||||
# extract outliers
|
||||
|
||||
outlier_idx = torch.unique(coo_tensorA.colidx)
|
||||
state.idx = outlier_idx
|
||||
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
|
||||
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
|
||||
# # do not use pool for 2nd FFN layer
|
||||
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
|
||||
# else:
|
||||
# state.idx = outlier_idx
|
||||
if state.CxB is not None:
|
||||
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
||||
else:
|
||||
outliers = state.CB[:, state.idx.long()].clone()
|
||||
|
||||
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
|
||||
CA[:, state.idx.long()] = 0
|
||||
CAt[:, state.idx.long()] = 0
|
||||
subA = A[:, state.idx.long()]
|
||||
|
||||
shapeB = state.SB[0] if state.SB else B.shape
|
||||
|
||||
if len(input_shape) == 3:
|
||||
output_shape = (input_shape[0], input_shape[1], shapeB[0])
|
||||
else:
|
||||
output_shape = (input_shape[0], shapeB[0])
|
||||
|
||||
# 3. Matmul
|
||||
if using_igemmlt:
|
||||
C32A, SA = F.transform(CA, "col32")
|
||||
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
||||
if bias is None or bias.dtype == torch.float16:
|
||||
# we apply the fused bias here
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
||||
output = output.to(A.dtype)
|
||||
else: # apply bias separately
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
||||
output = output.to(A.dtype).add_(bias)
|
||||
|
||||
else:
|
||||
A_wo_outliers = A.clone()
|
||||
if state.idx is not None:
|
||||
A_wo_outliers[:, state.idx.long()] = 0
|
||||
output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
|
||||
output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))
|
||||
if bias is not None:
|
||||
output = output.add_(bias)
|
||||
|
||||
# 4. Mixed-precision decomposition matmul
|
||||
if coo_tensorA is not None and subA is not None:
|
||||
output += torch.matmul(subA, state.subB)
|
||||
|
||||
# 5. Save state
|
||||
ctx.state = state
|
||||
|
||||
ctx.formatB = formatB
|
||||
ctx.grad_shape = input_shape
|
||||
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
ctx.tensors = (CAt, subA)
|
||||
ctx.tensor_states = (SCAt, state.idx)
|
||||
else:
|
||||
ctx.tensors = [None, None]
|
||||
ctx.tensor_states = (None, None)
|
||||
ctx.save_for_backward(None, None)
|
||||
|
||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
|
||||
return clone_func(output.view(output_shape))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
|
||||
CAt, subA = ctx.tensors
|
||||
SCAt, idx = ctx.tensor_states
|
||||
formatB = ctx.formatB
|
||||
state = ctx.state
|
||||
grad_A = grad_B = grad_bias = None
|
||||
|
||||
if req_gradBias:
|
||||
# compute grad_bias first before changing grad_output dtype
|
||||
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
|
||||
|
||||
# Cast grad_output to fp16
|
||||
if len(grad_output.shape) == 3:
|
||||
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||
|
||||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
|
||||
if req_gradB:
|
||||
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
|
||||
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
|
||||
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
|
||||
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
|
||||
if state.threshold > 0.0 and subA is not None:
|
||||
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
|
||||
|
||||
if req_gradA:
|
||||
if state.CBt is not None:
|
||||
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
||||
if state.CxBt is None:
|
||||
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
|
||||
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
|
||||
elif state.CB is not None:
|
||||
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
|
||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
elif state.CxB is not None:
|
||||
|
||||
if state.tile_indices is None:
|
||||
order, tile_size = state.formatB, state.get_tile_size()
|
||||
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
|
||||
with torch.no_grad():
|
||||
state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device)
|
||||
|
||||
CB = (
|
||||
undo_layout(state.CxB, state.tile_indices)
|
||||
.to(ctx.dtype_A)
|
||||
.mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
|
||||
)
|
||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
else:
|
||||
raise Exception("State must contain either CBt or CB or CxB matrix for backward")
|
||||
|
||||
return grad_A, grad_B, None, grad_bias, None
|
@ -0,0 +1,288 @@
|
||||
import contextlib
|
||||
import re
|
||||
import time
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from hivemind.utils.logging import get_logger
|
||||
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
|
||||
from peft.tuners import lora
|
||||
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file
|
||||
from transformers.utils import get_file_from_repo
|
||||
|
||||
from petals.server.block_utils import resolve_block_dtype
|
||||
from petals.utils.convert_block import QuantType
|
||||
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def check_peft_repository(repo_id: str) -> bool:
|
||||
fs = HfFileSystem()
|
||||
list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
|
||||
return len(list_of_files) > 0
|
||||
|
||||
|
||||
def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
|
||||
tensors = dict()
|
||||
is_tensors_found = dict()
|
||||
common_layer_patter_re = (
|
||||
".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+"
|
||||
)
|
||||
with safe_open(filepath, framework=framework, device=device) as f:
|
||||
for k in f.keys():
|
||||
if re.match(common_layer_patter_re, k):
|
||||
is_tensors_found[block_idx] = True
|
||||
tensors[k] = f.get_tensor(k)
|
||||
if not is_tensors_found.get(block_idx, False):
|
||||
logger.warning(f"There is no peft weights for block {block_idx}")
|
||||
return tensors
|
||||
|
||||
|
||||
def get_adapter_from_repo(
|
||||
repo_id: str,
|
||||
block_idx: Optional[int] = None,
|
||||
device: Optional[int] = None,
|
||||
*,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs)
|
||||
if config_path is None:
|
||||
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
||||
config = PeftConfig.from_json_file(config_path)
|
||||
|
||||
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, use_auth_token=token, **kwargs)
|
||||
if weight_path is None:
|
||||
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
||||
if block_idx is None:
|
||||
return config, load_file(weight_path)
|
||||
return config, load_specific_module(block_idx, weight_path, device=device)
|
||||
|
||||
|
||||
def load_peft(
|
||||
repo_id: str,
|
||||
block_idx: Optional[int] = None,
|
||||
device: Optional[int] = None,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
cache_dir: str,
|
||||
max_disk_space: Optional[int] = None,
|
||||
delay: float = 30,
|
||||
):
|
||||
# TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
|
||||
|
||||
if not check_peft_repository(repo_id):
|
||||
raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
|
||||
|
||||
try:
|
||||
with allow_cache_reads(cache_dir):
|
||||
return get_adapter_from_repo(
|
||||
repo_id,
|
||||
block_idx,
|
||||
device,
|
||||
revision=revision,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
|
||||
|
||||
while True:
|
||||
try:
|
||||
with allow_cache_writes(cache_dir):
|
||||
config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
|
||||
config_file_size = get_hf_file_metadata(config_url, token=token).size
|
||||
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
||||
weight_file_size = get_hf_file_metadata(weight_url, token=token).size
|
||||
|
||||
file_size = config_file_size + weight_file_size
|
||||
if file_size is not None:
|
||||
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
||||
else:
|
||||
logger.warning(f"Failed to fetch size from peft repo {repo_id}")
|
||||
|
||||
return get_adapter_from_repo(
|
||||
repo_id,
|
||||
block_idx,
|
||||
device,
|
||||
revision=revision,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
class AdapterContextMixin:
|
||||
"""A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""
|
||||
|
||||
ADAPTER_NOT_SET = "__ADAPTER_NOT_SET"
|
||||
_context_active_adapter = ADAPTER_NOT_SET
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def using_adapter(active_adapter: Optional[str]):
|
||||
prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
AdapterContextMixin._context_active_adapter = prev
|
||||
|
||||
@property
|
||||
def active_adapter(self):
|
||||
if self._context_active_adapter == self.ADAPTER_NOT_SET:
|
||||
logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug")
|
||||
return self._context_active_adapter
|
||||
|
||||
@active_adapter.setter
|
||||
def active_adapter(self, value: Optional[str]):
|
||||
assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""
|
||||
|
||||
|
||||
using_adapter = AdapterContextMixin.using_adapter
|
||||
|
||||
|
||||
class LoraLinear(lora.Linear, AdapterContextMixin):
|
||||
"""LoRA linear layer that uses adapter selected via using_adapter"""
|
||||
|
||||
|
||||
class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin):
|
||||
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
|
||||
|
||||
|
||||
class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin):
|
||||
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""
|
||||
|
||||
|
||||
def create_lora_adapter(block, quant_type: QuantType):
|
||||
for _, module in block.named_modules():
|
||||
for child_name, child in module.named_children():
|
||||
lora_wrapped_child = None
|
||||
if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
|
||||
continue
|
||||
if quant_type == QuantType.INT8:
|
||||
kwargs = {
|
||||
"has_fp16_weights": False,
|
||||
"threshold": 6.0,
|
||||
"bias": hasattr(child, "bias") and child.bias is not None,
|
||||
}
|
||||
lora_wrapped_child = LoraLinear8bitLt(
|
||||
AdapterContextMixin.ADAPTER_NOT_SET,
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
**kwargs,
|
||||
)
|
||||
elif quant_type == QuantType.NF4:
|
||||
kwargs = {
|
||||
"compress_statistics": True,
|
||||
"quant_type": "nf4",
|
||||
"blocksize": 64,
|
||||
"bias": hasattr(child, "bias") and child.bias is not None,
|
||||
}
|
||||
lora_wrapped_child = LoraLinear4bit(
|
||||
AdapterContextMixin.ADAPTER_NOT_SET,
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
**kwargs,
|
||||
)
|
||||
lora_wrapped_child.compute_dtype = child.compute_dtype
|
||||
else:
|
||||
bias = hasattr(child, "bias") and child.bias is not None
|
||||
lora_wrapped_child = LoraLinear(
|
||||
AdapterContextMixin.ADAPTER_NOT_SET,
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=bias,
|
||||
)
|
||||
if lora_wrapped_child:
|
||||
lora_wrapped_child.weight = child.weight
|
||||
lora_wrapped_child.bias = child.bias
|
||||
for p in lora_wrapped_child.parameters():
|
||||
p.requires_grad = False
|
||||
setattr(module, child_name, lora_wrapped_child)
|
||||
|
||||
|
||||
def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
|
||||
assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
|
||||
if peft_config["lora_dropout"] > 0:
|
||||
logger.info(f"Adapter {adapter_name} has dropout enabled, this server will disable dropout")
|
||||
|
||||
for _, module in block.named_modules():
|
||||
for child_name, child in module.named_children():
|
||||
if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
|
||||
continue
|
||||
|
||||
if child_name in peft_config["target_modules"] or (
|
||||
isinstance(peft_config["target_modules"], str)
|
||||
and re.fullmatch(peft_config["target_modules"], child_name)
|
||||
):
|
||||
is_lora_a_loaded = False
|
||||
is_lora_b_loaded = False
|
||||
for peft_key in peft_state_dict:
|
||||
if child_name not in peft_key:
|
||||
continue
|
||||
|
||||
if adapter_name not in child.lora_A:
|
||||
child.update_layer(
|
||||
adapter_name,
|
||||
peft_config["r"],
|
||||
peft_config["lora_alpha"],
|
||||
lora_dropout=peft_config["lora_dropout"],
|
||||
init_lora_weights=peft_config["init_lora_weights"],
|
||||
)
|
||||
child.train(False)
|
||||
for p in child.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
if peft_key.endswith(".lora_A.weight"):
|
||||
child.lora_A[adapter_name].weight[...] = peft_state_dict[peft_key]
|
||||
is_lora_a_loaded = True
|
||||
elif peft_key.endswith(".lora_A.bias"):
|
||||
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
||||
elif peft_key.endswith(".lora_B.weight"):
|
||||
child.lora_B[adapter_name].weight[...] = peft_state_dict[peft_key]
|
||||
is_lora_b_loaded = True
|
||||
elif peft_key.endswith(".lora_B.bias"):
|
||||
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
||||
|
||||
if is_lora_a_loaded and is_lora_b_loaded:
|
||||
logger.debug(f"Loaded adapter {adapter_name} for block {block_index}.{child_name}")
|
||||
elif is_lora_a_loaded or is_lora_b_loaded:
|
||||
raise ValueError(f"Invalid adapter {adapter_name} for block {block_index}.{child_name}")
|
||||
logger.info(f"Loaded adapter {adapter_name} for block {block_index}")
|
||||
|
||||
|
||||
def estimate_adapter_memory_per_block(
|
||||
block_config: transformers.PretrainedConfig,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
adapters: Sequence[str],
|
||||
**load_peft_kwargs,
|
||||
) -> int:
|
||||
"""Get the number of extra bytes used to store a set of adapters per given block"""
|
||||
with init_empty_weights(include_buffers=True):
|
||||
block = block_config.block_class(block_config)
|
||||
base_block_parameters = sum(p.numel() for p in block.parameters())
|
||||
create_lora_adapter(block, quant_type=QuantType.NONE)
|
||||
|
||||
for adapter in adapters:
|
||||
peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)
|
||||
assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now"
|
||||
add_adapter_to_block(
|
||||
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
|
||||
)
|
||||
adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
|
||||
bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
|
||||
return adapter_parameters * bytes_per_parameter
|
@ -0,0 +1,64 @@
|
||||
import asyncio
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import hivemind
|
||||
from hivemind.proto import dht_pb2
|
||||
from hivemind.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def ping(
|
||||
peer_id: hivemind.PeerID,
|
||||
_dht: hivemind.DHT,
|
||||
node: hivemind.dht.DHTNode,
|
||||
*,
|
||||
wait_timeout: float = 5,
|
||||
) -> float:
|
||||
try:
|
||||
ping_request = dht_pb2.PingRequest(peer=node.protocol.node_info)
|
||||
start_time = time.perf_counter()
|
||||
await node.protocol.get_stub(peer_id).rpc_ping(ping_request, timeout=wait_timeout)
|
||||
return time.perf_counter() - start_time
|
||||
except Exception as e:
|
||||
if str(e) == "protocol not supported": # Happens on servers with client-mode DHT (e.g., reachable via relays)
|
||||
return time.perf_counter() - start_time
|
||||
|
||||
logger.debug(f"Failed to ping {peer_id}:", exc_info=True)
|
||||
return math.inf
|
||||
|
||||
|
||||
async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) -> Dict[hivemind.PeerID, float]:
|
||||
rpc_infos = await asyncio.gather(*[ping(peer_id, *args, **kwargs) for peer_id in peer_ids])
|
||||
return dict(zip(peer_ids, rpc_infos))
|
||||
|
||||
|
||||
class PingAggregator:
|
||||
def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 300):
|
||||
self.dht = dht
|
||||
self.ema_alpha = ema_alpha
|
||||
self.expiration = expiration
|
||||
self.ping_emas = hivemind.TimedStorage()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs) -> None:
|
||||
current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs))
|
||||
logger.debug(f"Current RTTs: {current_rtts}")
|
||||
|
||||
with self.lock:
|
||||
expiration = hivemind.get_dht_time() + self.expiration
|
||||
for peer_id, rtt in current_rtts.items():
|
||||
prev_rtt = self.ping_emas.get(peer_id)
|
||||
if prev_rtt is not None and prev_rtt.value != math.inf:
|
||||
rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value # Exponential smoothing
|
||||
self.ping_emas.store(peer_id, rtt, expiration)
|
||||
|
||||
def to_dict(self) -> Dict[hivemind.PeerID, float]:
|
||||
with self.lock, self.ping_emas.freeze():
|
||||
smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()}
|
||||
logger.debug(f"Smothed RTTs: {smoothed_rtts}")
|
||||
return smoothed_rtts
|
@ -0,0 +1,12 @@
|
||||
import random
|
||||
from typing import Collection, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def sample_up_to(population: Collection[T], k: int) -> T:
|
||||
if not isinstance(population, list):
|
||||
population = list(population)
|
||||
if len(population) > k:
|
||||
population = random.sample(population, k)
|
||||
return population
|
@ -0,0 +1,44 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
import requests
|
||||
from hivemind.utils.logging import TextStyle, get_logger
|
||||
from packaging.version import parse
|
||||
|
||||
import petals
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def validate_version() -> None:
|
||||
logger.info(f"Running {TextStyle.BOLD}Petals {petals.__version__}{TextStyle.RESET}")
|
||||
try:
|
||||
r = requests.get("https://pypi.python.org/pypi/petals/json")
|
||||
r.raise_for_status()
|
||||
response = r.json()
|
||||
|
||||
versions = [parse(ver) for ver in response.get("releases")]
|
||||
latest = max(ver for ver in versions if not ver.is_prerelease)
|
||||
|
||||
if parse(petals.__version__) < latest:
|
||||
logger.info(
|
||||
f"A newer version {latest} is available. Please upgrade with: "
|
||||
f"{TextStyle.BOLD}pip install --upgrade petals{TextStyle.RESET}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch the latest Petals version from PyPI:", exc_info=True)
|
||||
|
||||
|
||||
def get_compatible_model_repo(model_name_or_path: Union[str, os.PathLike, None]) -> Union[str, os.PathLike, None]:
|
||||
if model_name_or_path is None:
|
||||
return None
|
||||
|
||||
match = re.fullmatch(r"(bigscience/.+)-petals", str(model_name_or_path))
|
||||
if match is None:
|
||||
return model_name_or_path
|
||||
|
||||
logger.info(
|
||||
f"Loading model from {match.group(1)}, since Petals 1.2.0+ uses original repos instead of converted ones"
|
||||
)
|
||||
return match.group(1)
|
@ -1,25 +0,0 @@
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
from huggingface_hub import delete_repo, list_models
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Remove old testing models from HF hub")
|
||||
parser.add_argument("--author", type=str, default="bloom-testing", help="auth token for from_pretrained")
|
||||
parser.add_argument("--seconds_since_last_updated", type=int, default=7 * 24 * 60 * 60)
|
||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
||||
parser.add_argument("--dry_run", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
for model in list_models(author=args.author, full=True):
|
||||
last_modified = datetime.strptime(model.lastModified, "%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
|
||||
if model.modelId.endswith("-main") or "/test-" not in model.modelId:
|
||||
continue # remove only test models
|
||||
|
||||
if (datetime.now() - last_modified).total_seconds() > args.seconds_since_last_updated:
|
||||
if args.dry_run:
|
||||
print(f"{model.modelId} can be deleted")
|
||||
else:
|
||||
delete_repo(repo_id=model.modelId, token=args.use_auth_token)
|
Binary file not shown.
@ -0,0 +1,16 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from petals.server.block_utils import resolve_block_dtype
|
||||
from petals.server.from_pretrained import load_pretrained_block
|
||||
from petals.utils.auto_config import AutoDistributedConfig
|
||||
from test_utils import MODEL_NAME
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"])
|
||||
def test_block_dtype(torch_dtype):
|
||||
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
||||
block = load_pretrained_block(MODEL_NAME, 0, config=config, torch_dtype=torch_dtype)
|
||||
expected_dtype = resolve_block_dtype(config, torch_dtype)
|
||||
assert all(param.dtype == expected_dtype for param in block.parameters())
|
@ -1,108 +0,0 @@
|
||||
import bitsandbytes as bnb
|
||||
import pytest
|
||||
import torch
|
||||
from bitsandbytes import functional as F
|
||||
|
||||
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt, get_inverse_transform_indices, undo_layout
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
|
||||
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
|
||||
)
|
||||
def test_layout_exact_match():
|
||||
x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda()
|
||||
for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"):
|
||||
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
|
||||
tile_indices = get_inverse_transform_indices(transform, tile_size)
|
||||
cxb = transform(x)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
restored_x = undo_layout(cxb, tile_indices)
|
||||
torch.cuda.synchronize()
|
||||
assert restored_x.is_contiguous()
|
||||
assert torch.all(torch.eq(restored_x, x))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
|
||||
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
|
||||
)
|
||||
def test_linear_exact_match():
|
||||
linear = torch.nn.Linear(1024, 3072)
|
||||
x = torch.randn(3, 1024, dtype=torch.half)
|
||||
linear8bitlt = bnb.nn.Linear8bitLt(
|
||||
linear.in_features,
|
||||
linear.out_features,
|
||||
linear.bias is not None,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
memory_efficient_backward=True,
|
||||
)
|
||||
linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to(
|
||||
linear.weight.dtype
|
||||
)
|
||||
linear8bitlt.bias = linear.bias
|
||||
linear8bitlt.cuda()
|
||||
|
||||
linear_custom = CustomLinear8bitLt(
|
||||
linear.in_features,
|
||||
linear.out_features,
|
||||
linear.bias is not None,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
)
|
||||
linear_custom.weight = bnb.nn.Int8Params(
|
||||
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
|
||||
).to(linear.weight.dtype)
|
||||
linear_custom.bias = linear.bias
|
||||
linear_custom.cuda()
|
||||
|
||||
x_ref = x.clone().cuda().requires_grad_(True)
|
||||
x_ours = x.clone().cuda().requires_grad_(True)
|
||||
fx_ref = linear8bitlt(x_ref).float()
|
||||
grad_proj = torch.randn_like(fx_ref)
|
||||
(fx_ref * grad_proj).mean().backward()
|
||||
|
||||
fx_ours = linear_custom(x_ours).float()
|
||||
(fx_ours * grad_proj).mean().backward()
|
||||
assert torch.equal(fx_ref, fx_ours)
|
||||
assert torch.allclose(x_ref.grad, x_ours.grad)
|
||||
assert not linear_custom.state.has_fp16_weights
|
||||
assert linear_custom.state.CB is None
|
||||
assert linear_custom.state.CxB is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
def test_linear_no_igemmlt():
|
||||
linear = torch.nn.Linear(1024, 3072)
|
||||
x = torch.randn(3, 1024, dtype=torch.half)
|
||||
linear_custom = CustomLinear8bitLt(
|
||||
linear.in_features,
|
||||
linear.out_features,
|
||||
linear.bias is not None,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
)
|
||||
linear_custom.state.force_no_igemmlt = True
|
||||
|
||||
linear_custom.weight = bnb.nn.Int8Params(
|
||||
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
|
||||
).to(linear.weight.dtype)
|
||||
linear_custom.bias = linear.bias
|
||||
linear_custom.cuda()
|
||||
linear.half().cuda()
|
||||
|
||||
x_ref = x.clone().cuda().requires_grad_(True)
|
||||
x_ours = x.clone().cuda().requires_grad_(True)
|
||||
fx_ref = linear(x_ref).float()
|
||||
grad_proj = torch.randn_like(fx_ref)
|
||||
(fx_ref * grad_proj).mean().backward()
|
||||
|
||||
fx_ours = linear_custom(x_ours).float()
|
||||
(fx_ours * grad_proj).mean().backward()
|
||||
assert torch.allclose(fx_ref, fx_ours, atol=0.02)
|
||||
assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)
|
||||
assert not linear_custom.state.has_fp16_weights
|
||||
assert linear_custom.state.CB is not None
|
||||
assert linear_custom.state.CxB is None
|
@ -0,0 +1,66 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from petals.utils.peft import check_peft_repository, load_peft
|
||||
|
||||
UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
|
||||
SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft"
|
||||
TMP_CACHE_DIR = "tmp_cache/"
|
||||
|
||||
|
||||
def clear_dir(path_to_dir):
|
||||
shutil.rmtree(path_to_dir)
|
||||
os.mkdir(path_to_dir)
|
||||
|
||||
|
||||
def dir_empty(path_to_dir):
|
||||
files = os.listdir(path_to_dir)
|
||||
return len(files) == 0
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_check_peft():
|
||||
assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load."
|
||||
assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load."
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_noncached(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
with pytest.raises(Exception):
|
||||
load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded"
|
||||
|
||||
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded"
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_cached(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_layer_exists(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
|
||||
load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir)
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_layer_nonexists(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
|
||||
load_peft(
|
||||
SAFE_PEFT_REPO,
|
||||
block_idx=1337,
|
||||
cache_dir=tmpdir,
|
||||
)
|
@ -0,0 +1,39 @@
|
||||
import time
|
||||
|
||||
import hivemind
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from petals import AutoDistributedConfig, RemoteSequential
|
||||
from petals.server.handler import CACHE_TOKENS_AVAILABLE
|
||||
from test_utils import *
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50):
|
||||
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
||||
config.allowed_servers = ["QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX"] # PeerID from server2.id
|
||||
|
||||
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
||||
blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to)
|
||||
blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to)
|
||||
|
||||
info_before = blocks1.sequence_manager.rpc_info
|
||||
|
||||
with blocks1.inference_session(max_length=max_length) as sess:
|
||||
sess.step(torch.randn(1, 1, config.hidden_size))
|
||||
blocks1.sequence_manager.state.rpc_info = None # invalidate cache
|
||||
info_inside = blocks1.sequence_manager.rpc_info
|
||||
|
||||
with blocks2.inference_session(max_length=max_length2) as sess2:
|
||||
sess2.step(torch.randn(1, 1, config.hidden_size))
|
||||
blocks2.sequence_manager.state.rpc_info = None # invalidate cache
|
||||
info_inside2 = blocks2.sequence_manager.rpc_info
|
||||
|
||||
time.sleep(0.1)
|
||||
blocks1.sequence_manager.state.rpc_info = None # invalidate cache
|
||||
info_after = blocks1.sequence_manager.rpc_info
|
||||
|
||||
assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE]
|
||||
assert info_before[CACHE_TOKENS_AVAILABLE] - info_inside[CACHE_TOKENS_AVAILABLE] == max_length * len(blocks1)
|
||||
assert info_inside[CACHE_TOKENS_AVAILABLE] - info_inside2[CACHE_TOKENS_AVAILABLE] == max_length2 * len(blocks2)
|
@ -0,0 +1,49 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from tensor_parallel import TensorParallel
|
||||
from tensor_parallel.slicing_configs import get_bloom_config
|
||||
|
||||
from petals.server.from_pretrained import load_pretrained_block
|
||||
from test_utils import MODEL_NAME
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.parametrize("custom_config", [True, False])
|
||||
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4])
|
||||
def test_tp_block(devices, custom_config):
|
||||
model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
|
||||
if model_config.model_type != "bloom":
|
||||
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
|
||||
|
||||
block_index = random.randint(0, 10)
|
||||
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])
|
||||
|
||||
tp_config = None
|
||||
if custom_config:
|
||||
tp_config = get_bloom_config(model_config, devices)
|
||||
|
||||
batch_size = 2
|
||||
prefix_length = 5
|
||||
|
||||
test_inputs1 = torch.randn(batch_size, 3, 1024, requires_grad=True, device=devices[0])
|
||||
test_inputs2 = test_inputs1.detach().clone().requires_grad_(True)
|
||||
test_prefix1 = torch.randn(batch_size, prefix_length, 1024, requires_grad=True, device=devices[0])
|
||||
test_prefix2 = test_prefix1.detach().clone().requires_grad_(True)
|
||||
grad_proj = torch.rand_like(test_inputs1)
|
||||
|
||||
y_prefix_ref, layer_past = block(test_prefix1, use_cache=True)
|
||||
y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past)
|
||||
y_ref.backward(grad_proj)
|
||||
|
||||
block_tp = TensorParallel(block, devices, config=tp_config)
|
||||
y_prefix, layer_past = block_tp(test_prefix2, use_cache=True)
|
||||
y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
|
||||
y_ours.backward(grad_proj)
|
||||
|
||||
assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-5)
|
||||
assert torch.allclose(y_ours, y_ref, atol=1e-5)
|
||||
assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)
|
||||
assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)
|
Loading…
Reference in New Issue