diff --git a/README.md b/README.md index ea7919a..aa93a43 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ Petals is a community-run system — we rely on people sharing their GPUs. Y ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia pip install git+https://github.com/bigscience-workshop/petals -python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16 +python -m petals.cli.run_server stabilityai/StableBeluga2 ``` 🪟 **Windows + WSL.** Follow the guide on our [Wiki](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows). @@ -57,7 +57,7 @@ python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16 ```bash sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \ - python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2 --torch_dtype float16 + python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2 ``` These commands will host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, repos with LLaMA-65B, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures. diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index eb5300e..effce82 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -11,7 +11,8 @@ def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype] """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise.""" if dtype not in ("auto", None): return dtype - if config.torch_dtype not in ("auto", None): + if config.torch_dtype not in ("auto", None, torch.float32): + # If config specifies float32, we override it to the default dtype below return config.torch_dtype return torch.bfloat16