Sequential and parallel forward / backward (#36)

pull/37/head
Dmitry Baranchuk 2 years ago committed by GitHub
parent f0cffbf67e
commit 6573076883
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -62,23 +62,10 @@ else
conda activate bloom-demo
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -i https://pypi.org/simple -r requirements.txt
fi
##############
# Local test #
##############
if [ "$RUN_LOCAL_TESTS" = true ] ; then
echo "Run test on your local machine"
python -m cli.inference_one_block --config cli/config.json --device ${DEVICE} # see other args
fi
##############
# Run server #
##############

@ -32,17 +32,15 @@ done
###########################
source ~/miniconda3/etc/profile.d/conda.sh
if conda env list | grep ".*bloom-demo.*" &>/dev/null; then
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
conda activate bloom-demo
else
conda create -y --name bloom-demo python=3.8.12 pip
conda activate bloom-demo
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -i https://pypi.org/simple -r requirements.txt
fi
@ -88,7 +86,7 @@ do
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
echo "=== Server #${SERVER_ID} ==="
echo "Server ID: ${id_path}"
echo "Server ID: ${cfg[id_path]}"
echo "Device: ${cfg[device]}"
echo "Bloom block ids: ${cfg[block_ids]}"
echo "Host maddr: ${cfg[maddr]}"

@ -37,17 +37,15 @@ done
###########################
source ~/miniconda3/etc/profile.d/conda.sh
if conda env list | grep ".*bloom-demo.*" &>/dev/null; then
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
conda activate bloom-demo
else
conda create -y --name bloom-demo python=3.8.12 pip
conda activate bloom-demo
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -i https://pypi.org/simple -r requirements.txt
fi
@ -57,7 +55,7 @@ fi
hivemind-dht &> tmp.out &
sleep 3
sleep 5
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
rm tmp.out
echo "Initial peer: ${INITIAL_PEER}"

@ -1,11 +1,11 @@
# this code is in active development, interfaces may change
import os
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple
import hivemind
import torch
import torch.nn as nn
from hivemind import get_logger, use_hivemind_log_handler
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from src.bloom.model import (
BloomConfig,
@ -66,13 +66,45 @@ class DistributedBloomModel(BloomModel):
for p in self.parameters():
p.requires_grad = value
def forward(self, *args, use_cache=None, **kwargs):
if use_cache:
raise ValueError(
"Distributed forward does not support use_cache; for efficient cache-aware generation, "
"please use model.transformer.inference_session() or model.generate(...)"
)
return super().forward(*args, use_cache=False, **kwargs)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
for k, v in kwargs.items():
if not (v is None or v is False):
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
output_shape = input_shape + (hidden_states.size(-1),)
hidden_states = self.h(hidden_states)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=None,
attentions=None,
)
class DistributedBloomPrefix(DistributedBloomModel):
@ -94,16 +126,10 @@ class DistributedBloomPrefix(DistributedBloomModel):
def forward(
self,
input_ids: Optional[torch.LongTensor],
inputs_embeds: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values=None,
position_ids=None,
head_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
assert (
input_ids is None or inputs_embeds is None
@ -122,17 +148,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
transformer_outputs = super().forward(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
# Remove prefix
last_hidden_state = transformer_outputs[0][:, self.prefix_length :]

@ -12,6 +12,7 @@ import src
from src.client.inference_session import RemoteSequentialInferenceSession
from src.client.remote_block import RemoteTransformerBlock
from src.client.sequence_manager import RemoteSequenceManager
from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
from src.data_structures import UID_DELIMITER
from src.dht_utils import _create_remote_modules_from_infos
@ -52,21 +53,8 @@ class RemoteSequential(nn.Module):
self.is_subsequence = self.sequence_manager.block_uids != block_uids
def forward(self, inputs: torch.Tensor):
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
for block in iter(self):
for retry_index in range(self.sequence_manager.max_retries):
try:
(outputs,) = block(inputs)
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
inputs = outputs
break
except Exception as e:
if retry_index == self.sequence_manager.max_retries - 1:
raise e
else:
logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
return inputs
outputs = _RemoteSequentialAutogradFunction.apply(inputs, self.sequence_manager)
return outputs
def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
assert isinstance(ix, (int, slice))

@ -0,0 +1,220 @@
import asyncio
import logging
from typing import List, Optional, Sequence, Tuple
import torch
from hivemind import serialize_torch_tensor
from hivemind.moe.client.expert import expert_backward, expert_forward
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import StubBase
from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from src.server.handler import TransformerConnectionHandler
MAX_TOKENS_IN_BATCH = 1024
async def run_expert_forward(
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
) -> Tuple[torch.Tensor, ...]:
"""
Serializes input tensors and calls "expert_forward".
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
# detach to avoid pickling the computation graph
assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = (inputs, kwargs)
if not nested_compare(forward_inputs, rpc_info["forward_schema"]):
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
forward_inputs = nested_flatten(forward_inputs)
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
# TODO: figure out whether we should use run_in_executor here
serialized_tensors = (
serialize_torch_tensor(tensor, proto.compression)
for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
)
deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
flat_outputs = tuple(deserialized_outputs)
return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
async def run_expert_backward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
intemediate_inputs: List[torch.Tensor],
grad_outputs: List[torch.Tensor],
) -> Sequence[torch.Tensor]:
"""
Serializes grad outputs and calls "expert_backward".
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
serialized_tensors = (
serialize_torch_tensor(tensor, proto.compression)
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
)
deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
return deserialized_grad_inputs
async def sequential_forward(
inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
"""
Constructs a routing path from <start_index> to <end_index>.
Performs chained forward for each subsequence of blocks on the path.
If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
"""
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
sequences = sequence_manager.make_sequence(start_index, end_index)
intermediate_inputs = []
done_sequences = []
while len(sequences) > 0:
while True:
try:
span = sequences.pop(0)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
(outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
# Save intermediate inputs and subsequences if the forward is already done for them
intermediate_inputs.append(inputs)
done_sequences.append(span)
inputs = outputs
break
except Exception as e:
logging.debug(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
backup_sequences = sequence_manager.make_sequence(span.start)
assert backup_sequences[0].start == span.start
sequences = backup_sequences
return outputs, intermediate_inputs, done_sequences
async def sequential_backward(
grad_outputs: Sequence[torch.Tensor],
intermediate_inputs: Sequence[torch.Tensor],
forward_sequences: Sequence[RemoteSpanInfo],
sequence_manager: RemoteSequenceManager,
) -> Sequence[torch.Tensor]:
"""
Performs chained backward for each forward subsequence.
If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
"""
assert len(intermediate_inputs) == len(forward_sequences)
# TODO think about grads w.r.t. deep prompts
while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
while True:
try:
inputs = intermediate_inputs.pop(-1)
span = forward_sequences.pop(-1)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
grad_outputs = await run_expert_backward(
span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
)
break
except Exception as e:
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
_, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
inputs, sequence_manager, start_index=span.start, end_index=span.end
)
assert len(intermediate_inputs) == len(forward_sequences)
assert backup_forward_sequences[0].start == span.start
assert backup_forward_sequences[-1].end == span.end
forward_sequences.extend(backup_forward_sequences)
intermediate_inputs.extend(backup_intermediate_inputs)
return grad_outputs
async def _gather_forward(input_batches, sequence_manager):
"""Wrapper for asyncio.gather to perform parallel sequential forwards"""
return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
"""Wrapper for asyncio.gather to perform parallel sequential backwards"""
return await asyncio.gather(
*[
sequential_backward((grad_output,), input_batch, spans, sequence_manager)
for grad_output, input_batch, spans in zip(
grad_output_batches, intermediate_input_batches, forward_sequences
)
]
)
class _RemoteSequentialAutogradFunction(torch.autograd.Function):
"""
PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.
This function splits input data into batches with <MAX_TOKENS_IN_BATCH> and performs efficient parallel processing.
"""
@staticmethod
def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
input_batches: Sequence[torch.Tensor] = inputs.split(batch_size)
sequence_manager.rpc_info # lazy init
outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))
assert len(outputs) == len(input_batches)
output_batches = [output[0] for output in outputs]
intemediate_input_batches = [output[1] for output in outputs]
sequences_for_batches = [output[2] for output in outputs]
ctx.sequence_manager = sequence_manager
ctx.intemediate_input_batches = intemediate_input_batches
ctx.sequences_for_batches = sequences_for_batches
return torch.cat(output_batches, dim=0)
@staticmethod
def backward(ctx, grad_outputs: torch.Tensor):
intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
ctx.sequence_manager.rpc_info # lazy init
batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
grad_input_batches = RemoteExpertWorker.run_coroutine(
_gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
)
grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
grad_inputs = torch.cat(grad_inputs, dim=0)
return (grad_inputs, None)
Loading…
Cancel
Save