black-isort

diff
justheuristic 2 years ago
parent 0f9cd687d4
commit 4ad845bce3

@ -48,7 +48,7 @@ if __name__ == "__main__":
config = transformers.AutoConfig.from_pretrained(
args.model, use_auth_token=args.use_auth_token, revision=args.revision
)
model = transformers.AutoModel.from_pretrained(
model = transformers.AutoModel.from_pretrained(
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
)
tokenizer = transformers.AutoTokenizer.from_pretrained(

@ -9,8 +9,15 @@ import torch
import torch.nn as nn
import torch.nn.quantized.dynamic.modules.linear
from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
pre_process_alibi_for_pad, split_tensor_along_last_dim)
from src.bloom.ops import (
BloomGelu,
BloomScaledSoftmax,
attention_mask_func,
build_alibi_tensor,
dropout_add,
pre_process_alibi_for_pad,
split_tensor_along_last_dim,
)
class BloomAttention(nn.Module):

@ -9,8 +9,11 @@ import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward)
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
@ -153,9 +156,9 @@ class BloomModel(BloomPreTrainedModel):
self.n_head = config.n_head
# Embedding + LN Embedding
# TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!)
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) # dtype=config.torch_dtype
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) # dtype=config.torch_dtype
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks
@ -177,10 +180,10 @@ class BloomModel(BloomPreTrainedModel):
def set_input_embeddings(self, new_embeddings):
self.word_embeddings = new_embeddings
def set_requires_grad(self, value):
for p in self.parameters():
p.requires_grad=value
p.requires_grad = value
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
@ -320,9 +323,9 @@ class BloomForYou(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = BloomModel(config)
self.lm_head = None
super().__init__(config)
self.transformer = BloomModel(config)
self.lm_head = None
# Initialize weights and apply final processing
self.post_init()
# Initialize weights and apply final processing
self.post_init()

@ -31,29 +31,31 @@ class DistributedBloomForYou(BloomForYou):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
if 'initial_peers' not in kwargs:
if "initial_peers" not in kwargs:
raise ValueError("Please specify initial_peers=...")
dht = hivemind.DHT(
initial_peers=kwargs.pop('initial_peers'), client_mode=kwargs.pop('client_mode', True),
start=True)
initial_peers=kwargs.pop("initial_peers"), client_mode=kwargs.pop("client_mode", True), start=True
)
if 'prefix' not in kwargs:
if "prefix" not in kwargs:
logger.debug(f"No DHT prefix specified; using automatic prefix {pretrained_model_name_or_path}")
assert UID_DELIMITER not in pretrained_model_name_or_path, \
f"Cannot infer prefix automatically from {pretrained_model_name_or_path}; please specify prefix=..."
assert (
UID_DELIMITER not in pretrained_model_name_or_path
), f"Cannot infer prefix automatically from {pretrained_model_name_or_path}; please specify prefix=..."
prefix = kwargs.pop("prefix", pretrained_model_name_or_path)
config = DistributedBloomConfig.from_pretrained(pretrained_model_name_or_path, revision=CLIENT_BRANCH, **kwargs)
model = cls(config, dht, prefix)
model.transformer.load_state_dict(_load_state_dict(
pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token')
), strict=True)
model.transformer.load_state_dict(
_load_state_dict(pretrained_model_name_or_path, use_auth_token=kwargs.get("use_auth_token")), strict=True
)
return model
class DistributedBloomForCausalLM(DistributedBloomForYou):
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
@ -86,9 +88,7 @@ class DistributedBloomForCausalLM(DistributedBloomForYou):
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer.forward(
input_ids=input_ids, return_dict=return_dict, **kwargs
)
transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
# Switch dtype in case word_embeddings are fp16
word_embeddings = self.transformer.word_embeddings.weight.t()

@ -15,12 +15,13 @@ use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
Span = NamedTuple('Span', [('start', int), ('end', Optional[int]), ('peer_id', PeerID)])
Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
@dataclasses.dataclass(frozen=False, init=False) # TODO[borzunov@] eto ne dataclass
class RemoteSequenceInfo:
"""Keeps and updates the meta-information about which peers host which blocks"""
dht: DHT
block_uids: List[ModuleUID, ...]
block_infos: List[Optional[RemoteModuleInfo], ...]
@ -48,8 +49,8 @@ class RemoteSequenceInfo:
def update_block_infos_(self):
new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
return_future=False)
partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")), return_future=False
)
assert len(new_block_infos) == len(self.block_uids)
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
if info is None:

@ -103,8 +103,8 @@ class RemoteSequentialInferenceSession:
# TODO begin throwaway prototype code
remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
_=remote.info #TODO fix
span_uids = self.remote_sequence_info.block_uids[current_block: chosen_span.end]
_ = remote.info # TODO fix
span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end]
remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
self.active_sessions.append(remote.inference_session())
self.stack.enter_context(self.active_sessions[-1])

@ -30,13 +30,17 @@ class TransformerBackend(ModuleBackend):
attention_cache_handle = int(cache_metadata[0, 0].item())
prefix_length = int(cache_metadata[0, 1].item())
hidden_states = inputs[0] # todo: in future, it would be best to support attention mask here
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
assert (
hidden_states.ndim == 3
), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
with self.memory_cache.use_cache(attention_cache_handle) as cache:
assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
hidden_states, (new_k, new_v) = self.module.forward(
hidden_states, layer_past=layer_past, use_cache=True
)
# todo remove these asserts once we pass all tests
new_length = new_v.shape[1]

@ -76,22 +76,22 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_uids = self._check_header(request)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
# Run a chain of requested backends
# Run a chain of requested backends
for backend in requested_backends:
assert isinstance(hidden_states, (list, tuple))
assert (
len(hidden_states) == 1 and hidden_states[0].ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
hidden_states = await backend.forward_pool.submit_task(*hidden_states)
# Serialize the overall output and respond
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
return runtime_pb2.ExpertResponse(tensors=[
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
for result, proto in zip(
hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
)
])
return runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
]
)
async def rpc_forward_stream(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@ -101,48 +101,41 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_uids = self._check_header_str(uids_header)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
# Run a chain of requested backends
# Run a chain of requested backends
for backend in requested_backends:
assert isinstance(hidden_states, (list, tuple))
assert (
len(hidden_states) == 1 and hidden_states[0].ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
hidden_states = await backend.forward_pool.submit_task(*hidden_states)
# Serialize the overall output
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
serialized_output = [
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
for result, proto in zip(
hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
)
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
]
# Split the serialized_output for streaming and respond
output_split = [
part
for tensor in serialized_output
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
]
async for part in as_aiter(*output_split):
yield runtime_pb2.ExpertResponse(tensors=[part])
async def rpc_backward(
self, request: runtime_pb2.ExpertRequest, context: P2PContext
) -> runtime_pb2.ExpertResponse:
async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
# Parse requests and prepare backends
inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_header(request)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
# Note that we do not forward for the last module since we do not need its output
inter_inputs = [inputs]
for backend in requested_backends[:-1]:
assert (inputs.ndim == 3
), f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
inputs = await backend.forward_pool.submit_task(inputs)
assert (isinstance(inputs, (list, tuple)) and len(inputs) == 1)
assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
inputs = inputs[0]
inter_inputs.append(inputs)
@ -150,16 +143,16 @@ class TransformerConnectionHandler(ConnectionHandler):
for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
inputs_and_grads = [inp, grads]
grads = await backend.backward_pool.submit_task(*inputs_and_grads)
assert (isinstance(grads, (list, tuple)) and len(grads) == 1)
assert isinstance(grads, (list, tuple)) and len(grads) == 1
grads = grads[0]
# Serialize the overall grad_input and respond
return runtime_pb2.ExpertResponse(tensors=[
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
for result, proto in zip(
[grads], nested_flatten(requested_backends[0].grad_inputs_schema)
)
])
return runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
]
)
async def rpc_backward_stream(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@ -170,35 +163,30 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its outputs
# Note that we do not forward for the last module since we do not need its outputs
inter_inputs = [inputs]
for backend in requested_backends[:-1]:
assert (inputs.ndim == 3
), f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
inputs = await backend.forward_pool.submit_task(inputs)
assert (isinstance(inputs, (list, tuple)) and len(inputs) == 1)
assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
inputs = inputs[0]
inter_inputs.append(inputs)
# Run a backward chain for requested backends
# Run a backward chain for requested backends
for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
inputs_and_grads = [inp, grads]
grads = await backend.backward_pool.submit_task(*inputs_and_grads)
assert (isinstance(grads, (list, tuple)) and len(grads) == 1)
assert isinstance(grads, (list, tuple)) and len(grads) == 1
grads = grads[0]
# Serialize the overall grad_inputs
serialized_grad_inputs = [
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
for result, proto in zip(
[grads], nested_flatten(requested_backends[0].grad_inputs_schema)
)
for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
]
# Split the serialized_grad_inputs for streaming and respond
output_split = [
part
for tensor in serialized_grad_inputs
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
]
async for part in as_aiter(*output_split):

@ -111,9 +111,10 @@ class Server(threading.Thread):
add_custom_models_from_file(custom_module_path)
if prefix is None:
prefix = converted_model_name_or_path
assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix,\
f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); " \
assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
f"Please specify --prefix manually when starting a server"
)
logger.info(f"Automatic dht prefix: {prefix}")
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
@ -139,7 +140,9 @@ class Server(threading.Thread):
assert num_blocks is not None
block_indices = range(num_blocks) # TODO replace with proper load balancing
block_config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
block_config = DistributedBloomConfig.from_pretrained(
converted_model_name_or_path, use_auth_token=use_auth_token
)
# initialize modules
blocks = {}

@ -30,10 +30,10 @@ REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
# seq_length <= 128: rpc_forward & rpc_backward
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
remote_block, = get_remote_module(dht, BLOCK_UID)
(remote_block,) = get_remote_module(dht, BLOCK_UID)
assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
assert isinstance(remote_block, RemoteTransformerBlock)
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4 bloom6b3.5", remote_block._info.peer_id)
@ -41,7 +41,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
]
]
inputs = torch.randn(1, seq_length, 4096, requires_grad=True)
outputs_rpc = remote_block.forward(inputs)[0]
outputs_rpc.sum().backward()

@ -29,7 +29,7 @@ def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS, prefix=prefix)
assert len(model.transformer.h) == model.config.n_layer
test_inputs = tokenizer("A cat sat on a mat", return_tensors='pt')['input_ids']
test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
parallel_outputs = model.forward(test_inputs).logits
assert torch.all(torch.isfinite(parallel_outputs))
logger.info("Forward outputs are finite")
@ -49,7 +49,7 @@ def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="
recurrent_outputs = []
with model.transformer.h.inference_session() as sess:
for t in range(embs.shape[1]):
recurrent_outputs.append(sess.step(embs[:, t: t + 1, :]))
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
recurrent_outputs = model.transformer.ln_f(recurrent_outputs)

Loading…
Cancel
Save