You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/models/llama/model.py

152 lines
5.9 KiB
Python

from typing import Optional
import hivemind
import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin
from petals.client.remote_sequential import RemoteSequential
from petals.models.llama.config import DistributedLlamaConfig
logger = get_logger(__name__)
class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
"""LlamaModel, but all transformer layers are hosted by the swarm"""
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = [r"^model\.layers\."]
config_class = DistributedLlamaConfig
def __init__(self, config: DistributedLlamaConfig, *, dht: Optional[hivemind.DHT] = None):
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
super().__init__(config)
assert len(self.layers) == 0
config.num_hidden_layers = n_layer
self.layers = RemoteSequential(config, dht=dht)
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
self.init_prompts(config)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> BaseModelOutputWithPast:
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
for k, v in kwargs.items():
if not (v is None or v is False):
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
hidden_states = inputs_embeds
output_shape = input_shape + (hidden_states.size(-1),)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = self.layers(hidden_states, prompts=intermediate_prompts)
else:
hidden_states = self.layers(hidden_states)
# Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = hidden_states[:, self.pre_seq_len :]
# Add last hidden state
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=None,
attentions=None,
)
@property
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
return self.embed_tokens
@property
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
return nn.Identity()
@property
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
return self.layers
@property
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
return self.norm
class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, LlamaForCausalLM):
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
config_class = DistributedLlamaConfig
def __init__(self, config: DistributedLlamaConfig):
LlamaPreTrainedModel.__init__(self, config)
self.model = DistributedLlamaModel(config)
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = LMHead(config)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
@property
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
return self.model
class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
config_class = DistributedLlamaConfig
def __init__(self, config):
LlamaPreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels
self.model = DistributedLlamaModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
@property
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
return self.model