Merge branch 'client' into main

pull/18/head
justheuristic 2 years ago
commit d42e8abd38

@ -65,7 +65,7 @@ loss = (outputs * torch.randn_like(outputs)).norm()
loss.backward()
# test inference, one block
with layer3.begin_inference_session() as sess:
with layer3.inference_session() as sess:
for i in range(10):
res = sess.step(torch.ones(1, 1, 4096))
```
@ -94,7 +94,9 @@ python -m cli.run_server --prefix bloom6b3 --converted_model_name_or_path bigsci
export PYTHONPATH=. INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT"
BLOCK_UID=bloom6b3.3 pytest tests/test_block_exact_match.py
BLOCK_UID=bloom6b3.4 pytest tests/test_block_exact_match.py
# the test below will fail because there is no server that serves layer 7
# BLOCK_UID=bloom6b3.7 pytest tests/test_block_exact_match.py
# test full model exact match
MODEL_NAME=bigscience/test-bloomd-6b3 REF_NAME=bigscience/bloom-6b3 pytest tests/test_full_model.py
```

@ -14,11 +14,12 @@ def main():
parser = configargparse.ArgParser(default_config_files=["config.yml"])
parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
parser.add_argument('--prefix', type=str, required=True, help="Announce all blocks with this prefix")
parser.add_argument('--converted_model_name_or_path', type=str, default='bigscience/test-bloomd-6b3',
help="path or name of a pretrained model, converted with cli/convert_model.py (see README.md)")
parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
"use the same name as in the converted model.")
parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,

@ -9,15 +9,8 @@ import torch
import torch.nn as nn
import torch.nn.quantized.dynamic.modules.linear
from src.bloom.ops import (
BloomGelu,
BloomScaledSoftmax,
attention_mask_func,
build_alibi_tensor,
dropout_add,
pre_process_alibi_for_pad,
split_tensor_along_last_dim,
)
from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
pre_process_alibi_for_pad, split_tensor_along_last_dim)
class BloomAttention(nn.Module):

@ -11,11 +11,8 @@ import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
@ -208,6 +205,8 @@ class BloomModel(BloomPreTrainedModel):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
if position_ids is not None:
logger.warning("position_ids are ignored in this bloom implementation")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
@ -238,9 +237,8 @@ class BloomModel(BloomPreTrainedModel):
# Compute alibi tensor: check build_alibi_tensor documentation
current_sequence_length = hidden_states.shape[1]
if past_key_values[0] is not None:
if past_key_values and past_key_values[0]:
current_sequence_length += past_key_values[0][0].shape[1]
alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@ -258,7 +256,7 @@ class BloomModel(BloomPreTrainedModel):
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions, alibi)
return module(*inputs, use_cache, output_attentions, alibi=None)
return custom_forward
@ -277,7 +275,7 @@ class BloomModel(BloomPreTrainedModel):
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
alibi=None,
)
hidden_states = outputs[0]

@ -11,13 +11,17 @@ from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo
from hivemind.p2p import P2P, StubBase
from hivemind.proto import runtime_pb2
from hivemind.utils import anext, nested_flatten
from hivemind.utils import anext, nested_flatten, use_hivemind_log_handler, get_logger
from src.data_structures import RemoteModuleInfo
from src.dht_utils import ModuleUID
from src.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class RemoteTransformerBlock(RemoteExpert):
"""A class that interacts with a remote module on a specific server for forward/backward or inference"""
@ -29,11 +33,20 @@ class RemoteTransformerBlock(RemoteExpert):
def stub(self) -> StubBase:
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
def forward(self, inputs: torch.Tensor, **kwargs):
for k, v in kwargs.items():
assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
return super().forward(inputs)
def inference_session(self) -> RemoteTransformerBlockInferenceSession:
"""Initialize a new inference session with the specified remote server"""
_ = self.info # create _info manually since the built-in property will not work inside RemoteExpertWorker
return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
def begin_inference_session(self):
logger.warning("beging_inference_session was renamed to just inference_session")
return self.inference_session()
class RemoteTransformerBlockInferenceSession:
"""An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
@ -44,6 +57,7 @@ class RemoteTransformerBlockInferenceSession:
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self.stepped = False
self.closed = False
@classmethod
@ -89,6 +103,7 @@ class RemoteTransformerBlockInferenceSession:
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
await self._inputs_queue.put(inputs_serialized)
self.stepped = True
return await anext(self._outputs_stream)
def close(self):
@ -103,11 +118,12 @@ class RemoteTransformerBlockInferenceSession:
"""Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
if self._outputs_stream is None:
return # already closed
await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
try:
await anext(self._outputs_stream)
except StopAsyncIteration:
pass
if self.stepped:
await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
try:
await anext(self._outputs_stream)
except StopAsyncIteration:
pass
def __del__(self):
self.close()

@ -0,0 +1,49 @@
# this code is in active development, interfaces may change
import os
from typing import Optional, Union
import hivemind
from hivemind import DHT, get_logger, use_hivemind_log_handler
from src.bloom import BloomForCausalLM, DistributedBloomConfig
from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
from src.client.remote_sequential import RemoteSequential
from src.data_structures import UID_DELIMITER
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class DistributedBloomForCausalLM(BloomForCausalLM):
"""BloomForCausalLM, but all transformer layers are hosted by the swarm"""
def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
super().__init__(config)
assert len(self.transformer.h) == 0
config.n_layer = n_layer
self.transformer.h = RemoteSequential(config, dht, prefix)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
if 'initial_peers' not in kwargs:
raise ValueError("Please specify initial_peers=...")
dht = hivemind.DHT(
initial_peers=kwargs.pop('initial_peers'), client_mode=kwargs.pop('client_mode', True),
start=True)
if 'prefix' not in kwargs:
logger.debug(f"No DHT prefix specified; using automatic prefix {pretrained_model_name_or_path}")
assert UID_DELIMITER not in pretrained_model_name_or_path, \
f"Cannot infer prefix automatically from {pretrained_model_name_or_path}; please specify prefix=..."
prefix = kwargs.pop("prefix", pretrained_model_name_or_path)
config = DistributedBloomConfig.from_pretrained(pretrained_model_name_or_path, revision=CLIENT_BRANCH, **kwargs)
model = cls(config, dht, prefix)
model.load_state_dict(_load_state_dict(
pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token')
), strict=True)
return model

@ -0,0 +1,93 @@
from __future__ import annotations
import dataclasses
import threading
from functools import partial
from typing import Tuple, List, Optional, Sequence, NamedTuple
from hivemind import DHT, PeerID
from hivemind.utils.logging import use_hivemind_log_handler, get_logger
from src.data_structures import ModuleUID, RemoteModuleInfo
from src.dht_utils import _get_remote_module_infos
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
Span = NamedTuple('Span', [('start', int), ('end', Optional[int]), ('peer_id', PeerID)])
@dataclasses.dataclass(frozen=False, init=False)
class RemoteSequenceInfo:
"""Keeps and updates the meta-information about which peers host which blocks"""
dht: DHT
block_uids: List[ModuleUID, ...]
block_infos: List[Optional[RemoteModuleInfo], ...]
spans_by_priority: List[Span] # sorted from best to worst
spans_containing_block: Tuple[List[Span], ...]
lock_changes: threading.Lock
def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
self.dht = dht
self.block_uids = list(block_uids)
self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
self.spans_by_priority = []
self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
self.lock_changes = threading.Lock()
self.update_()
for uid, info in zip(self.block_uids, self.block_infos):
assert info is not None, f"Found no remote peers for block {uid}"
assert self.spans_by_priority and self.spans_containing_block
def update_(self):
with self.lock_changes:
self.update_block_infos_()
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
def update_block_infos_(self):
new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
return_future=False)
assert len(new_block_infos) == len(self.block_uids)
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
if info is None:
logger.warning(f"Found no block info for block {uid}")
if not isinstance(info, RemoteModuleInfo):
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
if not info.peer_ids:
logger.warning(f"Found no active peers for block {uid}")
if info.uid != uid:
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
if not isinstance(info.peer_ids, set):
logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
self.block_infos[block_index] = info
@staticmethod
def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
closed_spans = []
active_spans = {}
for block_index, info in enumerate(block_infos):
for peer_id in info.peer_ids:
if peer_id not in active_spans:
active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
else: # peer_id in active_spans
active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
for peer_id in list(active_spans.keys()):
if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
closed_spans.append(active_spans.pop(peer_id))
assert not active_spans
closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
for span in closed_spans:
for block_index in range(span.start, span.end):
spans_containing_block[block_index].append(span)
return closed_spans, spans_containing_block
def __len__(self):
return len(self.block_uids)

@ -0,0 +1,134 @@
from __future__ import annotations
import contextlib
import logging
import random
import torch
from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo
from torch import nn
from src import DistributedBloomConfig, RemoteTransformerBlock
from src.client.remote_sequence_info import RemoteSequenceInfo
from src.data_structures import UID_DELIMITER
from src.dht_utils import _create_remote_modules_from_infos
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class RemoteSequential(nn.Module):
"""
A sequence of transformer blocks hosted by the swarm.
"""
def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
if prefix.endswith(UID_DELIMITER):
logger.warning(
f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
f"This will cause {self.__class__.__name__} to look for modules under "
f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
)
super().__init__()
self.config = config
self.dht = dht
self.prefix = prefix
self.max_retries = max_retries
self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
logger.debug(f"Remote block uids: {block_uids}")
self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
def forward(self, inputs: torch.Tensor):
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
for block_index in range(self.config.n_layer):
for retry_index in range(self.max_retries):
try:
block = self[block_index]
(outputs,) = block(inputs)
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
inputs = outputs
break
except Exception as e:
if retry_index == self.max_retries - 1:
raise e
else:
logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
return inputs
def __getitem__(self, block_index: int):
assert 0 <= block_index < self.config.n_layer
(module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p)
return module
def __iter__(self):
for block_index in range(self.config.n_layer):
yield self[block_index]
def __len__(self):
return len(self.remote_sequence_info)
def inference_session(self) -> RemoteSequentialInferenceSession:
self.remote_sequence_info.update_()
return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p)
class RemoteSequentialInferenceSession:
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
self.remote_sequence_info = remote_sequence_info
self.p2p = p2p
self.closed = False
self.stack = contextlib.ExitStack()
self.active_sessions = []
def __enter__(self):
assert not self.closed
self.stack.__enter__()
# TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
current_block = 0
while current_block != len(self.remote_sequence_info):
candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
chosen_span = random.choice(candidate_spans) # TODO this is a temporary code
assert chosen_span.start <= current_block < chosen_span.end
# TODO begin throwaway prototype code
remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
_=remote.info #TODO fix
span_uids = self.remote_sequence_info.block_uids[current_block: chosen_span.end]
remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
self.active_sessions.append(remote.inference_session())
self.stack.enter_context(self.active_sessions[-1])
current_block = chosen_span.end
# TODO end throwaway prototype code
return self
def step(self, inputs: torch.Tensor):
assert not self.closed
for session in self.active_sessions:
outputs = session.step(inputs)
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
inputs = outputs
return inputs
def close(self, *exc_details):
"""Finish a given inference session, close the underlying connection"""
if not self.closed:
self.stack.__exit__(*exc_details or (None, None, None))
self.active_sessions.clear()
self.closed = True
def __exit__(self, *exc_details):
self.close(*exc_details)
def __del__(self):
self.close()

@ -106,7 +106,8 @@ async def _get_remote_module_infos(
for i, uid in enumerate(uids):
metadata = found[uid]
if metadata is None or not isinstance(metadata.value, dict):
logger.error(f"Incorrect metadata for {uid}: {metadata}")
if metadata is not None:
logger.error(f"Incorrect metadata for {uid}: {metadata}")
continue
valid_entries = set()
for maybe_peer_id, _unused_value in metadata.value.items():

@ -26,29 +26,29 @@ class TransformerBackend(ModuleBackend):
self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
attention_cache_handle = int(cache_metadata[0, 0].item())
prefix_length = int(cache_metadata[0, 1].item())
hidden_states = inputs[0] # todo: in future, it would be best to support attention mask here
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
with self.memory_cache.use_cache(attention_cache_handle) as cache:
print("METADATA:", cache_metadata)
assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
print("PAST", past_k.shape, past_v.shape)
hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
# todo remove these asserts once we pass all tests
new_length = new_v.shape[1]
assert new_length > prefix_length
assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
assert torch.allclose(new_v[:, : past_v.shape[1]], past_v)
assert torch.allclose(new_k[:, : past_k.shape[1]], past_k)
cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
return (hidden_states,)
with torch.inference_mode():
attention_cache_handle = int(cache_metadata[0, 0].item())
prefix_length = int(cache_metadata[0, 1].item())
hidden_states = inputs[0] # todo: in future, it would be best to support attention mask here
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
with self.memory_cache.use_cache(attention_cache_handle) as cache:
assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
# todo remove these asserts once we pass all tests
new_length = new_v.shape[1]
assert new_length > prefix_length
assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
assert torch.allclose(new_v[:, : past_v.shape[1]], past_v)
assert torch.allclose(new_k[:, : past_k.shape[1]], past_k)
cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
return (hidden_states,)
def get_pools(self) -> Sequence[TaskPool]:
return self.forward_pool, self.backward_pool, self.inference_pool

@ -14,6 +14,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src import declare_active_modules
from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block
from src.data_structures import UID_DELIMITER, CHAIN_DELIMITER
from src.server.backend import TransformerBackend
from src.server.cache import MemoryCache
from src.server.handler import TransformerConnectionHandler
@ -84,7 +85,7 @@ class Server(threading.Thread):
@classmethod
def create(
cls,
prefix: str,
prefix: Optional[str],
converted_model_name_or_path: str,
num_blocks: Optional[int] = None,
block_indices: Optional[str] = None,
@ -108,6 +109,12 @@ class Server(threading.Thread):
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
if custom_module_path is not None:
add_custom_models_from_file(custom_module_path)
if prefix is None:
prefix = converted_model_name_or_path
assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix,\
f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); " \
f"Please specify --prefix manually when starting a server"
logger.info(f"Automatic dht prefix: {prefix}")
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]

@ -32,7 +32,7 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
(outputs_forward,) = remote_block(inputs)
outputs_inference = []
with remote_block.begin_inference_session() as sess:
with remote_block.inference_session() as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)

@ -39,7 +39,7 @@ def test_remote_block_exact_match(atol_inference=1e-4):
inputs = torch.randn(1, 8, 4096)
outputs_inference = []
with remote_block.begin_inference_session() as sess:
with remote_block.inference_session() as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)

@ -0,0 +1,57 @@
# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
import os
import torch
import transformers
from hivemind import use_hivemind_log_handler, get_logger
from src.client.remote_model import DistributedBloomForCausalLM
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
if not INITIAL_PEERS:
raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
INITIAL_PEERS = INITIAL_PEERS.split()
MODEL_NAME = os.environ.get("MODEL_NAME")
if not MODEL_NAME:
raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
REF_NAME = os.environ.get("REF_NAME")
def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
assert len(model.transformer.h) == model.config.n_layer
test_inputs = tokenizer("A cat sat on a mat", return_tensors='pt')['input_ids']
parallel_outputs = model.forward(test_inputs).logits
assert torch.all(torch.isfinite(parallel_outputs))
logger.info("Forward outputs are finite")
if REF_NAME:
ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
# note: this creates a dummy mask to make the test compatible with older transformer versions
# prior to https://github.com/huggingface/transformers/pull/17837
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
else:
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
embs = model.transformer.word_embeddings(test_inputs)
embs = model.transformer.word_embeddings_layernorm(embs)
recurrent_outputs = []
with model.transformer.h.inference_session() as sess:
for t in range(embs.shape[1]):
recurrent_outputs.append(sess.step(embs[:, t: t + 1, :]))
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
recurrent_outputs = model.lm_head(recurrent_outputs)
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
logger.info("Inference is consistent with forward")
Loading…
Cancel
Save