Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
import hivemind
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
from transformers.cache_utils import Cache
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
|
|
|
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
|
|
|
|
|
|
|
|
from petals.client.from_pretrained import FromPretrainedMixin
|
|
|
|
from petals.client.lm_head import LMHead
|
|
|
|
from petals.client.ptune import PTuneMixin
|
Make client compatible with transformers' GenerationMixin (#464)
This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally.
Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions.
### Breaking changes
If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent:
```python
# Using default session
with model.inference_session(max_length=100):
output_ids = model.generate(input_ids, max_new_tokens=3)
# Explicitly specifying a session
with model.inference_session(max_length=100) as sess:
output_ids = model.generate(input_ids, max_new_tokens=3, session=sess)
```
Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).
9 months ago
|
|
|
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
from petals.client.remote_sequential import RemoteSequential
|
|
|
|
from petals.models.bloom.config import DistributedBloomConfig
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
|
|
|
|
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
|
|
|
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
_keys_to_ignore_on_load_unexpected = [r"^h\."]
|
|
|
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
|
|
|
|
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
|
|
|
super().__init__(config)
|
|
|
|
assert len(self.h) == 0
|
|
|
|
config.num_hidden_layers = n_layer
|
|
|
|
|
|
|
|
self.h = RemoteSequential(config, dht=dht)
|
|
|
|
|
|
|
|
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
self.init_prompts(config)
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
Make client compatible with transformers' GenerationMixin (#464)
This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally.
Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions.
### Breaking changes
If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent:
```python
# Using default session
with model.inference_session(max_length=100):
output_ids = model.generate(input_ids, max_new_tokens=3)
# Explicitly specifying a session
with model.inference_session(max_length=100) as sess:
output_ids = model.generate(input_ids, max_new_tokens=3, session=sess)
```
Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).
9 months ago
|
|
|
past_key_values: Optional[RemotePastKeyValues] = None,
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
Make client compatible with transformers' GenerationMixin (#464)
This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally.
Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions.
### Breaking changes
If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent:
```python
# Using default session
with model.inference_session(max_length=100):
output_ids = model.generate(input_ids, max_new_tokens=3)
# Explicitly specifying a session
with model.inference_session(max_length=100) as sess:
output_ids = model.generate(input_ids, max_new_tokens=3, session=sess)
```
Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).
9 months ago
|
|
|
head_mask: Optional[torch.LongTensor] = None,
|
|
|
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
|
|
|
use_cache: Optional[bool] = None,
|
Make client compatible with transformers' GenerationMixin (#464)
This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally.
Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions.
### Breaking changes
If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent:
```python
# Using default session
with model.inference_session(max_length=100):
output_ids = model.generate(input_ids, max_new_tokens=3)
# Explicitly specifying a session
with model.inference_session(max_length=100) as sess:
output_ids = model.generate(input_ids, max_new_tokens=3, session=sess)
```
Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).
9 months ago
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
):
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
|
elif input_ids is not None:
|
|
|
|
input_shape = input_ids.size()
|
|
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
|
|
elif inputs_embeds is not None:
|
|
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
|
|
else:
|
|
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
|
Make client compatible with transformers' GenerationMixin (#464)
This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally.
Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions.
### Breaking changes
If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent:
```python
# Using default session
with model.inference_session(max_length=100):
output_ids = model.generate(input_ids, max_new_tokens=3)
# Explicitly specifying a session
with model.inference_session(max_length=100) as sess:
output_ids = model.generate(input_ids, max_new_tokens=3, session=sess)
```
Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).
9 months ago
|
|
|
# The causal mask will be added on the server-side
|
|
|
|
assert (
|
|
|
|
attention_mask is None or (attention_mask == 1).all()
|
|
|
|
), f"Custom attention masks are not supported, {attention_mask=}"
|
|
|
|
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
|
|
|
|
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
|
Make client compatible with transformers' GenerationMixin (#464)
This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally.
Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions.
### Breaking changes
If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent:
```python
# Using default session
with model.inference_session(max_length=100):
output_ids = model.generate(input_ids, max_new_tokens=3)
# Explicitly specifying a session
with model.inference_session(max_length=100) as sess:
output_ids = model.generate(input_ids, max_new_tokens=3, session=sess)
```
Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).
9 months ago
|
|
|
assert not output_attentions, f"{output_attentions=} is not supported"
|
|
|
|
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
|
|
|
|
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
|
|
|
|
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
if inputs_embeds is None:
|
|
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
|
|
|
|
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
|
|
|
|
if use_prompts:
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
batch_size = inputs_embeds.shape[0]
|
|
|
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
|
|
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
Make client compatible with transformers' GenerationMixin (#464)
This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally.
Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions.
### Breaking changes
If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent:
```python
# Using default session
with model.inference_session(max_length=100):
output_ids = model.generate(input_ids, max_new_tokens=3)
# Explicitly specifying a session
with model.inference_session(max_length=100) as sess:
output_ids = model.generate(input_ids, max_new_tokens=3, session=sess)
```
Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).
9 months ago
|
|
|
else:
|
|
|
|
prompts = intermediate_prompts = None
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
|
|
|
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
|
Make client compatible with transformers' GenerationMixin (#464)
This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally.
Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions.
### Breaking changes
If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent:
```python
# Using default session
with model.inference_session(max_length=100):
output_ids = model.generate(input_ids, max_new_tokens=3)
# Explicitly specifying a session
with model.inference_session(max_length=100) as sess:
output_ids = model.generate(input_ids, max_new_tokens=3, session=sess)
```
Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).
9 months ago
|
|
|
hidden_states = self.h(
|
|
|
|
hidden_states,
|
|
|
|
prompts=intermediate_prompts,
|
|
|
|
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
|
|
|
|
)
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
|
|
|
|
# Remove prefix
|
|
|
|
if use_prompts:
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
hidden_states = hidden_states[:, self.pre_seq_len :]
|
|
|
|
|
|
|
|
if past_key_values is None:
|
|
|
|
past_key_values = RemotePastKeyValues()
|
|
|
|
past_key_values.update_seen(hidden_states.size(1))
|
|
|
|
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
# Add last hidden state
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
hidden_states = hidden_states.view(output_shape)
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
|
|
last_hidden_state=hidden_states,
|
|
|
|
past_key_values=past_key_values,
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
hidden_states=None,
|
|
|
|
attentions=None,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
|
|
|
|
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
|
|
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
|
|
|
_supports_cache_class = True
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
|
BloomPreTrainedModel.__init__(self, config)
|
|
|
|
self.transformer = DistributedBloomModel(config)
|
|
|
|
self.lm_head = LMHead(config)
|
|
|
|
|
|
|
|
# Initialize weights and apply final processing
|
|
|
|
self.post_init()
|
|
|
|
|
|
|
|
def prepare_inputs_for_generation(
|
|
|
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
|
|
|
) -> dict:
|
|
|
|
# Omit tokens covered by past_key_values
|
|
|
|
if past_key_values is not None:
|
|
|
|
if isinstance(past_key_values, Cache):
|
|
|
|
cache_length = past_key_values.get_seq_length()
|
|
|
|
past_length = past_key_values.seen_tokens
|
|
|
|
max_cache_length = past_key_values.get_max_length()
|
|
|
|
else:
|
|
|
|
cache_length = past_length = past_key_values[0][0].shape[2]
|
|
|
|
max_cache_length = None
|
|
|
|
|
|
|
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
|
|
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
|
|
|
elif past_length < input_ids.shape[1]:
|
|
|
|
input_ids = input_ids[:, past_length:]
|
|
|
|
|
|
|
|
if (
|
|
|
|
max_cache_length is not None
|
|
|
|
and attention_mask is not None
|
|
|
|
and cache_length + input_ids.shape[1] > max_cache_length
|
|
|
|
):
|
|
|
|
attention_mask = attention_mask[:, -max_cache_length:]
|
|
|
|
|
|
|
|
position_ids = kwargs.get("position_ids", None)
|
|
|
|
if attention_mask is not None and position_ids is None:
|
|
|
|
# create position_ids on the fly for batch generation
|
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
|
|
if past_key_values:
|
|
|
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
|
|
|
|
|
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
|
|
if inputs_embeds is not None and past_key_values is None:
|
|
|
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
|
|
else:
|
|
|
|
model_inputs = {"input_ids": input_ids}
|
|
|
|
|
|
|
|
model_inputs.update(
|
|
|
|
{
|
|
|
|
"position_ids": position_ids,
|
|
|
|
"past_key_values": past_key_values,
|
|
|
|
"use_cache": kwargs.get("use_cache"),
|
|
|
|
"attention_mask": attention_mask,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
return model_inputs
|
|
|
|
|
|
|
|
def _temporary_reorder_cache(self, past_key_values, beam_idx):
|
|
|
|
return past_key_values
|
|
|
|
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
def get_output_embeddings(self):
|
|
|
|
return self.lm_head
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
|
|
|
|
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
|
|
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
|
BloomPreTrainedModel.__init__(self, config)
|
|
|
|
self.num_labels = config.num_labels
|
|
|
|
|
|
|
|
self.transformer = DistributedBloomModel(config)
|
|
|
|
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
|
|
|
|
|
|
|
# Initialize weights and apply final processing
|
|
|
|
self.post_init()
|