Merge branch 'client' into main
commit
d42e8abd38
@ -0,0 +1,49 @@
|
||||
# this code is in active development, interfaces may change
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import hivemind
|
||||
from hivemind import DHT, get_logger, use_hivemind_log_handler
|
||||
|
||||
from src.bloom import BloomForCausalLM, DistributedBloomConfig
|
||||
from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
|
||||
from src.client.remote_sequential import RemoteSequential
|
||||
from src.data_structures import UID_DELIMITER
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class DistributedBloomForCausalLM(BloomForCausalLM):
|
||||
"""BloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
|
||||
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
|
||||
super().__init__(config)
|
||||
assert len(self.transformer.h) == 0
|
||||
config.n_layer = n_layer
|
||||
self.transformer.h = RemoteSequential(config, dht, prefix)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
||||
if 'initial_peers' not in kwargs:
|
||||
raise ValueError("Please specify initial_peers=...")
|
||||
dht = hivemind.DHT(
|
||||
initial_peers=kwargs.pop('initial_peers'), client_mode=kwargs.pop('client_mode', True),
|
||||
start=True)
|
||||
|
||||
if 'prefix' not in kwargs:
|
||||
logger.debug(f"No DHT prefix specified; using automatic prefix {pretrained_model_name_or_path}")
|
||||
assert UID_DELIMITER not in pretrained_model_name_or_path, \
|
||||
f"Cannot infer prefix automatically from {pretrained_model_name_or_path}; please specify prefix=..."
|
||||
prefix = kwargs.pop("prefix", pretrained_model_name_or_path)
|
||||
|
||||
config = DistributedBloomConfig.from_pretrained(pretrained_model_name_or_path, revision=CLIENT_BRANCH, **kwargs)
|
||||
model = cls(config, dht, prefix)
|
||||
model.load_state_dict(_load_state_dict(
|
||||
pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token')
|
||||
), strict=True)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import threading
|
||||
from functools import partial
|
||||
from typing import Tuple, List, Optional, Sequence, NamedTuple
|
||||
|
||||
from hivemind import DHT, PeerID
|
||||
from hivemind.utils.logging import use_hivemind_log_handler, get_logger
|
||||
|
||||
from src.data_structures import ModuleUID, RemoteModuleInfo
|
||||
from src.dht_utils import _get_remote_module_infos
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
Span = NamedTuple('Span', [('start', int), ('end', Optional[int]), ('peer_id', PeerID)])
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=False, init=False)
|
||||
class RemoteSequenceInfo:
|
||||
"""Keeps and updates the meta-information about which peers host which blocks"""
|
||||
dht: DHT
|
||||
block_uids: List[ModuleUID, ...]
|
||||
block_infos: List[Optional[RemoteModuleInfo], ...]
|
||||
spans_by_priority: List[Span] # sorted from best to worst
|
||||
spans_containing_block: Tuple[List[Span], ...]
|
||||
lock_changes: threading.Lock
|
||||
|
||||
def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
|
||||
self.dht = dht
|
||||
self.block_uids = list(block_uids)
|
||||
self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
|
||||
self.spans_by_priority = []
|
||||
self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
|
||||
self.lock_changes = threading.Lock()
|
||||
self.update_()
|
||||
|
||||
for uid, info in zip(self.block_uids, self.block_infos):
|
||||
assert info is not None, f"Found no remote peers for block {uid}"
|
||||
assert self.spans_by_priority and self.spans_containing_block
|
||||
|
||||
def update_(self):
|
||||
with self.lock_changes:
|
||||
self.update_block_infos_()
|
||||
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
|
||||
|
||||
def update_block_infos_(self):
|
||||
new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
|
||||
partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
|
||||
return_future=False)
|
||||
assert len(new_block_infos) == len(self.block_uids)
|
||||
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
|
||||
if info is None:
|
||||
logger.warning(f"Found no block info for block {uid}")
|
||||
if not isinstance(info, RemoteModuleInfo):
|
||||
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
|
||||
if not info.peer_ids:
|
||||
logger.warning(f"Found no active peers for block {uid}")
|
||||
if info.uid != uid:
|
||||
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
|
||||
if not isinstance(info.peer_ids, set):
|
||||
logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
|
||||
self.block_infos[block_index] = info
|
||||
|
||||
@staticmethod
|
||||
def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
|
||||
closed_spans = []
|
||||
active_spans = {}
|
||||
for block_index, info in enumerate(block_infos):
|
||||
for peer_id in info.peer_ids:
|
||||
if peer_id not in active_spans:
|
||||
active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
|
||||
else: # peer_id in active_spans
|
||||
active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
|
||||
|
||||
for peer_id in list(active_spans.keys()):
|
||||
if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
|
||||
closed_spans.append(active_spans.pop(peer_id))
|
||||
assert not active_spans
|
||||
|
||||
closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
|
||||
|
||||
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
|
||||
for span in closed_spans:
|
||||
for block_index in range(span.start, span.end):
|
||||
spans_containing_block[block_index].append(span)
|
||||
|
||||
return closed_spans, spans_containing_block
|
||||
|
||||
def __len__(self):
|
||||
return len(self.block_uids)
|
@ -0,0 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import random
|
||||
|
||||
import torch
|
||||
from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
|
||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||
from hivemind.moe.expert_uid import ExpertInfo
|
||||
from torch import nn
|
||||
|
||||
from src import DistributedBloomConfig, RemoteTransformerBlock
|
||||
from src.client.remote_sequence_info import RemoteSequenceInfo
|
||||
from src.data_structures import UID_DELIMITER
|
||||
from src.dht_utils import _create_remote_modules_from_infos
|
||||
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class RemoteSequential(nn.Module):
|
||||
"""
|
||||
A sequence of transformer blocks hosted by the swarm.
|
||||
"""
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
|
||||
logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
|
||||
if prefix.endswith(UID_DELIMITER):
|
||||
logger.warning(
|
||||
f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
|
||||
f"This will cause {self.__class__.__name__} to look for modules under "
|
||||
f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dht = dht
|
||||
self.prefix = prefix
|
||||
self.max_retries = max_retries
|
||||
self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
|
||||
|
||||
block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
|
||||
logger.debug(f"Remote block uids: {block_uids}")
|
||||
self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
|
||||
for block_index in range(self.config.n_layer):
|
||||
for retry_index in range(self.max_retries):
|
||||
try:
|
||||
block = self[block_index]
|
||||
(outputs,) = block(inputs)
|
||||
assert isinstance(outputs, torch.Tensor)
|
||||
assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
|
||||
inputs = outputs
|
||||
break
|
||||
except Exception as e:
|
||||
if retry_index == self.max_retries - 1:
|
||||
raise e
|
||||
else:
|
||||
logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
|
||||
return inputs
|
||||
|
||||
def __getitem__(self, block_index: int):
|
||||
assert 0 <= block_index < self.config.n_layer
|
||||
(module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p)
|
||||
return module
|
||||
|
||||
def __iter__(self):
|
||||
for block_index in range(self.config.n_layer):
|
||||
yield self[block_index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.remote_sequence_info)
|
||||
|
||||
def inference_session(self) -> RemoteSequentialInferenceSession:
|
||||
self.remote_sequence_info.update_()
|
||||
return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p)
|
||||
|
||||
|
||||
class RemoteSequentialInferenceSession:
|
||||
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
|
||||
|
||||
def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
|
||||
self.remote_sequence_info = remote_sequence_info
|
||||
self.p2p = p2p
|
||||
self.closed = False
|
||||
self.stack = contextlib.ExitStack()
|
||||
self.active_sessions = []
|
||||
|
||||
def __enter__(self):
|
||||
assert not self.closed
|
||||
self.stack.__enter__()
|
||||
# TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
|
||||
current_block = 0
|
||||
while current_block != len(self.remote_sequence_info):
|
||||
candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
|
||||
chosen_span = random.choice(candidate_spans) # TODO this is a temporary code
|
||||
assert chosen_span.start <= current_block < chosen_span.end
|
||||
|
||||
# TODO begin throwaway prototype code
|
||||
remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
|
||||
_=remote.info #TODO fix
|
||||
span_uids = self.remote_sequence_info.block_uids[current_block: chosen_span.end]
|
||||
remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
|
||||
self.active_sessions.append(remote.inference_session())
|
||||
self.stack.enter_context(self.active_sessions[-1])
|
||||
current_block = chosen_span.end
|
||||
# TODO end throwaway prototype code
|
||||
|
||||
return self
|
||||
|
||||
def step(self, inputs: torch.Tensor):
|
||||
assert not self.closed
|
||||
for session in self.active_sessions:
|
||||
outputs = session.step(inputs)
|
||||
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
|
||||
inputs = outputs
|
||||
return inputs
|
||||
|
||||
def close(self, *exc_details):
|
||||
"""Finish a given inference session, close the underlying connection"""
|
||||
if not self.closed:
|
||||
self.stack.__exit__(*exc_details or (None, None, None))
|
||||
self.active_sessions.clear()
|
||||
self.closed = True
|
||||
|
||||
def __exit__(self, *exc_details):
|
||||
self.close(*exc_details)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
@ -0,0 +1,57 @@
|
||||
# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
|
||||
import os
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from hivemind import use_hivemind_log_handler, get_logger
|
||||
|
||||
from src.client.remote_model import DistributedBloomForCausalLM
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
|
||||
if not INITIAL_PEERS:
|
||||
raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
|
||||
INITIAL_PEERS = INITIAL_PEERS.split()
|
||||
|
||||
|
||||
MODEL_NAME = os.environ.get("MODEL_NAME")
|
||||
if not MODEL_NAME:
|
||||
raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
|
||||
|
||||
REF_NAME = os.environ.get("REF_NAME")
|
||||
|
||||
|
||||
def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
|
||||
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
||||
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
||||
assert len(model.transformer.h) == model.config.n_layer
|
||||
|
||||
test_inputs = tokenizer("A cat sat on a mat", return_tensors='pt')['input_ids']
|
||||
parallel_outputs = model.forward(test_inputs).logits
|
||||
assert torch.all(torch.isfinite(parallel_outputs))
|
||||
logger.info("Forward outputs are finite")
|
||||
|
||||
if REF_NAME:
|
||||
ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
|
||||
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
|
||||
# note: this creates a dummy mask to make the test compatible with older transformer versions
|
||||
# prior to https://github.com/huggingface/transformers/pull/17837
|
||||
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
|
||||
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
|
||||
else:
|
||||
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
|
||||
|
||||
embs = model.transformer.word_embeddings(test_inputs)
|
||||
embs = model.transformer.word_embeddings_layernorm(embs)
|
||||
recurrent_outputs = []
|
||||
with model.transformer.h.inference_session() as sess:
|
||||
for t in range(embs.shape[1]):
|
||||
recurrent_outputs.append(sess.step(embs[:, t: t + 1, :]))
|
||||
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
|
||||
recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
|
||||
recurrent_outputs = model.lm_head(recurrent_outputs)
|
||||
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
|
||||
logger.info("Inference is consistent with forward")
|
Loading…
Reference in New Issue