diff --git a/src/petals/models/falcon/model.py b/src/petals/models/falcon/model.py index 32c0b6f..296214d 100644 --- a/src/petals/models/falcon/model.py +++ b/src/petals/models/falcon/model.py @@ -47,6 +47,7 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[RemotePastKeyValues] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -68,6 +69,9 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix assert ( attention_mask is None or (attention_mask == 1).all() ), f"Custom attention masks are not supported, {attention_mask=}" + assert ( + position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all() + ), f"Non-consecutive position_ids are not supported, {position_ids=}" assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" assert use_cache is None or use_cache, f"{use_cache=} is not supported" assert not output_attentions, f"{output_attentions=} is not supported"