Upgrade BLOOM

pull/464/head
Aleksandr Borzunov 10 months ago
parent a1a0d30b77
commit 062cf519b1

@ -10,7 +10,7 @@ from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassifi
from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_sequential import RemoteSequential
from petals.models.bloom.config import DistributedBloomConfig
@ -39,16 +39,15 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
for k, v in kwargs.items():
if not (v is None or v is False):
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -59,21 +58,33 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
position = self.h.active_session.position if self.h.active_session is not None else 0
# The causal mask will be added on the server-side
assert attention_mask is None or (attention_mask == 1).all(), "Custom attention masks are not supported"
assert head_mask is None, "Custom head masks are not supported"
assert use_cache is None or use_cache, "use_cache=False is not supported"
assert not output_attentions, "output_attentions=True is not supported"
assert not output_hidden_states, "output_hidden_states=True is not supported"
assert return_dict is None or return_dict, "return_dict=True is not supported"
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and position == 0:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
else:
prompts = intermediate_prompts = None
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
else:
hidden_states = self.h(hidden_states)
hidden_states = self.h(
hidden_states,
prompts=intermediate_prompts,
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
# Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
@ -84,7 +95,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
past_key_values=RemotePastKeyValues(),
hidden_states=None,
attentions=None,
)

@ -1,14 +1,13 @@
from typing import List, Optional, Tuple, Union
from typing import Optional
import hivemind
import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.inference_session import InferenceSession
from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
@ -49,19 +48,6 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BaseModelOutputWithPast:
position = self.layers.active_session.position if self.layers.active_session is not None else 0
n_input_tokens = input_ids.shape[1] if input_ids is not None else 0
# The causal mask will be added on the server-side
assert attention_mask is None or (attention_mask == 1).all(), "Custom attention masks are not supported"
if position_ids is not None:
expected = torch.arange(position, position + n_input_tokens, dtype=torch.long, device=position_ids.device)
assert (position_ids == expected).all(), "Custom position_ids are not supported"
assert use_cache is None or use_cache, "use_cache=False is not supported"
assert not output_attentions, "output_attentions=True is not supported"
assert not output_hidden_states, "output_hidden_states=True is not supported"
assert return_dict is None or return_dict, "return_dict=True is not supported"
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -72,6 +58,17 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
position = self.layers.active_session.position if self.layers.active_session is not None else 0
# The causal mask will be added on the server-side
assert attention_mask is None or (attention_mask == 1).all(), "Custom attention masks are not supported"
if position_ids is not None:
expected = torch.arange(position, position + input_shape[1], dtype=torch.long, device=position_ids.device)
assert (position_ids == expected).all(), "Custom position_ids are not supported"
assert use_cache is None or use_cache, "use_cache=False is not supported"
assert not output_attentions, "output_attentions=True is not supported"
assert not output_hidden_states, "output_hidden_states=True is not supported"
assert return_dict is None or return_dict, "return_dict=True is not supported"
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

Loading…
Cancel
Save