mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
Don't cast logits to float32 on GPU
This commit is contained in:
parent
6e7565e41e
commit
86d08bf515
@ -449,7 +449,7 @@ class LMHead(nn.Module):
|
||||
else:
|
||||
# Switch dtype in case word_embeddings are fp16/bf16
|
||||
hidden_states = hidden_states.to(word_embeddings.dtype)
|
||||
lm_logits = F.linear(hidden_states, word_embeddings).float()
|
||||
lm_logits = F.linear(hidden_states, word_embeddings)
|
||||
return lm_logits
|
||||
|
||||
def chunked_forward(self, hidden_states):
|
||||
|
Loading…
Reference in New Issue
Block a user