Fix server warnings, update license links and readme (#602)

pull/598/head
Alexander Borzunov 3 months ago committed by GitHub
parent 67ca11a282
commit 8ad5513bea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -8,16 +8,14 @@
<br> <br>
</p> </p>
**Warning: Llama 3.1 support is still under construction!** the latest models require custom RoPE configuration that we do not have in Petals yet; we will update the code to fix that within a day. Generate text with distributed **Llama 3.1** (up to 405B), **Mixtral** (8x7B), **Falcon** (40B+), or **BLOOM** (176B) and finetune them for your own tasks &mdash; right from your desktop computer or Google Colab:
Generate text with distributed **Llama (1-3)** (70B), **Falcon** (40B+), **BLOOM** (176B) (or their derivatives), and finetune them for your own tasks &mdash; right from your desktop computer or Google Colab:
```python ```python
from transformers import AutoTokenizer from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM from petals import AutoDistributedModelForCausalLM
# Choose any model available at https://health.petals.dev # Choose any model available at https://health.petals.dev
model_name = "petals-team/StableBeluga2" # This one is fine-tuned Llama 2 (70B) model_name = "meta-llama/Meta-Llama-3.1-405B-Instruct"
# Connect to a distributed network hosting model layers # Connect to a distributed network hosting model layers
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -33,22 +31,26 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat...
🚀 &nbsp;<b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b> 🚀 &nbsp;<b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b>
</p> </p>
🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. 🦙 **Want to run Llama?** [Request access](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) to its weights, then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev). 🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)! 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
## Connect your GPU and increase Petals capacity ## Connect your GPU and increase Petals capacity
Petals is a community-run system &mdash; we rely on people sharing their GPUs. You can check out [available models](https://health.petals.dev) and help serving one of them! As an example, here is how to host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your GPU: Petals is a community-run system &mdash; we rely on people sharing their GPUs. You can help serving one of the [available models](https://health.petals.dev) or host a new model from 🤗 [Model Hub](https://huggingface.co/models)!
As an example, here is how to host a part of [Llama 3.1 (405B) Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) on your GPU:
🦙 **Want to host Llama?** [Request access](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) to its weights, then run `huggingface-cli login` in the terminal before loading the model.
🐧 **Linux + Anaconda.** Run these commands for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD): 🐧 **Linux + Anaconda.** Run these commands for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
```bash ```bash
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
pip install git+https://github.com/bigscience-workshop/petals pip install git+https://github.com/bigscience-workshop/petals
python -m petals.cli.run_server petals-team/StableBeluga2 python -m petals.cli.run_server meta-llama/Meta-Llama-3.1-405B-Instruct
``` ```
🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki. 🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki.
@ -58,7 +60,7 @@ python -m petals.cli.run_server petals-team/StableBeluga2
```bash ```bash
sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \ sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \
learningathome/petals:main \ learningathome/petals:main \
python -m petals.cli.run_server --port 31330 petals-team/StableBeluga2 python -m petals.cli.run_server --port 31330 meta-llama/Meta-Llama-3.1-405B-Instruct
``` ```
🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands: 🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands:
@ -66,19 +68,17 @@ sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cach
```bash ```bash
brew install python brew install python
python3 -m pip install git+https://github.com/bigscience-workshop/petals python3 -m pip install git+https://github.com/bigscience-workshop/petals
python3 -m petals.cli.run_server petals-team/StableBeluga2 python3 -m petals.cli.run_server meta-llama/Meta-Llama-3.1-405B-Instruct
``` ```
<p align="center"> <p align="center">
📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (how to use multiple GPUs, start the server on boot, etc.) 📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (how to use multiple GPUs, start the server on boot, etc.)
</p> </p>
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
🦙 **Want to host Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add `--token YOUR_TOKEN_HERE` to the `python -m petals.cli.run_server` command.
🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). 🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`. 🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`.
## How does it work? ## How does it work?
@ -122,22 +122,39 @@ Please see **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing. Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing.
### 📜 Citation ### 📜 Citations
Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel. Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel.
[Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188) [Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188)
_arXiv preprint arXiv:2209.01188,_ 2022. _Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)._ 2023.
```bibtex ```bibtex
@article{borzunov2022petals, @inproceedings{borzunov2023petals,
title = {Petals: Collaborative Inference and Fine-tuning of Large Models}, title = {Petals: Collaborative Inference and Fine-tuning of Large Models},
author = {Borzunov, Alexander and Baranchuk, Dmitry and Dettmers, Tim and Ryabinin, Max and Belkada, Younes and Chumachenko, Artem and Samygin, Pavel and Raffel, Colin}, author = {Borzunov, Alexander and Baranchuk, Dmitry and Dettmers, Tim and Riabinin, Maksim and Belkada, Younes and Chumachenko, Artem and Samygin, Pavel and Raffel, Colin},
journal = {arXiv preprint arXiv:2209.01188}, booktitle = {Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
year = {2022}, pages = {558--568},
year = {2023},
url = {https://arxiv.org/abs/2209.01188} url = {https://arxiv.org/abs/2209.01188}
} }
``` ```
Alexander Borzunov, Max Ryabinin, Artem Chumachenko, Dmitry Baranchuk, Tim Dettmers, Younes Belkada, Pavel Samygin, and Colin Raffel.
[Distributed inference and fine-tuning of large language models over the Internet.](https://arxiv.org/abs/2312.08361)
_Advances in Neural Information Processing Systems_ 36 (2024).
```bibtex
@inproceedings{borzunov2023distributed,
title = {Distributed inference and fine-tuning of large language models over the {I}nternet},
author = {Borzunov, Alexander and Ryabinin, Max and Chumachenko, Artem and Baranchuk, Dmitry and Dettmers, Tim and Belkada, Younes and Samygin, Pavel and Raffel, Colin},
booktitle = {Advances in Neural Information Processing Systems},
volume = {36},
pages = {12312--12331},
year = {2023},
url = {https://arxiv.org/abs/2312.08361}
}
```
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
<p align="center"> <p align="center">

@ -7,7 +7,7 @@ from typing import Optional, Tuple
import torch import torch
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor from transformers.models.bloom.modeling_bloom import BloomBlock, build_alibi_tensor
from petals.utils.misc import is_dummy from petals.utils.misc import is_dummy

@ -24,7 +24,7 @@ class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfi
def from_pretrained( def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
): ):
logger.info("Make sure you follow the BLOOM's terms of use: https://bit.ly/bloom-license") logger.info("Make sure you follow the BLOOM terms of use: https://bit.ly/bloom-license")
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
if loading_from_repo and dht_prefix is None: if loading_from_repo and dht_prefix is None:

@ -15,7 +15,6 @@ from transformers.models.llama.modeling_llama import (
LlamaConfig, LlamaConfig,
LlamaDecoderLayer, LlamaDecoderLayer,
LlamaMLP, LlamaMLP,
LlamaModel,
LlamaRMSNorm, LlamaRMSNorm,
repeat_kv, repeat_kv,
rotate_half, rotate_half,
@ -132,7 +131,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig): def __init__(self, config: LlamaConfig):
nn.Module.__init__(self) nn.Module.__init__(self)
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = OptimizedLlamaAttention(config=config) self.self_attn = OptimizedLlamaAttention(config=config, layer_idx=0)
# layer_idx only matters for KV caching, and we re-implement it in Petals
self.mlp = LlamaMLP(config) self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

@ -27,8 +27,8 @@ class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfi
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
): ):
logger.info( logger.info(
"Make sure you follow the LLaMA's terms of use: " "Make sure you follow the Llama terms of use: "
"https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1" "https://llama.meta.com/llama3/license, https://llama.meta.com/llama2/license"
) )
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)

@ -1,4 +1,3 @@
import json
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
@ -8,7 +7,7 @@ from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa,
) )
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
class WrappedMixtralBlock(MixtralDecoderLayer): class WrappedMixtralBlock(MixtralDecoderLayer):

@ -64,10 +64,6 @@ def load_pretrained_block(
max_disk_space=max_disk_space, max_disk_space=max_disk_space,
) )
# dummy load, check that keys match
report = block.load_state_dict(state_dict, strict=False)
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
for param_name, _ in block.named_parameters(): for param_name, _ in block.named_parameters():
assert param_name in state_dict, f"{param_name} not in state dict" assert param_name in state_dict, f"{param_name} not in state dict"
param = state_dict[param_name] param = state_dict[param_name]
@ -76,7 +72,6 @@ def load_pretrained_block(
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
logger.info(f"Loaded {model_name} block {block_index}") logger.info(f"Loaded {model_name} block {block_index}")
logger.debug(f"Details: {report}")
return block return block

@ -267,7 +267,7 @@ def estimate_adapter_memory_per_block(
**load_peft_kwargs, **load_peft_kwargs,
) -> int: ) -> int:
"""Get the number of extra bytes used to store a set of adapters per given block""" """Get the number of extra bytes used to store a set of adapters per given block"""
with init_empty_weights(include_buffers=True): with init_empty_weights(include_buffers=False):
block = get_model_block(block_config) block = get_model_block(block_config)
base_block_parameters = sum(p.numel() for p in block.parameters()) base_block_parameters = sum(p.numel() for p in block.parameters())
create_lora_adapter(block) create_lora_adapter(block)

Loading…
Cancel
Save