|
|
|
from contextlib import contextmanager
|
|
|
|
from typing import List, Optional, Union
|
|
|
|
|
|
|
|
import hivemind
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
|
|
|
from transformers.models.bloom import (
|
|
|
|
BloomConfig,
|
|
|
|
BloomForCausalLM,
|
|
|
|
BloomForSequenceClassification,
|
|
|
|
BloomModel,
|
|
|
|
BloomPreTrainedModel,
|
|
|
|
)
|
|
|
|
|
|
|
|
from petals.bloom.modeling_utils import LMHead
|
|
|
|
from petals.client.remote_generation import RemoteGenerationMixin
|
|
|
|
from petals.client.remote_sequential import RemoteSequential
|
Refactor RemoteSequenceManager (#309)
This PR:
1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.**
The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way.
2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.**
`dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided.
3. **Simplifies retry logic.**
Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps.
4. **Removes deprecated `RemoteTransformerBlock`.**
`RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due.
5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.**
This functions duplicate the functionality of the `RemoteSequential` constructor.
6. (minor) **Removes `RemoteSequential.is_subsequence` flag.**
This flag worked incorrectly and was never used. I am removing it for the sake of simplicity.
1 year ago
|
|
|
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
|
|
|
from petals.constants import PUBLIC_INITIAL_PEERS
|
|
|
|
from petals.utils.misc import DUMMY
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
Refactor RemoteSequenceManager (#309)
This PR:
1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.**
The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way.
2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.**
`dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided.
3. **Simplifies retry logic.**
Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps.
4. **Removes deprecated `RemoteTransformerBlock`.**
`RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due.
5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.**
This functions duplicate the functionality of the `RemoteSequential` constructor.
6. (minor) **Removes `RemoteSequential.is_subsequence` flag.**
This flag worked incorrectly and was never used. I am removing it for the sake of simplicity.
1 year ago
|
|
|
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig):
|
|
|
|
"""
|
|
|
|
A bloom config that contains information about DHT peers.
|
|
|
|
To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
|
|
|
|
"""
|
|
|
|
|
|
|
|
initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT
|
|
|
|
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
|
|
|
|
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
|
|
|
|
|
|
|
|
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
Refactor RemoteSequenceManager (#309)
This PR:
1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.**
The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way.
2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.**
`dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided.
3. **Simplifies retry logic.**
Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps.
4. **Removes deprecated `RemoteTransformerBlock`.**
`RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due.
5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.**
This functions duplicate the functionality of the `RemoteSequential` constructor.
6. (minor) **Removes `RemoteSequential.is_subsequence` flag.**
This flag worked incorrectly and was never used. I am removing it for the sake of simplicity.
1 year ago
|
|
|
tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
|
|
|
|
|
|
|
|
# This settings matter for running the client with dtype bfloat16 on CPU.
|
|
|
|
# If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
|
|
|
|
use_chunked_forward: Union[str, bool] = "auto"
|
|
|
|
chunked_forward_step: int = 16384
|
|
|
|
|
|
|
|
|
|
|
|
original_register_parameter = nn.Module.register_parameter
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def force_non_empty_weights():
|
|
|
|
"""
|
|
|
|
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
|
|
|
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
|
|
|
The transformers library should replace all meta tensors by empty tensors by itself
|
|
|
|
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
|
|
|
|
|
|
|
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
|
|
|
"""
|
|
|
|
|
|
|
|
try:
|
|
|
|
possibly_patched_register_parameter = nn.Module.register_parameter
|
|
|
|
nn.Module.register_parameter = original_register_parameter
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
nn.Module.register_parameter = possibly_patched_register_parameter
|
|
|
|
|
|
|
|
|
|
|
|
class _FromPretrainedDefaultsMixin:
|
|
|
|
@classmethod
|
|
|
|
def from_pretrained(
|
|
|
|
cls,
|
|
|
|
*args,
|
|
|
|
low_cpu_mem_usage: Optional[bool] = None,
|
|
|
|
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
if low_cpu_mem_usage is None:
|
|
|
|
low_cpu_mem_usage = True
|
|
|
|
if torch_dtype is None:
|
|
|
|
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
|
|
|
|
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
|
|
|
|
torch_dtype = "auto"
|
|
|
|
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs)
|
|
|
|
|
|
|
|
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
|
|
|
"low_cpu_mem_usage(`bool`, *optional*)",
|
|
|
|
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
|
|
|
).replace(
|
|
|
|
"torch_dtype (`str` or `torch.dtype`, *optional*)",
|
|
|
|
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel):
|
|
|
|
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
|
|
|
_keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
|
|
|
|
r"^(intermediate_)?prompt_embeddings\.weight$",
|
|
|
|
]
|
|
|
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
Refactor RemoteSequenceManager (#309)
This PR:
1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.**
The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way.
2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.**
`dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided.
3. **Simplifies retry logic.**
Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps.
4. **Removes deprecated `RemoteTransformerBlock`.**
`RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due.
5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.**
This functions duplicate the functionality of the `RemoteSequential` constructor.
6. (minor) **Removes `RemoteSequential.is_subsequence` flag.**
This flag worked incorrectly and was never used. I am removing it for the sake of simplicity.
1 year ago
|
|
|
def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
|
|
|
|
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
|
Refactor RemoteSequenceManager (#309)
This PR:
1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.**
The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way.
2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.**
`dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided.
3. **Simplifies retry logic.**
Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps.
4. **Removes deprecated `RemoteTransformerBlock`.**
`RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due.
5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.**
This functions duplicate the functionality of the `RemoteSequential` constructor.
6. (minor) **Removes `RemoteSequential.is_subsequence` flag.**
This flag worked incorrectly and was never used. I am removing it for the sake of simplicity.
1 year ago
|
|
|
assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
|
|
|
|
|
|
|
|
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
|
|
|
|
super().__init__(config)
|
|
|
|
assert len(self.h) == 0
|
|
|
|
config.n_layer = n_layer
|
|
|
|
|
Refactor RemoteSequenceManager (#309)
This PR:
1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.**
The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way.
2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.**
`dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided.
3. **Simplifies retry logic.**
Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps.
4. **Removes deprecated `RemoteTransformerBlock`.**
`RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due.
5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.**
This functions duplicate the functionality of the `RemoteSequential` constructor.
6. (minor) **Removes `RemoteSequential.is_subsequence` flag.**
This flag worked incorrectly and was never used. I am removing it for the sake of simplicity.
1 year ago
|
|
|
self.h = RemoteSequential(config, dht=dht)
|
|
|
|
|
|
|
|
# Forbid accumulate grads for embeddings and layernorm
|
|
|
|
self.set_requires_grad(False)
|
|
|
|
|
|
|
|
if config.tuning_mode and "ptune" in config.tuning_mode:
|
|
|
|
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
|
|
|
self.pre_seq_len = config.pre_seq_len
|
|
|
|
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
|
|
|
|
|
|
|
with force_non_empty_weights():
|
|
|
|
if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16):
|
|
|
|
logger.info(
|
|
|
|
"Prompt embeddings and their optimizer statistics will be kept in float32 "
|
|
|
|
"to increase ptune quality"
|
|
|
|
)
|
|
|
|
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
|
|
|
if config.tuning_mode == "deep_ptune":
|
|
|
|
self.intermediate_prompt_embeddings = nn.Embedding(
|
|
|
|
self.pre_seq_len,
|
|
|
|
config.num_hidden_layers * config.hidden_size,
|
|
|
|
# ^-- TODO: should be num_hidden_layers - 1
|
|
|
|
dtype=torch.float32,
|
|
|
|
)
|
|
|
|
elif config.tuning_mode:
|
|
|
|
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
|
|
|
|
|
|
|
def set_requires_grad(self, value):
|
|
|
|
for p in self.parameters():
|
|
|
|
p.requires_grad = value
|
|
|
|
|
|
|
|
def get_prompt(self, batch_size):
|
|
|
|
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
|
|
|
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
|
|
|
prompts = self.prompt_embeddings(prefix_tokens)
|
|
|
|
|
|
|
|
if self.config.tuning_mode == "deep_ptune":
|
|
|
|
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
|
|
|
intermediate_prompts = intermediate_prompts.view(
|
|
|
|
batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1
|
|
|
|
)
|
|
|
|
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
|
|
|
else:
|
|
|
|
intermediate_prompts = DUMMY
|
|
|
|
|
|
|
|
dtype = self.word_embeddings.weight.dtype
|
|
|
|
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
|
|
|
|
|
|
|
|
for k, v in kwargs.items():
|
|
|
|
if not (v is None or v is False):
|
|
|
|
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
|
|
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
|
elif input_ids is not None:
|
|
|
|
input_shape = input_ids.size()
|
|
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
|
|
elif inputs_embeds is not None:
|
|
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
|
|
else:
|
|
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
|
|
|
|
if inputs_embeds is None:
|
|
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
|
|
|
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
|
|
batch_size = inputs_embeds.shape[0]
|
|
|
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
|
|
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
|
|
|
|
|
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
|
|
|
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
|
|
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
|
|
|
else:
|
|
|
|
hidden_states = self.h(hidden_states)
|
|
|
|
|
|
|
|
# Remove prefix
|
|
|
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
|
|
hidden_states = hidden_states[:, self.pre_seq_len :]
|
|
|
|
|
|
|
|
# Add last hidden state
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
hidden_states = hidden_states.view(output_shape)
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
|
|
last_hidden_state=hidden_states,
|
|
|
|
past_key_values=None,
|
|
|
|
hidden_states=None,
|
|
|
|
attentions=None,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomForCausalLM(_FromPretrainedDefaultsMixin, RemoteGenerationMixin, BloomForCausalLM):
|
|
|
|
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
|
|
|
_keys_to_ignore_on_load_missing = (
|
|
|
|
BloomForCausalLM._keys_to_ignore_on_load_missing
|
|
|
|
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
|
|
+ [r"^lm_head.word_embeddings\.weight$"] # Missing since they are shared with input embeddings
|
|
|
|
)
|
|
|
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
|
BloomPreTrainedModel.__init__(self, config)
|
|
|
|
self.transformer = DistributedBloomModel(config)
|
|
|
|
self.lm_head = LMHead(config, self.transformer.word_embeddings)
|
|
|
|
|
|
|
|
# Initialize weights and apply final processing
|
|
|
|
self.post_init()
|
|
|
|
|
|
|
|
def get_input_embeddings(self):
|
|
|
|
return self.transformer.word_embeddings
|
|
|
|
|
|
|
|
def get_output_embeddings(self):
|
|
|
|
if self.config.tie_word_embeddings:
|
|
|
|
return None
|
|
|
|
return self.lm_head
|
|
|
|
|
|
|
|
def set_input_embeddings(self, new_embeddings: nn.Embedding):
|
|
|
|
assert isinstance(new_embeddings, nn.Embedding)
|
|
|
|
self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
|
|
|
|
assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
|
|
|
|
|
|
|
|
def set_output_embeddings(self, new_lm_head: nn.Linear):
|
|
|
|
with torch.no_grad():
|
|
|
|
self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
|
|
|
|
self.lm_head.bias[...] = new_lm_head.bias
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomForSequenceClassification(_FromPretrainedDefaultsMixin, BloomForSequenceClassification):
|
|
|
|
_keys_to_ignore_on_load_missing = (
|
|
|
|
BloomForSequenceClassification._keys_to_ignore_on_load_missing
|
|
|
|
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
|
|
)
|
|
|
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
|
BloomPreTrainedModel.__init__(self, config)
|
|
|
|
self.num_labels = config.num_labels
|
|
|
|
|
|
|
|
self.transformer = DistributedBloomModel(config)
|
|
|
|
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype)
|
|
|
|
|
|
|
|
# Initialize weights and apply final processing
|
|
|
|
self.post_init()
|