|
|
|
@ -40,19 +40,23 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
past_key_values: Optional[RemotePastKeyValues] = None,
|
|
|
|
|
session: Optional[InferenceSession] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
|
) -> BaseModelOutputWithPast:
|
|
|
|
|
# FIXME: Assert that the mask is None or triangle
|
|
|
|
|
# FIXME: Assert that the mask is None or triangle and position_ids are valid
|
|
|
|
|
# assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
|
|
|
|
|
logger.warning(f"forward: {input_ids.shape=} {session=} {kwargs=}")
|
|
|
|
|
logger.warning(f"forward: {input_ids=} {self.layers.active_session=}")
|
|
|
|
|
|
|
|
|
|
for k, v in kwargs.items():
|
|
|
|
|
if not (v is None or v is False):
|
|
|
|
|
logger.warning(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
|
|
|
|
assert use_cache is None or use_cache, "use_cache=False is not supported"
|
|
|
|
|
assert not output_attentions, "output_attentions=True is not supported"
|
|
|
|
|
assert not output_hidden_states, "output_hidden_states=True is not supported"
|
|
|
|
|
assert return_dict is None or return_dict, "return_dict=True is not supported"
|
|
|
|
|
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
|
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
|
@ -70,7 +74,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
|
|
|
|
|
if (
|
|
|
|
|
self.config.tuning_mode
|
|
|
|
|
and "ptune" in self.config.tuning_mode
|
|
|
|
|
and (session is None or session.position == 0)
|
|
|
|
|
and (self.layers.active_session is None or self.layers.active_session.position == 0)
|
|
|
|
|
):
|
|
|
|
|
batch_size = inputs_embeds.shape[0]
|
|
|
|
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
|
|
|
@ -81,14 +85,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
|
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
|
|
|
|
|
|
if session is not None:
|
|
|
|
|
hidden_states = session.step(
|
|
|
|
|
hidden_states,
|
|
|
|
|
prompts=intermediate_prompts,
|
|
|
|
|
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
hidden_states = self.layers(hidden_states, prompts=intermediate_prompts)
|
|
|
|
|
hidden_states = self.layers(hidden_states, prompts=intermediate_prompts, hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None)
|
|
|
|
|
|
|
|
|
|
# Remove prefix
|
|
|
|
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
|
|
@ -137,74 +134,6 @@ class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Ll
|
|
|
|
|
# Initialize weights and apply final processing
|
|
|
|
|
self.post_init()
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
input_ids: torch.LongTensor = None,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
labels: Optional[torch.LongTensor] = None,
|
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
|
session: Optional[InferenceSession] = None,
|
|
|
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
|
output_hidden_states = (
|
|
|
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
|
)
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
|
|
|
outputs = self.model(
|
|
|
|
|
input_ids=input_ids,
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
position_ids=position_ids,
|
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
|
use_cache=use_cache,
|
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
|
return_dict=return_dict,
|
|
|
|
|
session=session,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
|
if self.pretraining_tp > 1:
|
|
|
|
|
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
|
|
|
|
|
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
|
|
|
|
|
logits = torch.cat(logits, dim=-1)
|
|
|
|
|
else:
|
|
|
|
|
logits = self.lm_head(hidden_states)
|
|
|
|
|
logits = logits.float()
|
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
|
if labels is not None:
|
|
|
|
|
# Shift so that tokens < n predict n
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
|
# Flatten the tokens
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
|
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
|
|
|
shift_labels = shift_labels.view(-1)
|
|
|
|
|
# Enable model parallelism
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device)
|
|
|
|
|
loss = loss_fct(shift_logits, shift_labels)
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
output = (logits,) + outputs[1:]
|
|
|
|
|
return (loss,) + output if loss is not None else output
|
|
|
|
|
|
|
|
|
|
return CausalLMOutputWithPast(
|
|
|
|
|
loss=loss,
|
|
|
|
|
logits=logits,
|
|
|
|
|
past_key_values=outputs.past_key_values,
|
|
|
|
|
hidden_states=outputs.hidden_states,
|
|
|
|
|
attentions=outputs.attentions,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def get_output_embeddings(self):
|
|
|
|
|
return self.lm_head
|
|
|
|
|
|
|
|
|
|