|
|
|
@ -204,12 +204,12 @@ class BloomMLP(nn.Module):
|
|
|
|
|
class BloomBlock(nn.Module):
|
|
|
|
|
def __init__(self, config, layer_number=None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
hidden_size = config.hidden_size
|
|
|
|
|
self.hidden_size = config.hidden_size
|
|
|
|
|
|
|
|
|
|
self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
self.n_head = config.n_head
|
|
|
|
|
self.self_attention = BloomAttention(config, layer_number=layer_number)
|
|
|
|
|
self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
|
|
|
|
|
self.mlp = BloomMLP(config)
|
|
|
|
|
|
|
|
|
|