Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
import dataclasses
|
|
|
|
from contextlib import contextmanager
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from hivemind import get_logger
|
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
|
|
|
from petals.utils.misc import DUMMY
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
class PTuneConfig:
|
|
|
|
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
|
|
|
tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
|
|
|
|
|
|
|
|
|
|
|
|
class PTuneMixin:
|
|
|
|
_keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"]
|
|
|
|
|
|
|
|
def init_prompts(self, config: PretrainedConfig) -> None:
|
|
|
|
if config.tuning_mode and "ptune" in config.tuning_mode:
|
|
|
|
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
|
|
|
self.pre_seq_len = config.pre_seq_len
|
|
|
|
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
|
|
|
|
|
|
|
with force_non_empty_weights():
|
|
|
|
# Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality
|
|
|
|
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
|
|
|
if config.tuning_mode == "deep_ptune":
|
|
|
|
self.intermediate_prompt_embeddings = nn.Embedding(
|
|
|
|
self.pre_seq_len,
|
|
|
|
config.num_hidden_layers * config.hidden_size,
|
|
|
|
# ^-- TODO: should be num_hidden_layers - 1
|
|
|
|
dtype=torch.float32,
|
|
|
|
)
|
|
|
|
elif config.tuning_mode:
|
|
|
|
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
|
|
|
|
|
|
|
def get_prompt(self, batch_size):
|
|
|
|
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
|
|
|
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
|
|
|
prompts = self.prompt_embeddings(prefix_tokens)
|
|
|
|
|
|
|
|
if self.config.tuning_mode == "deep_ptune":
|
|
|
|
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
|
|
|
intermediate_prompts = intermediate_prompts.view(
|
|
|
|
batch_size,
|
|
|
|
self.pre_seq_len,
|
|
|
|
self.config.num_hidden_layers,
|
|
|
|
self.config.hidden_size
|
|
|
|
# TODO: should be num_hidden_layers - 1
|
|
|
|
)
|
|
|
|
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
|
|
|
else:
|
|
|
|
intermediate_prompts = DUMMY
|
|
|
|
|
|
|
|
dtype = self.word_embeddings.weight.dtype
|
|
|
|
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
|
|
|
|
|
|
|
|
|
|
|
_original_register_parameter = nn.Module.register_parameter
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def force_non_empty_weights():
|
|
|
|
"""
|
|
|
|
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
|
|
|
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
|
|
|
The transformers library should replace all meta tensors by empty tensors by itself
|
|
|
|
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
|
|
|
|
|
|
|
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
|
|
|
"""
|
|
|
|
|
Make client compatible with transformers' GenerationMixin (#464)
This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally.
Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions.
### Breaking changes
If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent:
```python
# Using default session
with model.inference_session(max_length=100):
output_ids = model.generate(input_ids, max_new_tokens=3)
# Explicitly specifying a session
with model.inference_session(max_length=100) as sess:
output_ids = model.generate(input_ids, max_new_tokens=3, session=sess)
```
Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).
9 months ago
|
|
|
possibly_patched_register_parameter = nn.Module.register_parameter
|
|
|
|
nn.Module.register_parameter = _original_register_parameter
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
nn.Module.register_parameter = possibly_patched_register_parameter
|