Another one comment change

This commit is contained in:
Artem Chumachenko 2022-11-28 10:52:07 +04:00
parent ee1f56b492
commit 282f327425

View File

@ -152,7 +152,7 @@ class RemoteGenerationMixin:
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
last_token_id, hypo_ids = decoding_algorithm(lm_logits)
# If samples have padded, so changes only them
# If some samples were padded, change only these samples
if seq_idx < inputs.size(1):
pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
last_token_id = (~pad_token_mask) * inputs[