|
|
|
@ -10,7 +10,7 @@ from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassifi
|
|
|
|
|
from petals.client.from_pretrained import FromPretrainedMixin
|
|
|
|
|
from petals.client.lm_head import LMHead
|
|
|
|
|
from petals.client.ptune import PTuneMixin
|
|
|
|
|
from petals.client.remote_generation import RemoteGenerationMixin
|
|
|
|
|
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
|
|
|
|
|
from petals.client.remote_sequential import RemoteSequential
|
|
|
|
|
from petals.models.bloom.config import DistributedBloomConfig
|
|
|
|
|
|
|
|
|
@ -39,16 +39,15 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
|
|
|
past_key_values: Optional[RemotePastKeyValues] = None,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
head_mask: Optional[torch.LongTensor] = None,
|
|
|
|
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
|
):
|
|
|
|
|
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
|
|
|
|
|
|
|
|
|
|
for k, v in kwargs.items():
|
|
|
|
|
if not (v is None or v is False):
|
|
|
|
|
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
|
|
|
|
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
|
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
|
|
elif input_ids is not None:
|
|
|
|
@ -59,21 +58,33 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
|
|
|
|
|
|
position = self.h.active_session.position if self.h.active_session is not None else 0
|
|
|
|
|
# The causal mask will be added on the server-side
|
|
|
|
|
assert attention_mask is None or (attention_mask == 1).all(), "Custom attention masks are not supported"
|
|
|
|
|
assert head_mask is None, "Custom head masks are not supported"
|
|
|
|
|
assert use_cache is None or use_cache, "use_cache=False is not supported"
|
|
|
|
|
assert not output_attentions, "output_attentions=True is not supported"
|
|
|
|
|
assert not output_hidden_states, "output_hidden_states=True is not supported"
|
|
|
|
|
assert return_dict is None or return_dict, "return_dict=True is not supported"
|
|
|
|
|
|
|
|
|
|
if inputs_embeds is None:
|
|
|
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
|
|
|
|
|
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
|
|
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and position == 0:
|
|
|
|
|
batch_size = inputs_embeds.shape[0]
|
|
|
|
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
|
|
|
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
|
|
|
else:
|
|
|
|
|
prompts = intermediate_prompts = None
|
|
|
|
|
|
|
|
|
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
|
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
|
|
|
|
|
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
|
|
|
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
|
|
|
|
else:
|
|
|
|
|
hidden_states = self.h(hidden_states)
|
|
|
|
|
hidden_states = self.h(
|
|
|
|
|
hidden_states,
|
|
|
|
|
prompts=intermediate_prompts,
|
|
|
|
|
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Remove prefix
|
|
|
|
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
|
|
@ -84,7 +95,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
|
|
|
|
|
hidden_states = hidden_states.view(output_shape)
|
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
|
|
|
last_hidden_state=hidden_states,
|
|
|
|
|
past_key_values=None,
|
|
|
|
|
past_key_values=RemotePastKeyValues(),
|
|
|
|
|
hidden_states=None,
|
|
|
|
|
attentions=None,
|
|
|
|
|
)
|
|
|
|
|