pull/545/head
younesbelkada 7 months ago
parent 5aebd3e8fc
commit 401e791700

@ -35,8 +35,8 @@ class _SkipTokensMixin:
input_ids = input_ids[:, _skipped_tokens.get() :]
_skipped_tokens.set(0)
if "past_key_values" in kwargs:
if kwargs['past_key_values'][0][0].shape == torch.Size([0]):
kwargs['past_key_values'] = None
if kwargs["past_key_values"][0][0].shape == torch.Size([0]):
kwargs["past_key_values"] = None
return super().prepare_inputs_for_generation(input_ids, **kwargs)

@ -27,7 +27,7 @@ class WrappedBloomBlock(BloomBlock):
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),

Loading…
Cancel
Save