|
|
|
@ -4,6 +4,7 @@ import hivemind
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
from transformers.cache_utils import Cache
|
|
|
|
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
|
|
|
|
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
|
|
|
|
|
|
|
|
|
@ -92,12 +93,16 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
|
|
|
|
|
if use_prompts:
|
|
|
|
|
hidden_states = hidden_states[:, self.pre_seq_len :]
|
|
|
|
|
|
|
|
|
|
if past_key_values is None:
|
|
|
|
|
past_key_values = RemotePastKeyValues()
|
|
|
|
|
past_key_values.update_seen(hidden_states.size(1))
|
|
|
|
|
|
|
|
|
|
# Add last hidden state
|
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
|
hidden_states = hidden_states.view(output_shape)
|
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
|
|
|
last_hidden_state=hidden_states,
|
|
|
|
|
past_key_values=RemotePastKeyValues(),
|
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
|
hidden_states=None,
|
|
|
|
|
attentions=None,
|
|
|
|
|
)
|
|
|
|
@ -107,6 +112,7 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
|
|
|
|
|
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
|
|
|
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
|
|
|
|
|
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
|
|
|
|
_supports_cache_class = True
|
|
|
|
|
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
|
|
|
@ -118,6 +124,58 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
|
|
|
|
|
# Initialize weights and apply final processing
|
|
|
|
|
self.post_init()
|
|
|
|
|
|
|
|
|
|
def prepare_inputs_for_generation(
|
|
|
|
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
|
|
|
|
) -> dict:
|
|
|
|
|
# Omit tokens covered by past_key_values
|
|
|
|
|
if past_key_values is not None:
|
|
|
|
|
if isinstance(past_key_values, Cache):
|
|
|
|
|
cache_length = past_key_values.get_seq_length()
|
|
|
|
|
past_length = past_key_values.seen_tokens
|
|
|
|
|
max_cache_length = past_key_values.get_max_length()
|
|
|
|
|
else:
|
|
|
|
|
cache_length = past_length = past_key_values[0][0].shape[2]
|
|
|
|
|
max_cache_length = None
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
|
|
|
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
|
|
|
|
elif past_length < input_ids.shape[1]:
|
|
|
|
|
input_ids = input_ids[:, past_length:]
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
max_cache_length is not None
|
|
|
|
|
and attention_mask is not None
|
|
|
|
|
and cache_length + input_ids.shape[1] > max_cache_length
|
|
|
|
|
):
|
|
|
|
|
attention_mask = attention_mask[:, -max_cache_length:]
|
|
|
|
|
|
|
|
|
|
position_ids = kwargs.get("position_ids", None)
|
|
|
|
|
if attention_mask is not None and position_ids is None:
|
|
|
|
|
# create position_ids on the fly for batch generation
|
|
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
|
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
|
|
|
if past_key_values:
|
|
|
|
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
|
|
|
|
|
|
|
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None:
|
|
|
|
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
|
|
|
else:
|
|
|
|
|
model_inputs = {"input_ids": input_ids}
|
|
|
|
|
|
|
|
|
|
model_inputs.update(
|
|
|
|
|
{
|
|
|
|
|
"position_ids": position_ids,
|
|
|
|
|
"past_key_values": past_key_values,
|
|
|
|
|
"use_cache": kwargs.get("use_cache"),
|
|
|
|
|
"attention_mask": attention_mask,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
return model_inputs
|
|
|
|
|
|
|
|
|
|
def _temporary_reorder_cache(self, past_key_values, beam_idx):
|
|
|
|
|
return past_key_values
|
|
|
|
|
|
|
|
|
|
def get_output_embeddings(self):
|
|
|
|
|
return self.lm_head
|
|
|
|
|
|
|
|
|
|