|
|
|
@ -49,10 +49,14 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
|
) -> BaseModelOutputWithPast:
|
|
|
|
|
# FIXME: Assert that the mask is None or triangle and position_ids are valid
|
|
|
|
|
# assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
|
|
|
|
|
logger.warning(f"forward: {input_ids=} {self.layers.active_session=}")
|
|
|
|
|
|
|
|
|
|
position = self.layers.active_session.position if self.layers.active_session is not None else 0
|
|
|
|
|
n_input_tokens = input_ids.shape[1] if input_ids is not None else 0
|
|
|
|
|
|
|
|
|
|
# The causal mask will be added on the server-side
|
|
|
|
|
assert attention_mask is None or (attention_mask == 1).all(), "Custom attention masks are not supported"
|
|
|
|
|
if position_ids is not None:
|
|
|
|
|
expected = torch.arange(position, position + n_input_tokens, dtype=torch.long, device=position_ids.device)
|
|
|
|
|
assert (position_ids == expected).all(), "Custom position_ids are not supported"
|
|
|
|
|
assert use_cache is None or use_cache, "use_cache=False is not supported"
|
|
|
|
|
assert not output_attentions, "output_attentions=True is not supported"
|
|
|
|
|
assert not output_hidden_states, "output_hidden_states=True is not supported"
|
|
|
|
@ -71,11 +75,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
|
|
|
|
|
if inputs_embeds is None:
|
|
|
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
self.config.tuning_mode
|
|
|
|
|
and "ptune" in self.config.tuning_mode
|
|
|
|
|
and (self.layers.active_session is None or self.layers.active_session.position == 0)
|
|
|
|
|
):
|
|
|
|
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and position == 0:
|
|
|
|
|
batch_size = inputs_embeds.shape[0]
|
|
|
|
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
|
|
|
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
|
|
|