past_key_values to None if zero shape

pull/545/head
younesbelkada 7 months ago
parent fa254cff02
commit 741b5394cc

@ -34,6 +34,9 @@ class _SkipTokensMixin:
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:
input_ids = input_ids[:, _skipped_tokens.get() :]
_skipped_tokens.set(0)
if "past_key_values" in kwargs:
if kwargs['past_key_values'][0][0].shape == torch.Size([0]):
kwargs['past_key_values'] = None
return super().prepare_inputs_for_generation(input_ids, **kwargs)

Loading…
Cancel
Save