Fix floating point issues in block_selection.py (#89)

forward-backward-timeouts
Alexander Borzunov 1 year ago committed by GitHub
parent c07a7e0812
commit 898f614515
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -60,7 +60,7 @@ A stable version of the code and a public swarm open to everyone will be release
### 📋 Terms of use ### 📋 Terms of use
Before using Petals to run a language model, please make sure that you are familiar with its terms of use, risks, and limitations. For BLOOM, they are described in its [model card](https://huggingface.co/bigscience/bloom) and [license](https://huggingface.co/spaces/bigscience/license). Before using Petals to run a language model, please make sure that you are familiar with its terms of use, risks, and limitations. In case of BLOOM, they are described in its [model card](https://huggingface.co/bigscience/bloom) and [license](https://huggingface.co/spaces/bigscience/license).
### 🔒 Privacy and security ### 🔒 Privacy and security
@ -101,7 +101,7 @@ For macOS, you can *probably* run everything normally if you manage to install d
## 🚀 Getting Started ## 🚀 Getting Started
This is a toy example running on a local machine without GPU and with a tiny model. This is a toy example running on a local machine without GPU and with a tiny model.
For a detailed instruction with larger models, see ["Launch your own swarm"](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm). For a detailed instruction with larger models, see ["Launch your own swarm"](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm).
First, run a couple of servers, each in a separate shell. To launch your first server, run: First, run a couple of servers, each in a separate shell. To launch your first server, run:
@ -133,7 +133,7 @@ You can assign `--initial_peers` to one or multiple addresses of other servers,
The only requirement is that at least one of them is running at the time. The only requirement is that at least one of them is running at the time.
Before you proceed, __please run 3 servers__ for a total of 24 blocks (3x8). If you are running a different model, Before you proceed, __please run 3 servers__ for a total of 24 blocks (3x8). If you are running a different model,
make sure your servers have enough total `--num_blocks` to cover that model. make sure your servers have enough total `--num_blocks` to cover that model.
Once your have enough servers, you can use them to train and/or inference the model: Once your have enough servers, you can use them to train and/or inference the model:
```python ```python
@ -162,8 +162,8 @@ print("Gradients (norm):", model.transformer.word_embeddings.weight.grad.norm())
``` ```
Of course, this is a simplified code snippet. For actual training, see the example notebooks with "deep" prompt-tuning: Of course, this is a simplified code snippet. For actual training, see the example notebooks with "deep" prompt-tuning:
- Simple text semantic classification: [examples/prompt-tuning-sst2.ipynb](./examples/prompt-tuning-sst2.ipynb). - Simple text semantic classification: [examples/prompt-tuning-sst2.ipynb](./examples/prompt-tuning-sst2.ipynb)
- A personified chatbot: [examples/prompt-tuning-personachat.ipynb](./examples/prompt-tuning-personachat.ipynb). - A personified chatbot: [examples/prompt-tuning-personachat.ipynb](./examples/prompt-tuning-personachat.ipynb)
Here's a [more advanced tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) that covers 8-bit quantization and best practices for running Petals. Here's a [more advanced tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) that covers 8-bit quantization and best practices for running Petals.

@ -32,7 +32,10 @@ def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict
if module is None: if module is None:
continue continue
for peer_id, server in module.servers.items(): # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
# If the order were not defined, we would get slightly different values due to floating point errors,
# which may cause excess block replacements.
for peer_id, server in sorted(module.servers.items()):
if server.state == ServerState.OFFLINE: if server.state == ServerState.OFFLINE:
continue continue
@ -47,17 +50,14 @@ def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict
return spans, throughputs return spans, throughputs
def _choose_best_start(throughputs: np.ndarray, num_blocks: int, cur_start: Optional[int]) -> int: def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
options = ( options = ((sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1))
(sorted(throughputs[i : i + num_blocks]), i != cur_start, i)
for i in range(0, len(throughputs) - num_blocks + 1)
)
return min(options)[-1] return min(options)[-1]
def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]: def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
_, throughputs = _compute_spans(module_infos) _, throughputs = _compute_spans(module_infos)
start = _choose_best_start(throughputs, num_blocks, None) start = _choose_best_start(throughputs, num_blocks)
return list(range(start, start + num_blocks)) return list(range(start, start + num_blocks))
@ -69,16 +69,22 @@ def should_choose_other_blocks(
spans, throughputs = _compute_spans(module_infos) spans, throughputs = _compute_spans(module_infos)
initial_throughput = throughputs.min() initial_throughput = throughputs.min()
eps = 1e-3
assert local_peer_id in spans, "Span served by this server is not present in the DHT" assert local_peer_id in spans, "Span served by this server is not present in the DHT"
local_span = spans[local_peer_id] local_span = spans[local_peer_id]
throughputs[local_span.start : local_span.end] -= local_span.throughput throughputs[local_span.start : local_span.end] -= local_span.throughput * (1 + eps)
# Without (1 + eps) here, we would sometimes subtract a value slightly less than local_span.throughput
# due to the floating point error, which would cause excess block replacements.
# Also, subtracting local_span.throughput * (1 + eps) makes _choose_best_start() prefer
# the previous server position in case of other things being almost equal.
new_start = _choose_best_start(throughputs, local_span.length, local_span.start) new_start = _choose_best_start(throughputs, local_span.length)
if local_span.start == new_start: if local_span.start == new_start:
return False # This server is on its best place already return False # This server is on its best place already
local_span.move_to(new_start)
throughputs[local_span.start : local_span.end] += local_span.throughput * eps
local_span.move_to(new_start)
throughputs[local_span.start : local_span.end] += local_span.throughput throughputs[local_span.start : local_span.end] += local_span.throughput
moved = True moved = True
@ -89,18 +95,18 @@ def should_choose_other_blocks(
moved = False moved = False
for peer_id in servers: for peer_id in servers:
span = spans[peer_id] span = spans[peer_id]
throughputs[span.start : span.end] -= span.throughput throughputs[span.start : span.end] -= span.throughput * (1 + eps)
new_start = _choose_best_start(throughputs, span.length, span.start) new_start = _choose_best_start(throughputs, span.length)
throughputs[span.start : span.end] += span.throughput * eps
if span.start != new_start: if span.start != new_start:
span.move_to(new_start) span.move_to(new_start)
moved = True moved = True
throughputs[span.start : span.end] += span.throughput throughputs[span.start : span.end] += span.throughput
new_throughput = throughputs.min() new_throughput = throughputs.min()
actual_quality = initial_throughput / new_throughput actual_quality = initial_throughput / new_throughput
logger.info(f"Swarm balance quality: {actual_quality * 100:.1f}%") logger.info(f"Swarm balance quality: {actual_quality * 100:.1f}%")
eps = 1e-6
return actual_quality < balance_quality - eps return actual_quality < balance_quality - eps

Loading…
Cancel
Save