Add automated tests (#23)

This PR will run basic tests automatically on each subsequent PR

- convert a small model on every PR
- run existing tests on every PR
- enforce black / isort
- require checks on merge
- make sure tests are not flappy

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Dmitry Baranchuk <dmitrybaranchuk@gmail.com>
pull/24/head
justheuristic 2 years ago committed by GitHub
parent f5463812ad
commit e2711a033b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,26 @@
name: Check style
on:
push:
branches: [ master ]
pull_request:
jobs:
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: psf/black@stable
with:
options: "--check --diff"
version: "22.3.0"
isort:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.8
- uses: isort/isort-action@master
with:
isortVersion: "5.10.1"

@ -0,0 +1,89 @@
name: Tests
on:
push:
branches: [ master ]
pull_request:
jobs:
convert-model:
runs-on: ubuntu-latest
env:
BLOOM_TESTING_WRITE_TOKEN: ${{ secrets.BLOOM_TESTING_WRITE_TOKEN }}
timeout-minutes: 15
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Cache dependencies
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: Key-v1-py3.9-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Delete previous model, if exists
run: |
python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \
name='test-bloomd-350m-$GITHUB_HEAD_REF', organization='bloom-testing')" || true
- name: Convert model and push to hub
run: |
python -m cli.convert_model --model bigscience/bloom-350m --output_path ./converted_model \
--output_repo bloom-testing/test-bloomd-350m-$GITHUB_HEAD_REF --use_auth_token $BLOOM_TESTING_WRITE_TOKEN
run-tests:
runs-on: ubuntu-latest
needs: convert-model
strategy:
matrix:
python-version: [ 3.7, 3.8, 3.9 ]
fail-fast: false
timeout-minutes: 15
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Cache dependencies
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Test
run: |
export MODEL_NAME=bloom-testing/test-bloomd-350m-$GITHUB_HEAD_REF
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
--torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 &
SERVER1_PID=$!
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
# ^-- server 1 multiaddr is determined by --identity and --host_maddrs
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:24 \
--torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
SERVER2_PID=$!
sleep 30 # wait for server to download layers
# test individual blocks
export PYTHONPATH=.
BLOCK_UID=$MODEL_NAME.0 REF_NAME=$MODEL_NAME REF_INDEX=0 pytest tests/test_block_exact_match.py
BLOCK_UID=$MODEL_NAME.19 REF_NAME=$MODEL_NAME REF_INDEX=19 pytest tests/test_block_exact_match.py
REF_NAME=$MODEL_NAME pytest tests/test_chained_calls.py
REF_NAME=bigscience/bloom-350m pytest tests/test_full_model.py
kill -s SIGINT $SERVER1_PID $SERVER2_PID
echo "Done!"

@ -10,8 +10,9 @@ from huggingface_hub import Repository
from tqdm.auto import tqdm
from src import BloomModel
from src.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
from src.client import DistributedBloomConfig
from src.bloom.from_pretrained import CLIENT_BRANCH, BLOCK_BRANCH_PREFIX
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,10 @@
[tool.black]
line-length = 120
required-version = "22.3.0"
[tool.isort]
profile = "black"
line_length = 120
combine_as_imports = true
combine_star = true
known_local_folder = ["tests", "cli"]

@ -0,0 +1,6 @@
pytest==6.2.5 # see https://github.com/pytest-dev/pytest/issues/9621
pytest-forked
pytest-asyncio==0.16.0
black==22.3.0
isort==5.10.1
psutil

@ -0,0 +1,6 @@
torch==1.12.0
accelerate==0.10.0
huggingface-hub==0.7.0
bitsandbytes-cuda113==0.26.0
https://github.com/learning-at-home/hivemind/archive/d42c70331da43667da6d9020666df54806d8b561.zip
https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip

@ -9,8 +9,15 @@ import torch
import torch.nn as nn
import torch.nn.quantized.dynamic.modules.linear
from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
pre_process_alibi_for_pad, split_tensor_along_last_dim)
from src.bloom.ops import (
BloomGelu,
BloomScaledSoftmax,
attention_mask_func,
build_alibi_tensor,
dropout_add,
pre_process_alibi_for_pad,
split_tensor_along_last_dim,
)
class BloomAttention(nn.Module):

@ -10,14 +10,16 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss, LayerNorm
from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward)
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig
@ -445,12 +447,27 @@ class LMHead(nn.Module):
self.word_embeddings = word_embeddings
self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
@property
def in_features(self) -> int:
return self.word_embeddings.num_embeddings
@property
def out_features(self) -> int:
return self.word_embeddings.embedding_dim
@property
def weight(self):
return self.word_embeddings.weight
@property
def bias(self):
return None
def forward(self, hidden_states):
word_embeddings = self.word_embeddings.weight
# We use 'chunked_forward' only when embeddings are in half-precision on CPU.
if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
word_embeddings.device.type == 'cpu':
if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
lm_logits = self.chunked_forward(hidden_states)
else:
# Switch dtype in case word_embeddings are fp16/bf16
@ -459,20 +476,20 @@ class LMHead(nn.Module):
return lm_logits
def chunked_forward(self, hidden_states):
""" Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
chunk_size: provides trade-off between efficiency and extra memory consumption.
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
chunk_size: provides trade-off between efficiency and extra memory consumption.
"""
assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
word_embeddings = self.word_embeddings.weight
num_embeddings = self.word_embeddings.num_embeddings
hidden_states = hidden_states.float()
hidden_states = hidden_states.float()
output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
for i in range(0, num_embeddings, self.chunk_size):
chunk = word_embeddings[i: i + self.chunk_size].float()
output[..., i: i + self.chunk_size] = F.linear(hidden_states, chunk)
chunk = word_embeddings[i : i + self.chunk_size].float()
output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
return output
@ -565,7 +582,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None

@ -1,4 +1,4 @@
from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from src.client.remote_sequence_info import RemoteSequenceInfo
from src.client.remote_sequential import RemoteSequential
from src.client.sequence_manager import RemoteSequenceManager

@ -2,15 +2,20 @@
import os
from typing import Optional, Tuple
import hivemind
import torch
import torch.nn as nn
import hivemind
from hivemind import get_logger, use_hivemind_log_handler
from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead, BloomForSequenceClassification
from src.bloom.model import (
BloomConfig,
BloomForCausalLM,
BloomForSequenceClassification,
BloomModel,
BloomPreTrainedModel,
LMHead,
)
from src.client.remote_sequential import RemoteSequential
from src.data_structures import UID_DELIMITER
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -25,12 +30,13 @@ class DistributedBloomConfig(BloomConfig):
initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
num_prefix_tokens: int = 0 # a number of tokens for prompt tuning.
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
num_prefix_tokens: int = 0 # a number of tokens for prompt tuning.
class DistributedBloomModel(BloomModel):
"""BloomModel, but all transformer layers are hosted by the swarm"""
config_class = DistributedBloomConfig
def __init__(self, config: DistributedBloomConfig):
@ -49,7 +55,7 @@ class DistributedBloomModel(BloomModel):
)
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
self.h = RemoteSequential(config, dht, config.dht_prefix)
# Forbid accumulate grads for embeddings and layernorm
self.set_requires_grad(False)
@ -57,6 +63,14 @@ class DistributedBloomModel(BloomModel):
for p in self.parameters():
p.requires_grad = value
def forward(self, *args, use_cache=None, **kwargs):
if use_cache:
raise ValueError(
"Distributed forward does not support use_cache; for efficient cache-aware generation, "
"please use model.transformer.inference_session() or model.generate(...)"
)
return super().forward(*args, use_cache=False, **kwargs)
class DistributedBloomPrefix(DistributedBloomModel):
"""DistributedBloomModel with prefix tokens for prompt tuning"""
@ -76,7 +90,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
return prompts
def forward(
self,
self,
input_ids: Optional[torch.LongTensor],
inputs_embeds: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor],
@ -86,14 +100,16 @@ class DistributedBloomPrefix(DistributedBloomModel):
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None
return_dict=None,
):
assert input_ids is None or inputs_embeds is None, "You cannot specify both input_ids and inputs_embeds at the same time"
assert (
input_ids is None or inputs_embeds is None
), "You cannot specify both input_ids and inputs_embeds at the same time"
assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
batch_size = inputs_embeds.shape[0]
if attention_mask is not None:
@ -104,25 +120,26 @@ class DistributedBloomPrefix(DistributedBloomModel):
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
transformer_outputs = super().forward(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
return_dict=return_dict,
)
# Remove prefix
last_hidden_state = transformer_outputs[0][:, self.prefix_length:]
transformer_outputs['last_hidden_state'] = last_hidden_state
last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
transformer_outputs["last_hidden_state"] = last_hidden_state
return transformer_outputs
class DistributedBloomForCausalLM(BloomForCausalLM):
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
"""Similar to BloomForCausalLM, but all transformer layers are hosted by the swarm"""
config_class = DistributedBloomConfig
def __init__(self, config: DistributedBloomConfig):
@ -136,11 +153,23 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head.word_embeddings
def get_input_embeddings(self):
return self.transformer.word_embeddings
def set_output_embeddings(self, new_embeddings):
self.lm_head.word_embeddings.weight = new_embeddings.weight
def get_output_embeddings(self):
if self.config.tie_word_embeddings:
return None
return self.lm_head
def set_input_embeddings(self, new_embeddings: nn.Embedding):
assert isinstance(new_embeddings, nn.Embedding)
self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
def set_output_embeddings(self, new_lm_head: nn.Linear):
with torch.no_grad():
self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
self.lm_head.bias[...] = new_lm_head.bias
class DistributedBloomForSequenceClassification(BloomForSequenceClassification):

@ -3,6 +3,7 @@ from __future__ import annotations
import contextlib
import logging
import random
from typing import Optional, Union
import torch
from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
@ -12,7 +13,7 @@ from torch import nn
import src
from src.client.remote_block import RemoteTransformerBlock
from src.client.remote_sequence_info import RemoteSequenceInfo
from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import UID_DELIMITER
from src.dht_utils import _create_remote_modules_from_infos
@ -25,7 +26,15 @@ class RemoteSequential(nn.Module):
A sequence of transformer blocks hosted by the swarm.
"""
def __init__(self, config: src.DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
def __init__(
self,
config: src.DistributedBloomConfig,
dht: DHT,
prefix: str,
max_retries: int = 3,
p2p: Optional[P2P] = None,
sequence_manager: Optional[RemoteSequenceManager] = None,
):
logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
if prefix.endswith(UID_DELIMITER):
logger.warning(
@ -39,12 +48,17 @@ class RemoteSequential(nn.Module):
self.dht = dht
self.prefix = prefix
self.max_retries = max_retries
self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
logger.debug(f"Remote block uids: {block_uids}")
self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
block_uids = [f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
if sequence_manager is None:
logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
self.sequence_manager = RemoteSequenceManager(dht, block_uids)
self.is_subsequence = False
else:
assert isinstance(sequence_manager.block_uids, list)
logger.debug(f"Reusing sequence manager with {len(self.sequence_manager)}")
self.is_subsequence = self.sequence_manager.block_uids == block_uids
def forward(self, inputs: torch.Tensor):
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
@ -64,27 +78,38 @@ class RemoteSequential(nn.Module):
logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
return inputs
def __getitem__(self, block_index: int):
assert 0 <= block_index < self.config.n_layer
(module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p)
return module
def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
assert isinstance(ix, (int, slice))
if isinstance(ix, int):
assert 0 <= ix < self.config.n_layer
(module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p)
return module
else:
return RemoteSequential(
self.config,
self.dht,
prefix=self.prefix,
max_retries=self.max_retries,
p2p=self.p2p,
sequence_manager=self.sequence_manager[ix],
)
def __iter__(self):
for block_index in range(self.config.n_layer):
yield self[block_index]
def __len__(self):
return len(self.remote_sequence_info)
return len(self.sequence_manager)
def inference_session(self) -> RemoteSequentialInferenceSession:
self.remote_sequence_info.update_()
return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p)
self.sequence_manager.update_()
return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p)
class RemoteSequentialInferenceSession:
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
def __init__(self, remote_sequence_info: RemoteSequenceManager, p2p: P2P):
self.remote_sequence_info = remote_sequence_info
self.p2p = p2p
self.closed = False

@ -1,29 +1,27 @@
from __future__ import annotations
import threading
from typing import List, NamedTuple, Optional, Sequence, Tuple
from typing import List, Optional, Sequence, Tuple, Union
from hivemind import DHT, PeerID
from hivemind import DHT, DHTExpiration
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src.data_structures import ModuleUID, RemoteModuleInfo, ServerState
from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
from src.dht_utils import get_remote_module_infos
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
class RemoteSequenceInfo:
class RemoteSequenceManager:
"""Keeps and updates the meta-information about which peers host which blocks"""
dht: DHT
block_uids: List[ModuleUID]
block_infos: List[Optional[RemoteModuleInfo]]
spans_by_priority: List[Span] # sorted from best to worst
spans_containing_block: Tuple[List[Span]]
spans_by_priority: List[RemoteSpanInfo] # sorted from best to worst
spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
last_update_time: DHTExpiration
lock_changes: threading.Lock
def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
@ -32,6 +30,7 @@ class RemoteSequenceInfo:
self.block_infos = [None] * len(self.block_uids)
self.spans_by_priority = []
self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
self.last_update_time = -float("inf")
self.lock_changes = threading.Lock()
self.update_()
@ -39,6 +38,18 @@ class RemoteSequenceInfo:
assert info is not None, f"Found no remote peers for block {uid}"
assert self.spans_by_priority and self.spans_containing_block
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
"""Get a RemoteSequenceManager for a sub-sequence of blocks"""
assert isinstance(ix, (int, slice))
if not isinstance(ix, slice):
ix = slice(int(ix), int(ix) + 1, 1)
with self.lock_changes:
subseq = RemoteSequenceManager(self.dht, self.block_uids[ix])
subseq.block_infos = self.block_infos[ix]
subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
subseq.last_update_time = self.last_update_time
return subseq
def update_(self):
with self.lock_changes:
self.update_block_infos_()
@ -67,15 +78,15 @@ class RemoteSequenceInfo:
if server.state != ServerState.ONLINE:
continue
if peer_id not in active_spans:
active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
else: # peer_id in active_spans
active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
active_spans[peer_id].end = block_index + 1
for peer_id in list(active_spans.keys()):
if (
peer_id not in info.servers or
info.servers[peer_id].state != ServerState.ONLINE or
block_index == len(block_infos) - 1
peer_id not in info.servers
or info.servers[peer_id].state != ServerState.ONLINE
or block_index == len(block_infos) - 1
):
closed_spans.append(active_spans.pop(peer_id))
assert not active_spans

@ -23,5 +23,16 @@ class ServerInfo:
@dataclass
class RemoteModuleInfo:
"""A remote module that is served by one or more servers"""
uid: ModuleUID
servers: Dict[PeerID, ServerInfo]
@dataclass
class RemoteSpanInfo:
"""A chain of remote blocks served by one specific remote peer"""
start: int
end: int
peer_id: PeerID

@ -136,8 +136,12 @@ async def _get_remote_module_infos(
try:
peer_id = PeerID.from_base58(peer_id)
state, throughput = server_info.value
if not (isinstance(state, int) and isinstance(throughput, float) and
math.isfinite(throughput) and throughput >= 0.0):
if not (
isinstance(state, int)
and isinstance(throughput, float)
and math.isfinite(throughput)
and throughput >= 0.0
):
raise ValueError(f"Invalid server info: {server_info}")
servers[peer_id] = ServerInfo(ServerState(state), throughput)
except (TypeError, ValueError) as e:

@ -9,10 +9,10 @@ def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[Remot
if module is None:
throughputs.append(0)
continue
throughputs.append(sum(server.throughput for server in module.servers.values()
if server.state != ServerState.OFFLINE))
throughputs.append(
sum(server.throughput for server in module.servers.values() if server.state != ServerState.OFFLINE)
)
options = [(sorted(throughputs[i:i + num_blocks]), i)
for i in range(0, len(throughputs) - num_blocks + 1)]
options = [(sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1)]
best_start = min(options)[1]
return list(range(best_start, best_start + num_blocks))

@ -4,7 +4,7 @@ import multiprocessing as mp
import random
import threading
import time
from typing import Dict, Literal, Optional, Sequence, Union
from typing import Dict, Optional, Sequence, Union
import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
@ -13,7 +13,7 @@ from hivemind.moe.server.runtime import Runtime
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src import declare_active_modules, BloomConfig
from src import BloomConfig, declare_active_modules
from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
from src.dht_utils import get_remote_module_infos
@ -98,7 +98,7 @@ class Server(threading.Thread):
cls,
prefix: Optional[str],
converted_model_name_or_path: str,
throughput: Union[float, Literal['auto', 'eval']],
throughput: Union[float, str],
num_blocks: Optional[int] = None,
block_indices: Optional[str] = None,
num_handlers: Optional[int] = None,
@ -140,17 +140,15 @@ class Server(threading.Thread):
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
memory_cache = MemoryCache(device, cache_size_bytes)
assert isinstance(throughput, float) or throughput in ['auto', 'eval']
if throughput in ['auto', 'eval']:
throughput = get_host_throughput(device, force_eval=(throughput == 'eval'))
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
if isinstance(torch_dtype, str):
torch_dtype = DTYPE_MAP[torch_dtype]
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
block_config = BloomConfig.from_pretrained(
converted_model_name_or_path, use_auth_token=use_auth_token
)
block_config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
if block_indices is not None:
try:
@ -288,7 +286,7 @@ class ModuleAnnouncerThread(threading.Thread):
throughput: float,
update_period: float = 30,
expiration: float,
**kwargs
**kwargs,
):
super().__init__(**kwargs)
self.module_backends = module_backends

@ -20,10 +20,10 @@ use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
DEFAULT_CACHE_PATH = Path(Path.home(), '.cache', project_name, 'throughput.json')
DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, 'throughput.lock')
DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", project_name, "throughput.json")
DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, "throughput.lock")
SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], 'cli', 'speed_test.py')
SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], "cli", "speed_test.py")
@dataclass
@ -43,7 +43,7 @@ def get_host_throughput(
# We use the system-wide lock since only one process at a time can measure the host throughput
os.makedirs(lock_path.parent, exist_ok=True)
with open(lock_path, 'wb') as lock_fd:
with open(lock_path, "wb") as lock_fd:
logger.info("Loading throughput info")
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
# The OS will release the lock when lock_fd is closed or the process is killed
@ -63,7 +63,7 @@ def get_host_throughput(
info = measure_throughput_info()
try:
os.makedirs(cache_path.parent, exist_ok=True)
with open(cache_path, 'w') as cache_fd:
with open(cache_path, "w") as cache_fd:
json.dump(asdict(info), cache_fd)
except Exception:
logger.exception(f"Failed to save throughput info in {cache_path}")
@ -73,29 +73,30 @@ def get_host_throughput(
def measure_throughput_info() -> ThroughputInfo:
logger.info("Measuring network, CPU, and GPU throughput. "
"This takes about a minute and will be cached for future runs")
logger.info(
"Measuring network, CPU, and GPU throughput. " "This takes about a minute and will be cached for future runs"
)
# We measure throughput in "(inference) requests per second" (RPS) using a fixed model
config = BloomConfig.from_pretrained('bigscience/test-bloomd-6b3')
config = BloomConfig.from_pretrained("bigscience/test-bloomd-6b3")
network_rps = measure_network_rps(config)
device_rps = {'cpu': measure_device_rps('cpu', config)}
device_rps = {"cpu": measure_device_rps("cpu", config)}
if torch.cuda.is_available():
device_rps['cuda'] = measure_device_rps('cuda', config)
device_rps["cuda"] = measure_device_rps("cuda", config)
return ThroughputInfo(network_rps=network_rps, device_rps=device_rps)
def measure_network_rps(config: BloomConfig) -> float:
proc = subprocess.run([SPEED_TEST_PATH, '--json'], capture_output=True)
proc = subprocess.run([SPEED_TEST_PATH, "--json"], capture_output=True)
if proc.returncode != 0:
raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
network_info = json.loads(proc.stdout)
bits_per_request = config.hidden_size * 32
network_rps = min(network_info['download'], network_info['upload']) / bits_per_request
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
logger.info(
f"Network throughput: "
@ -120,7 +121,7 @@ def measure_device_rps(device: str, config: BloomConfig, layer_index: int = 0, n
elapsed += time.perf_counter() - start_time
device_rps = n_steps / elapsed
device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == 'cuda' else 'CPU'
device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == "cuda" else "CPU"
logger.info(f"Compute throughput ({device_name}): {device_rps:.2f} RPS")
return device_rps

Binary file not shown.

@ -3,6 +3,7 @@ import os
import hivemind
import torch
import transformers
from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_block import RemoteTransformerBlock
@ -19,16 +20,18 @@ if not BLOCK_UID:
raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID[-1].split(".")[-1]))
REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID.split(".")[-1]))
def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
remote_block = get_remote_module(dht, BLOCK_UID)
assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
assert isinstance(remote_block, RemoteTransformerBlock)
ref_config = transformers.AutoConfig.from_pretrained(REF_NAME)
inputs = torch.randn(1, 8, 4096)
inputs = torch.randn(1, 8, ref_config.hidden_size)
(outputs_forward,) = remote_block(inputs)
outputs_inference = []

@ -0,0 +1,97 @@
######
# Warning:torch this test is a work in progress. It will be modified soon.
# - if you want more stable tests, see test_block_exact_match
# - if you want to figure out chained inference, ask yozh
import os
import hivemind
import torch
import transformers
from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_block import RemoteTransformerBlock
from src.dht_utils import get_remote_module
INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
if not INITIAL_PEERS:
raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
INITIAL_PEERS = INITIAL_PEERS.split()
MODEL_NAME = os.environ.get("MODEL_NAME")
if not MODEL_NAME:
raise RuntimeError("Must specify MODEL_NAME as a name of a model to be tested")
REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0")
assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
assert isinstance(remote_block, RemoteTransformerBlock)
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
ref_blocks = [
load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
]
inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
outputs_rpc = remote_block.forward(inputs)[0]
outputs_rpc.sum().backward()
grads_rpc = inputs.grad
inputs.grad = None
hidden_states = inputs
for ref_block in ref_blocks:
hidden_states = ref_block.forward(hidden_states)[0]
outputs_ref = hidden_states
outputs_ref.sum().backward()
grads_ref = inputs.grad
assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward)
assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)
def test_chained_inference_exact_match(atol_inference=1e-4):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0")
assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
assert isinstance(remote_block, RemoteTransformerBlock)
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)
inputs = torch.randn(1, 8, config.hidden_size)
outputs_inference = []
with remote_block.inference_session() as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)
ref_blocks = [
load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
]
outputs_ref = []
caches = [None, None]
for i in range(inputs.shape[1]):
new_caches = []
hidden_states = inputs[:, i : i + 1, :]
for ref_block, cache in zip(ref_blocks, caches):
with torch.no_grad():
hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache)
new_caches.append(new_cache)
outputs_ref.append(hidden_states)
caches = new_caches
outputs_ref = torch.cat(outputs_ref, dim=1)
assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference)

@ -1,59 +0,0 @@
######
# Warning:torch this test is a work in progress. It will be modified soon.
# - if you want more stable tests, see test_block_exact_match
# - if you want to figure out chained inference, ask yozh
import os
import hivemind
import torch
from hivemind.moe.expert_uid import ExpertInfo
from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_block import RemoteTransformerBlock
from src.dht_utils import get_remote_module
INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
if not INITIAL_PEERS:
raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
INITIAL_PEERS = INITIAL_PEERS.split()
BLOCK_UID = os.environ.get("BLOCK_UID")
if not BLOCK_UID:
raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
# seq_length > 128: rpc_forward_stream & rpc_backward_stream
# seq_length <= 128: rpc_forward & rpc_backward
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
(remote_block,) = get_remote_module(dht, BLOCK_UID)
assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
assert isinstance(remote_block, RemoteTransformerBlock)
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4 bloom6b3.5", remote_block._info.peer_id)
ref_blocks = [
load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
]
inputs = torch.randn(1, seq_length, 4096, requires_grad=True)
outputs_rpc = remote_block.forward(inputs)[0]
outputs_rpc.sum().backward()
grads_rpc = inputs.grad
inputs.grad = None
hidden_states = inputs
for ref_block in ref_blocks:
hidden_states = ref_block.forward(hidden_states)[0]
outputs_ref = hidden_states
outputs_ref.sum().backward()
grads_ref = inputs.grad
assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward)
assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)

@ -1,64 +0,0 @@
######
# Warning:torch this test is a work in progress. It will be modified soon.
# - if you want more stable tests, see test_block_exact_match
# - if you want to figure out chained inference, ask yozh
import os
import hivemind
import torch
from hivemind.moe.expert_uid import ExpertInfo
from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_block import RemoteTransformerBlock
from src.dht_utils import get_remote_module
INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
if not INITIAL_PEERS:
raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
INITIAL_PEERS = INITIAL_PEERS.split()
BLOCK_UID = os.environ.get("BLOCK_UID")
if not BLOCK_UID:
raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID[-1].split(".")[-1]))
def test_remote_block_exact_match(atol_inference=1e-4):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
remote_block = get_remote_module(dht, BLOCK_UID)
assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
assert isinstance(remote_block, RemoteTransformerBlock)
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4", remote_block._info.peer_id)
inputs = torch.randn(1, 8, 4096)
outputs_inference = []
with remote_block.inference_session() as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)
ref_blocks = [
load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
]
outputs_ref = []
caches = [None, None]
for i in range(inputs.shape[1]):
new_caches = []
hidden_states = inputs[:, i : i + 1, :]
for ref_block, cache in zip(ref_blocks, caches):
with torch.no_grad():
hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache)
new_caches.append(new_cache)
outputs_ref.append(hidden_states)
caches = new_caches
outputs_ref = torch.cat(outputs_ref, dim=1)
assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference)

@ -24,9 +24,10 @@ if not MODEL_NAME:
REF_NAME = os.environ.get("REF_NAME")
def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="bloom6b3"):
def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
assert isinstance(model, DistributedBloomForCausalLM)
assert len(model.transformer.h) == model.config.n_layer
test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
@ -35,26 +36,29 @@ def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="
logger.info("Forward outputs are finite")
if REF_NAME:
ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
# note: this creates a dummy mask to make the test compatible with older transformer versions
# prior to https://github.com/huggingface/transformers/pull/17837
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
with torch.no_grad():
ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
# note: this creates a dummy mask to make the test compatible with older transformer versions
# prior to https://github.com/huggingface/transformers/pull/17837
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
del ref_model, ref_outputs
else:
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
embs = model.transformer.word_embeddings(test_inputs)
embs = model.transformer.word_embeddings_layernorm(embs)
recurrent_outputs = []
with model.transformer.h.inference_session() as sess:
for t in range(embs.shape[1]):
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
dictionary = model.transformer.word_embeddings.weight.t()
recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
recurrent_outputs = (recurrent_outputs @ dictionary).float()
with torch.inference_mode():
embs = model.transformer.word_embeddings(test_inputs)
embs = model.transformer.word_embeddings_layernorm(embs)
recurrent_outputs = []
with model.transformer.h.inference_session() as sess:
for t in range(embs.shape[1]):
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
dictionary = model.transformer.word_embeddings.weight.t()
recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
recurrent_outputs = (recurrent_outputs @ dictionary).float()
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
logger.info("Inference is consistent with forward")

Loading…
Cancel
Save