Quality of life changes: update readme, simplify run_server interface (#75)

- run_server now accepts model name as both positional and keyword argument
- changed names in README to account for interface updates
- moved model conversion from README to a separate wiki page
- updated requirements.txt
justheuristic-patch-5
justheuristic 2 years ago committed by GitHub
parent 1046911dea
commit 8caf1145a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -35,7 +35,7 @@ This snippet shows how to **(a)** generate text with BLOOM and **(b)** solve a s
```python
# Initialize distributed BLOOM and connect to the swarm
model = DistributedBloomForCausalLM.from_pretrained(
"bigscience/distributed-bloom", tuning_mode="ptune", initial_peers=SEE_BELOW
"bigscience/bloom-petals", tuning_mode="ptune", initial_peers=SEE_BELOW
) # Embeddings & prompts are on your device, BLOOM blocks are distributed
print("Generated:", model.generate(tokenized_prefix, max_new_tokens=5))
@ -78,89 +78,121 @@ This is important because it's technically possible for peers serving model laye
## Installation
🚧 **Note:** These are short instructions for running a private swarm with a test 6B version of BLOOM. We will replace them with instructions involving the full 176B BLOOM and more detailed explanations soon (in a day or two).
Here's how to install the dependencies with conda:
```
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
pip install bitsandbytes==0.33.2 # for 8-bit quantization
pip install -r requirements.txt
```
--------------------------------------------------------------------------------
This script uses Anaconda to install cuda-enabled PyTorch.
If you don't have anaconda, you can get it from [here](https://www.anaconda.com/products/distribution).
If you don't want anaconda, you can install PyTorch [any other way](https://pytorch.org/get-started/locally/).
If you want to run models with 8-bit weights, please install **PyTorch with CUDA 11** or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes).
__OS support:__ currently, PETALS only supports Linux operating systems. On Windows 11, you can run PETALS with GPU enabled inside WSL2 ([read more](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl).
For macOS, you can *probably* run everything normally if you manage to install dependencies, but we do not guarantee this.
### Getting Started
This is a toy example running on a local machine without GPU and with a tiny model.
For a more detailed instruction with larger models, see ["Launch your own swarm"](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm).
First, run a couple of servers, each in a separate shell. First server runs like this
```bash
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
python -m cli.run_server bloom-testing/test-bloomd-560m-main --num_blocks 8 --torch_dtype float32 \
--host_maddrs /ip4/127.0.0.1/tcp/31337 # use port 31337, local connections only
```
### Basic functionality
This server will host 8 (out of 24) layers for [this tiny bloom model](https://huggingface.co/bloom-testing/test-bloomd-560m-main) that was converted for PETALS.
To run a different model, please see [this wiki page](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-PETALS).
All tests is run on localhost
Once the server has started, it will print out a ton of information, including an (important) line like this:
```bash
Mon Day 01:23:45.678 [INFO] Running DHT node on ['/ip4/127.0.0.1/tcp/31337/p2p/ALongStringOfCharacters'], initial peers = []
```
First, run one or more servers like this:
You can use this address (/ip4/whatever/else) to connect additional servers. Open another terminal and run:
```bash
# minimalistic server with non-trained bloom blocks
python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 \
--block_indices 3:5 --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
# when running multiple servers:
# - give each server a unique --identity_path (or remote --identity_path arg when debugging)
# - if running multiple servers on the same machine, give each a unique port (last integer in --host_maddrs, 0 means random port)
# - when running over the internet, change --host_maddrs according to https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet
# - each server except first should have --initial_peers pointing to one of pre-existing servers
python -m cli.run_server bloom-testing/test-bloomd-560m-main --num_blocks 8 --torch_dtype float32 \
--host_maddrs /ip4/127.0.0.1/tcp/0 --initial_peers /ip4/127.0...<TODO! copy the address of another server>
# e.g. --initial_peers /ip4/127.0.0.1/tcp/31337/p2p/QmS1GecIfYouAreReadingThisYouNeedToCopyYourServerAddressCBBq
```
Then open a python notebook or console and run:
You can assign `--initial_peers` to one or multiple addresses of other servers, not necessarily the first one.
The only requirement is that at least one of them is alive, i.e. running at the time.
Before you proceed, __please run 3 servers__ for a total of 24 blocks (3x8). If you are running a different model,
make sure your servers have enough total `--num_blocks` to cover that model.
Once your have enough servers, you can use them to train and/or inference the model:
```python
import torch
import hivemind
from src import DistributedBloomConfig, get_remote_module
dht = hivemind.DHT(
initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS], # e.g. /ip4/127.0.0.1/...
client_mode=True, start=True,
)
config = DistributedBloomConfig.from_pretrained("bigscience/test-bloom-6b3")
layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4'], config)
assert layer3 is not None and layer4 is not None, "one or both layers were not found in DHT"
# test forward/backward, two blocks
outputs = layer4(layer3(torch.randn(1, 64, 4096)))
loss = (outputs * torch.randn_like(outputs)).norm()
loss.backward()
import torch.nn.functional as F
import transformers
from src import DistributedBloomForCausalLM
# test inference, one block
with layer3.inference_session(max_length=10) as sess:
for i in range(10):
res = sess.step(torch.ones(1, 1, 4096))
```
initial_peers = [TODO_put_one_or_more_server_addresses_here] # e.g. ["/ip4/127.0.0.1/tcp/more/stuff/here"]
tokenizer = transformers.BloomTokenizerFast.from_pretrained("bloom-testing/test-bloomd-560m-main")
model = DistributedBloomForCausalLM.from_pretrained(
"bloom-testing/test-bloomd-560m-main", initial_peers=initial_peers, low_cpu_mem_usage=True, torch_dtype=torch.float32
) # this model has only embeddings / logits, all transformer blocks rely on remote servers
### Convert regular BLOOM into distributed
```bash
inputs = tokenizer("a cat sat", return_tensors="pt")["input_ids"]
remote_outputs = model.generate(inputs, max_length=10)
print(tokenizer.decode(remote_outputs[0])) # "a cat sat in the back of the car,"
# convert model from HF hub to a distributed format (can take hours depending on your connection!)
MY_WRITE_TOKEN=TODO_WRITE_TOKEN_FROM_https://huggingface.co/settings/token
python -m cli.convert_model --model bigscience/bloom-6b3 \
--output_path ./converted_model --output_repo bigscience/test-bloomd-6b3 \
--use_auth_token $MY_WRITE_TOKEN # ^-- todo replace output repo with something you have access to
model = DistributedBloomForCausalLM.from_pretrained(
"bloom-testing/test-bloomd-560m-main", initial_peers=initial_peers, low_cpu_mem_usage=True, torch_dtype=torch.float32
) # this model has only embeddings / logits, all transformer blocks rely on remote servers
# "train" input embeddings by backprop through distributed transformer blocks
model.transformer.word_embeddings.weight.requires_grad = True
outputs = model.forward(input_ids=inputs)
loss = F.cross_entropy(outputs.logits.flatten(0, 1), inputs.flatten())
loss.backward()
print("Gradients (norm):", model.transformer.word_embeddings.weight.grad.norm())
```
Of course, this is a simplified code snippet. For actual training, see our example on "deep" prompt-tuning here: [examples/prompt-tuning-personachat.ipynb](./examples/prompt-tuning-personachat.ipynb).
### Test local vs remote block (allclose)
Here's a [more advanced tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) that covers 8-bit quantization and best practices for running PETALS.
To test distributed inference, run one or more servers, then open a new shell and run pytest with environment variables:
```bash
# shell A: serve model
python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 \
--torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
### Development
# shell B:
export PYTHONPATH=.
export INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT"
export MODEL_NAME="bigscience/test-bloomd-6b3"
PETALS uses pytest with a few plugins. To install them, run `pip install -r requirements-dev.txt`
# test individual random blocks for exact match
pytest tests/test_block_exact_match.py
To run minimalistic tests, spin up some servers:
```bash
export MODEL_NAME=bloom-testing/test-bloomd-560m-main
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
python -m cli.run_server $MODEL_NAME --block_indices 0:12 --throughput 1 --torch_dtype float32 \
--identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> server1.log &
sleep 5 # wait for the first server to initialize DHT
python -m cli.run_server $MODEL_NAME --block_indices 12:24 --throughput 1 --torch_dtype float32 \
--initial_peers /ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g &> server2.log &
tail -f server1.log server2.log # view logs for both servers
# after you're done, kill servers with 'pkill -f cli.run_server'
```
# test the full model
pytest tests/test_full_model.py
Then launch pytest:
```
export MODEL_NAME=bloom-testing/test-bloomd-560m-main REF_NAME=bigscience/bloom-560m
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
PYTHONPATH=. pytest tests --durations=0 --durations-min=1.0 -v
```
The automated tests use a more complex server configuration that can be found [here](https://github.com/bigscience-workshop/petals/blob/main/.github/workflows/run-tests.yaml)
We use [black](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html) and [isort](https://pycqa.github.io/isort/) for all pull requests.
Before commiting your code, simply run `black . && isort .` and you will be fine.
--------------------------------------------------------------------------------

@ -15,8 +15,11 @@ def main():
parser = configargparse.ArgParser(default_config_files=["config.yml"])
parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
parser.add_argument('--converted_model_name_or_path', type=str, default='bigscience/test-bloomd-6b3',
help="path or name of a pretrained model, converted with cli/convert_model.py (see README.md)")
group = parser.add_mutually_exclusive_group()
group.add_argument('--converted_model_name_or_path', type=str, default=None,
help="path or name of a pretrained model, converted with cli/convert_model.py")
group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
@ -83,6 +86,8 @@ def main():
args = vars(parser.parse_args())
args.pop("config", None)
args["converted_model_name_or_path"] = args.pop("model") or args["converted_model_name_or_path"]
if args.pop("increase_file_limit"):
increase_file_limit()

@ -1,8 +1,8 @@
torch==1.12.0
bitsandbytes==0.33.0
torch>=1.12
bitsandbytes==0.33.0 #TODO update this to 0.33.2 asap
accelerate==0.10.0
huggingface-hub==0.7.0
transformers==4.21.3
protobuf>=3.12.2,<4.0.0
https://github.com/learning-at-home/hivemind/archive/131f82c97ea67510d552bb7a68138ad27cbfa5d4.zip
hivemind==1.1.1
humanfriendly

Loading…
Cancel
Save