From 86d08bf515b097ffce2051c4b58c6cf63a17e9b1 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Tue, 29 Nov 2022 08:58:58 +0000 Subject: [PATCH] Don't cast logits to float32 on GPU --- src/bloom/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bloom/model.py b/src/bloom/model.py index 574140b..566020d 100644 --- a/src/bloom/model.py +++ b/src/bloom/model.py @@ -449,7 +449,7 @@ class LMHead(nn.Module): else: # Switch dtype in case word_embeddings are fp16/bf16 hidden_states = hidden_states.to(word_embeddings.dtype) - lm_logits = F.linear(hidden_states, word_embeddings).float() + lm_logits = F.linear(hidden_states, word_embeddings) return lm_logits def chunked_forward(self, hidden_states):