pull/554/head
Denis Mazur 4 months ago
parent 7ff5952499
commit 72e1f00599

@ -25,6 +25,9 @@ class RemotePastKeyValues(Cache):
self.seen_tokens = 0
self.hypo_ids: Optional[torch.LongTensor] = None
def __getitem__(self, _index: int) -> List[torch.Tensor]:
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return self.seen_tokens
@ -130,6 +133,7 @@ class RemoteGenerationMixin(_SkipTokensMixin):
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(session.position)
kwargs["past_key_values"] = past_key_values
result = super().generate(inputs, *args, **kwargs)
sequences = result.sequences if isinstance(result, ModelOutput) else result

@ -6,6 +6,7 @@ See commit history for authorship.
from typing import Optional, Tuple
import torch
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
@ -26,7 +27,13 @@ class WrappedBloomBlock(BloomBlock):
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_length,
)
attention_mask = attention_mask.bool()
return super().forward(
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
)

@ -6,6 +6,7 @@ import torch.nn as nn
from hivemind.utils.logging import get_logger
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
from transformers.cache_utils import Cache
from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
@ -92,12 +93,16 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
if use_prompts:
hidden_states = hidden_states[:, self.pre_seq_len :]
if past_key_values is None:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(hidden_states.size(1))
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=RemotePastKeyValues(),
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
@ -107,6 +112,7 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
_supports_cache_class = True
config_class = DistributedBloomConfig
@ -118,6 +124,61 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
# Initialize weights and apply final processing
self.post_init()
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
) -> dict:
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def _temporary_reorder_cache(self, past_key_values, beam_idx):
return past_key_values
def get_output_embeddings(self):
return self.lm_head

Loading…
Cancel
Save