mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
Remove excess .float() cast
This commit is contained in:
parent
6190a5909e
commit
d8dac556a6
@ -129,7 +129,7 @@ class DistributedBloomModel(BloomModel):
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
|
Loading…
Reference in New Issue
Block a user