diff --git a/src/client/remote_generation.py b/src/client/remote_generation.py index 29e3131..e2da719 100644 --- a/src/client/remote_generation.py +++ b/src/client/remote_generation.py @@ -152,7 +152,7 @@ class RemoteGenerationMixin: lm_logits = constraint(last_token_id, lm_logits, hypo_ids) last_token_id, hypo_ids = decoding_algorithm(lm_logits) - # If samples have padded, so changes only them + # If some samples were padded, change only these samples if seq_idx < inputs.size(1): pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id last_token_id = (~pad_token_mask) * inputs[