|
|
|
@ -34,7 +34,7 @@ class DistributedBloomConfig(BloomConfig):
|
|
|
|
|
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
|
|
|
|
|
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
|
|
|
|
|
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
|
|
|
|
|
num_prefix_tokens: int = 0 # a number of tokens for prompt tuning.
|
|
|
|
|
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomModel(BloomModel):
|
|
|
|
@ -112,11 +112,11 @@ class DistributedBloomPrefix(DistributedBloomModel):
|
|
|
|
|
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
super().__init__(config)
|
|
|
|
|
assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0"
|
|
|
|
|
self.prefix_length = config.num_prefix_tokens
|
|
|
|
|
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
|
|
|
|
self.pre_seq_len = config.pre_seq_len
|
|
|
|
|
|
|
|
|
|
self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size)
|
|
|
|
|
self.prefix_tokens = torch.arange(self.prefix_length).long()
|
|
|
|
|
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
|
|
|
|
|
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
|
|
|
|
|
|
|
|
|
def get_prompt(self, batch_size):
|
|
|
|
|
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
|
|
|
@ -163,7 +163,7 @@ class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
|
|
BloomPreTrainedModel.__init__(self, config)
|
|
|
|
|
if config.num_prefix_tokens > 0:
|
|
|
|
|
if config.pre_seq_len > 0:
|
|
|
|
|
self.transformer = DistributedBloomPrefix(config)
|
|
|
|
|
else:
|
|
|
|
|
self.transformer = DistributedBloomModel(config)
|
|
|
|
@ -223,7 +223,7 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
|
|
super().__init__(config)
|
|
|
|
|
if config.num_prefix_tokens > 0:
|
|
|
|
|
if config.pre_seq_len > 0:
|
|
|
|
|
self.transformer = DistributedBloomPrefix(config)
|
|
|
|
|
else:
|
|
|
|
|
self.transformer = DistributedBloomModel(config)
|
|
|
|
|