@ -34,6 +34,9 @@ class _SkipTokensMixin:
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:
input_ids = input_ids[:, _skipped_tokens.get() :]
_skipped_tokens.set(0)
if "past_key_values" in kwargs:
if kwargs['past_key_values'][0][0].shape == torch.Size([0]):
kwargs['past_key_values'] = None
return super().prepare_inputs_for_generation(input_ids, **kwargs)