deep-prompt-tuning
dbaranchuk 2 years ago
parent 1477bdc471
commit 5b06dc2255

@ -34,7 +34,7 @@ class DistributedBloomConfig(BloomConfig):
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
num_prefix_tokens: int = 0 # a number of tokens for prompt tuning.
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
class DistributedBloomModel(BloomModel):
@ -112,11 +112,11 @@ class DistributedBloomPrefix(DistributedBloomModel):
def __init__(self, config):
super().__init__(config)
assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0"
self.prefix_length = config.num_prefix_tokens
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
self.pre_seq_len = config.pre_seq_len
self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size)
self.prefix_tokens = torch.arange(self.prefix_length).long()
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
def get_prompt(self, batch_size):
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
@ -163,7 +163,7 @@ class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
def __init__(self, config: DistributedBloomConfig):
BloomPreTrainedModel.__init__(self, config)
if config.num_prefix_tokens > 0:
if config.pre_seq_len > 0:
self.transformer = DistributedBloomPrefix(config)
else:
self.transformer = DistributedBloomModel(config)
@ -223,7 +223,7 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
def __init__(self, config: DistributedBloomConfig):
super().__init__(config)
if config.num_prefix_tokens > 0:
if config.pre_seq_len > 0:
self.transformer = DistributedBloomPrefix(config)
else:
self.transformer = DistributedBloomModel(config)

Loading…
Cancel
Save