Don't cast logits to float32 on GPU

This commit is contained in:
Aleksandr Borzunov 2022-11-29 08:58:58 +00:00
parent 6e7565e41e
commit 86d08bf515

View File

@ -449,7 +449,7 @@ class LMHead(nn.Module):
else:
# Switch dtype in case word_embeddings are fp16/bf16
hidden_states = hidden_states.to(word_embeddings.dtype)
lm_logits = F.linear(hidden_states, word_embeddings).float()
lm_logits = F.linear(hidden_states, word_embeddings)
return lm_logits
def chunked_forward(self, hidden_states):