You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/models/bloom/config.py

35 lines
1.3 KiB
Python

Add LLaMA support (#323) This PR: 1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms. - BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ. - LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name). 2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`. 3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.). 4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers. Upgrade instructions: - Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
import os
from typing import Optional, Union
from hivemind import get_logger
from transformers.models.bloom import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.client.routing.sequence_manager import SequenceManagerConfig
from petals.models.bloom.block import WrappedBloomBlock
logger = get_logger(__name__)
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedBloomBlock
attn_class = BloomAttention
block_prefix = "h"
num_key_value_groups = 1
Add LLaMA support (#323) This PR: 1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms. - BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ. - LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name). 2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`. 3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.). 4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers. Upgrade instructions: - Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
@classmethod
def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
):
logger.info("Make sure you follow the BLOOM's terms of use: https://bit.ly/bloom-license")
Add LLaMA support (#323) This PR: 1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms. - BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ. - LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name). 2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`. 3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.). 4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers. Upgrade instructions: - Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
if loading_from_repo and dht_prefix is None:
# We need "-petals" for backward compatibility with Petals < 1.2.0
dht_prefix = str(model_name_or_path) + "-petals"
logger.info(f"Using DHT prefix: {dht_prefix}")
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)