We need to sample the next server using its throughput as the weight to actually achieve max throughput for fine-tuning.
As an example, imagine a situation where we have 3 servers with throughputs [1000, 500, 1] hosting the same blocks, then compare the uniform and weighted sampling strategies.
This PR:
1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.**
The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way.
2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.**
`dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided.
3. **Simplifies retry logic.**
Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps.
4. **Removes deprecated `RemoteTransformerBlock`.**
`RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due.
5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.**
This functions duplicate the functionality of the `RemoteSequential` constructor.
6. (minor) **Removes `RemoteSequential.is_subsequence` flag.**
This flag worked incorrectly and was never used. I am removing it for the sake of simplicity.
- After #285, `load_pretrained_block()` uses `accelerate.utils.set_module_tensor_to_device()`
- In accelerate>=0.16.0, it saves the tensor in the dtype previously used by the model instead of dtype of the weights (https://github.com/huggingface/accelerate/pull/920)
- Because of that, blocks and attention caches used float32, which caused OOMs
- This PR makes `load_pretrained_block()` respect `torch_dtype` (default: `"auto"`, which means reading `torch_dtype` from `config.json`)
**Why?**
- We'd like to avoid excess threads for the original sequence manager in case if we only use its slices (e.g. when we add adapters or need only a subset of model blocks):
- If we create a sequence manager just before a fork (e.g. in a web app backend or a multi-thread benchmark), we'd like to avoid excess threads in the original process and only use this thread in child processes where we actually call `.make_sequence()`.
`use_auto_relay=True` makes the libp2p daemon look for relays to become reachable if we are behind NAT/firewall. However, being reachable is not necessary for the Petals client, and we should not spend the relays' capacity on this.
This PR fixes issues of #290:
- hivemind bfloat16 codec crashed on dummy tensors (with 0 elements), see https://github.com/learning-at-home/hivemind/pull/560 (this PR makes Petals depend on the latest hivemind version from the repo, it's temporary)
- transformers version check mismatched with the version allowed in `setup.cfg`
Also:
- This PR enables 8-bit by default for TP. Even though TP in 8-bit may be slower, we currently prefer to host more blocks to increase the network's stability.
- new bitsandbytes supports newer *and* older GPUs
- new hivemind supports a better bfloat16 codec
Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
For some reasons, right now 15 sec is not enough to connect to the bootstrap peers in the public swarm, as reported by multiple users and observed by me. Increasing it to 120 sec until we find the root cause of the issue.
This PR increases `request_timeout`, since the previous default of 30 sec is not enough for many use cases.
Previously, we kept the request timeout low since we assumed that the server could freeze on dial if the target peer is behind a firewall. However, apparently, it won't freeze because libp2p has its own [dial timeout](https://github.com/libp2p/go-libp2p/blob/v0.26.0/core/network/context.go#L11).
Before this PR, `model.generate()` returned one excess token when resuming generation with an existing (the last token of the previous session, `session.last_token_id`). This is an unexpected behavior not convenient for the downstream apps, so this PR changes it until it's too late.
Even if the swarm seems to have at least 2 servers for each block, turning off on one of the servers could break it. That's because once a server is turned off, others may move to a better position, creating a significant downtime on their way. This PR prohibits switching blocks if it would make the swarm disjoint along the way.