design interface & refactoring

pull/17/head
Dmitry Baranchuk 2 years ago
parent be83e6d0cb
commit e66ab6f1f2

@ -1 +1 @@
from src.bloom.model import BloomBlock, BloomForCausalLM, BloomModel, DistributedBloomConfig
from src.bloom.model import BloomBlock, BloomForYou, BloomModel, DistributedBloomConfig

@ -15,7 +15,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from transformers.modeling_utils import WEIGHTS_NAME
from transformers.utils.hub import cached_path, hf_bucket_url
from src.bloom import BloomBlock, BloomForCausalLM, DistributedBloomConfig
from src.bloom import BloomBlock, DistributedBloomConfig
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -23,7 +23,6 @@ logger = get_logger(__file__)
CLIENT_BRANCH = "client"
BLOCK_BRANCH_PREFIX = "block_"
USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
cls = BloomForCausalLM
FORCE_DOWNLOAD = False
RESUME_DOWNLOAD = False
LOCAL_FILES_ONLY = False

@ -4,8 +4,6 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
See commit history for authorship.
"""
from typing import Tuple
import torch
import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
@ -314,114 +312,17 @@ class BloomModel(BloomPreTrainedModel):
@add_start_docstrings(
"""
The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
The Bloom interface for various applications, e.g., inference, classification...
""",
BLOOM_START_DOCSTRING,
)
class BloomForCausalLM(BloomModel):
class BloomForYou(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
}
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = super().forward(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Switch dtype in case word_embeddings are fp16
word_embeddings = self.word_embeddings.weight.t()
hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
lm_logits = (hidden_states @ word_embeddings).float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def __init__(self, config):
super().__init__(config)
self.transformer = BloomModel(config)
self.lm_head = None
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
# Initialize weights and apply final processing
self.post_init()

@ -1,28 +1,33 @@
# this code is in active development, interfaces may change
import os
from typing import Optional, Union
from typing import Optional, Union, Tuple
import hivemind
from hivemind import DHT, get_logger, use_hivemind_log_handler
from src.bloom import BloomModel, BloomForCausalLM, DistributedBloomConfig
from src.bloom import BloomForYou, DistributedBloomConfig
from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
from src.client.remote_sequential import RemoteSequential
from src.data_structures import UID_DELIMITER
import torch
from hivemind import use_hivemind_log_handler
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class DistributedBloomForCausalLM(BloomForCausalLM):
"""BloomForCausalLM, but all transformer layers are hosted by the swarm"""
class DistributedBloomForYou(BloomForYou):
"""BloomModel, but all transformer layers are hosted by the swarm"""
def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
super().__init__(config)
assert len(self.h) == 0
assert len(self.transformer.h) == 0
config.n_layer = n_layer
self.h = RemoteSequential(config, dht, prefix)
self.transformer.h = RemoteSequential(config, dht, prefix)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
@ -41,7 +46,84 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
config = DistributedBloomConfig.from_pretrained(pretrained_model_name_or_path, revision=CLIENT_BRANCH, **kwargs)
model = cls(config, dht, prefix)
model.load_state_dict(_load_state_dict(
model.transformer.load_state_dict(_load_state_dict(
pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token')
), strict=True)
return model
class DistributedBloomForCausalLM(DistributedBloomForYou):
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
}
def forward(self, input_ids, labels=None, return_dict=None, **kwargs):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer.forward(
input_ids=input_ids, return_dict=return_dict, **kwargs
)
# Switch dtype in case word_embeddings are fp16
word_embeddings = self.transformer.word_embeddings.weight.t()
hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
lm_logits = (hidden_states @ word_embeddings).float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)

@ -24,9 +24,9 @@ if not MODEL_NAME:
REF_NAME = os.environ.get("REF_NAME")
def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="bloom6b3"):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS, prefix=prefix)
assert len(model.transformer.h) == model.config.n_layer
test_inputs = tokenizer("A cat sat on a mat", return_tensors='pt')['input_ids']
@ -52,6 +52,9 @@ def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
recurrent_outputs.append(sess.step(embs[:, t: t + 1, :]))
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
recurrent_outputs = model.lm_head(recurrent_outputs)
dictionary = model.transformer.word_embeddings.weight.t()
recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
recurrent_outputs = (recurrent_outputs @ dictionary).float()
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
logger.info("Inference is consistent with forward")

@ -1,60 +0,0 @@
# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
import os
import torch
import transformers
from hivemind import use_hivemind_log_handler, get_logger
from src.client.remote_model import DistributedBloomForCausalLM
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
if not INITIAL_PEERS:
raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
INITIAL_PEERS = INITIAL_PEERS.split()
MODEL_NAME = os.environ.get("MODEL_NAME")
if not MODEL_NAME:
raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
REF_NAME = os.environ.get("REF_NAME")
def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS, prefix="bloom6b3")
assert len(model.h) == model.config.n_layer
test_inputs = tokenizer("A cat sat on a mat", return_tensors='pt')['input_ids']
parallel_outputs = model.forward(test_inputs).logits
assert torch.all(torch.isfinite(parallel_outputs))
logger.info("Forward outputs are finite")
if REF_NAME:
ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
# note: this creates a dummy mask to make the test compatible with older transformer versions
# prior to https://github.com/huggingface/transformers/pull/17837
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
else:
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
embs = model.word_embeddings(test_inputs)
embs = model.word_embeddings_layernorm(embs.float())
recurrent_outputs = []
with model.h.inference_session() as sess:
for t in range(embs.shape[1]):
recurrent_outputs.append(sess.step(embs[:, t: t + 1, :]))
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
recurrent_outputs = model.ln_f(recurrent_outputs)
dictionary = model.word_embeddings.weight.t()
recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
recurrent_outputs = (recurrent_outputs @ dictionary).float()
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
logger.info("Inference is consistent with forward")
Loading…
Cancel
Save