From e520c4781e264d0b63f54538415788fb2724b46d Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Wed, 30 Nov 2022 15:36:37 +0000 Subject: [PATCH] Keep prompts in float32, cast where necessary --- src/petals/client/remote_model.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index b9923fa..eba613c 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -95,12 +95,18 @@ class DistributedBloomModel(BloomModel): self.prefix_tokens = torch.arange(self.pre_seq_len).long() with force_non_empty_weights(): - self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size) + if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16): + logger.info( + "Prompt embeddings and their optimizer statistics will be kept in float32 " + "to increase ptune quality" + ) + self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32) if config.tuning_mode == "deep_ptune": self.intermediate_prompt_embeddings = nn.Embedding( self.pre_seq_len, - config.num_hidden_layers * config.hidden_size + config.num_hidden_layers * config.hidden_size, # ^-- TODO: should be num_hidden_layers - 1 + dtype=torch.float32, ) elif config.tuning_mode: raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now") @@ -122,7 +128,9 @@ class DistributedBloomModel(BloomModel): intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3]) else: intermediate_prompts = DUMMY - return prompts, intermediate_prompts + + dtype = self.word_embeddings.weight.dtype + return prompts.to(dtype), intermediate_prompts.to(dtype) def forward( self,