Deep distributed prompt tuning (#42)

* implemented an option to add learnable prompts to intermediate layers
* added support for prompts (as input) in rpc_forward and rpc_backward
* added a test to check that RemoteSequential works correctly with deep prompts

Co-authored-by: justheuristic <justheuristic@gmail.com>
fix-convert-8bit
Dmitry Baranchuk 2 years ago committed by GitHub
parent 9460220a10
commit 6095f58681
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,5 @@
# this code is in active development, interfaces may change
from typing import List, Optional, Tuple
from typing import Optional, Tuple
import hivemind
import torch
@ -17,6 +17,7 @@ from src.bloom.model import (
)
from src.client.remote_generation import RemoteGenerationMixin
from src.client.remote_sequential import RemoteSequential
from src.utils.misc import DUMMY
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -33,6 +34,7 @@ class DistributedBloomConfig(BloomConfig):
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
class DistributedBloomModel(BloomModel):
@ -60,10 +62,41 @@ class DistributedBloomModel(BloomModel):
# Forbid accumulate grads for embeddings and layernorm
self.set_requires_grad(False)
if config.tuning_mode and "ptune" in config.tuning_mode:
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
self.pre_seq_len = config.pre_seq_len
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
if config.tuning_mode == "deep_ptune":
self.intermediate_prompt_embeddings = nn.Embedding(
self.pre_seq_len,
config.num_hidden_layers * config.hidden_size
# ^-- TODO: should be num_hidden_layers - 1
)
self.intermediate_prompt_embeddings.weight.data.zero_()
elif config.tuning_mode:
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
def set_requires_grad(self, value):
for p in self.parameters():
p.requires_grad = value
def get_prompt(self, batch_size):
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
prompts = self.prompt_embeddings(prefix_tokens)
if self.config.tuning_mode == "deep_ptune":
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
intermediate_prompts = intermediate_prompts.view(
batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1
)
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
else:
intermediate_prompts = DUMMY
return prompts, intermediate_prompts
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -90,10 +123,22 @@ class DistributedBloomModel(BloomModel):
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# Note: it supports only float32 or bfloat16 inputs
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
output_shape = input_shape + (hidden_states.size(-1),)
hidden_states = self.h(hidden_states)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
else:
hidden_states = self.h(hidden_states)
# Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = hidden_states[:, self.pre_seq_len :]
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
@ -106,55 +151,6 @@ class DistributedBloomModel(BloomModel):
)
class DistributedBloomPrefix(DistributedBloomModel):
"""DistributedBloomModel with prefix tokens for prompt tuning"""
def __init__(self, config):
super().__init__(config)
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
self.pre_seq_len = config.pre_seq_len
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
def get_prompt(self, batch_size):
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
prompts = self.prompt_embeddings(prefix_tokens)
return prompts
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
assert (
input_ids is None or inputs_embeds is None
), "You cannot specify both input_ids and inputs_embeds at the same time"
assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
batch_size = inputs_embeds.shape[0]
if attention_mask is not None:
prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
# Remove prefix
last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
transformer_outputs["last_hidden_state"] = last_hidden_state
return transformer_outputs
class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
@ -162,10 +158,7 @@ class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
def __init__(self, config: DistributedBloomConfig):
BloomPreTrainedModel.__init__(self, config)
if config.pre_seq_len > 0:
self.transformer = DistributedBloomPrefix(config)
else:
self.transformer = DistributedBloomModel(config)
self.transformer = DistributedBloomModel(config)
self.lm_head = LMHead(config, self.transformer.word_embeddings)
# Initialize weights and apply final processing
@ -195,10 +188,7 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
def __init__(self, config: DistributedBloomConfig):
super().__init__(config)
if config.pre_seq_len > 0:
self.transformer = DistributedBloomPrefix(config)
else:
self.transformer = DistributedBloomModel(config)
self.transformer = DistributedBloomModel(config)
# Initialize weights and apply final processing
self.post_init()

@ -15,6 +15,7 @@ from src.client.sequence_manager import RemoteSequenceManager
from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
from src.data_structures import UID_DELIMITER
from src.dht_utils import _create_remote_modules_from_infos
from src.utils.misc import DUMMY
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -52,8 +53,8 @@ class RemoteSequential(nn.Module):
assert isinstance(sequence_manager.block_uids, list)
self.is_subsequence = self.sequence_manager.block_uids != block_uids
def forward(self, inputs: torch.Tensor):
outputs = _RemoteSequentialAutogradFunction.apply(inputs, self.sequence_manager)
def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
return outputs
def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:

@ -12,6 +12,7 @@ from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from src.server.handler import TransformerConnectionHandler
from src.utils.misc import DUMMY, is_dummy
MAX_TOKENS_IN_BATCH = 1024
@ -33,7 +34,13 @@ async def run_expert_forward(
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = (inputs, kwargs)
if not nested_compare(forward_inputs, rpc_info["forward_schema"]):
# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
# TODO: rm this assert when support arbitrary number of input tensors
assert len(args_schema) == 1 and len(inputs) == 2
forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
if not nested_compare(forward_inputs, forward_schema_with_prompts):
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
forward_inputs = nested_flatten(forward_inputs)
@ -44,7 +51,7 @@ async def run_expert_forward(
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
)
)
@ -57,8 +64,9 @@ async def run_expert_backward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
intemediate_inputs: List[torch.Tensor],
inputs: torch.Tensor,
grad_outputs: List[torch.Tensor],
*extra_tensors: torch.Tensor,
) -> Sequence[torch.Tensor]:
"""
Serializes grad outputs and calls "expert_backward".
@ -67,8 +75,14 @@ async def run_expert_backward(
"""
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
# TODO generalize this
prompts_schema = next(iter(args_schema))
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
# Asynchronous serialization
loop = asyncio.get_running_loop()
@ -84,7 +98,11 @@ async def run_expert_backward(
async def sequential_forward(
inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
inputs: torch.Tensor,
prompts: torch.Tensor,
sequence_manager: RemoteSequenceManager,
start_index: int = 0,
end_index: Optional[int] = None,
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
"""
Constructs a routing path from <start_index> to <end_index>.
@ -96,6 +114,9 @@ async def sequential_forward(
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
assert is_dummy(prompts) or len(prompts) == len(
sequence_manager.block_uids
) # should be n_layers - 1 but add extra prompts for convenience
sequences = sequence_manager.make_sequence(start_index, end_index)
intermediate_inputs = []
@ -107,7 +128,9 @@ async def sequential_forward(
span = sequences.pop(0)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
(outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
(outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
@ -119,7 +142,7 @@ async def sequential_forward(
inputs = outputs
break
except Exception as e:
logging.debug(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
backup_sequences = sequence_manager.make_sequence(span.start)
assert backup_sequences[0].start == span.start
sequences = backup_sequences
@ -129,58 +152,68 @@ async def sequential_forward(
async def sequential_backward(
grad_outputs: Sequence[torch.Tensor],
intermediate_inputs: Sequence[torch.Tensor],
forward_sequences: Sequence[RemoteSpanInfo],
intermediate_inputs: List[torch.Tensor],
prompts: torch.Tensor,
forward_sequences: List[RemoteSpanInfo],
sequence_manager: RemoteSequenceManager,
) -> Sequence[torch.Tensor]:
"""
Performs chained backward for each forward subsequence.
If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
"""
assert len(intermediate_inputs) == len(forward_sequences)
# TODO think about grads w.r.t. deep prompts
grad_prompts_reversed = []
while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
while True:
inputs = intermediate_inputs.pop(-1)
span = forward_sequences.pop(-1)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
try:
inputs = intermediate_inputs.pop(-1)
span = forward_sequences.pop(-1)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
grad_outputs = await run_expert_backward(
span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
grad_outputs, *span_grad_prompts = await run_expert_backward(
span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end]
)
grad_outputs = [grad_outputs]
grad_prompts_reversed.extend(span_grad_prompts)
break
except Exception as e:
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
_, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
inputs, sequence_manager, start_index=span.start, end_index=span.end
inputs, prompts[span.start : span.end], sequence_manager, start_index=span.start, end_index=span.end
)
assert len(intermediate_inputs) == len(forward_sequences)
assert backup_forward_sequences[0].start == span.start
assert backup_forward_sequences[-1].end == span.end
forward_sequences.extend(backup_forward_sequences)
intermediate_inputs.extend(backup_intermediate_inputs)
return grad_outputs
# For now, we do not support mixed dummy and grad prompts
# Concat in num_layer dimension
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
return grad_outputs, grad_prompts
async def _gather_forward(input_batches, sequence_manager):
async def _gather_forward(input_batches, prompt_batches, sequence_manager):
"""Wrapper for asyncio.gather to perform parallel sequential forwards"""
return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
return await asyncio.gather(
*[
sequential_forward(input_batch, prompt_batch, sequence_manager)
for input_batch, prompt_batch in zip(input_batches, prompt_batches)
]
)
async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
async def _gather_backward(
grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
):
"""Wrapper for asyncio.gather to perform parallel sequential backwards"""
return await asyncio.gather(
*[
sequential_backward((grad_output,), input_batch, spans, sequence_manager)
for grad_output, input_batch, spans in zip(
grad_output_batches, intermediate_input_batches, forward_sequences
sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
for grad_output, input_batch, prompt_batch, spans in zip(
grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
)
]
)
@ -193,18 +226,23 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
if is_dummy(prompts):
prompt_batches = [DUMMY] * len(input_batches)
else:
prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
sequence_manager.rpc_info # lazy init
outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))
outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
assert len(outputs) == len(input_batches)
output_batches = [output[0] for output in outputs]
intemediate_input_batches = [output[1] for output in outputs]
sequences_for_batches = [output[2] for output in outputs]
ctx.prompt_batches = prompt_batches
ctx.sequence_manager = sequence_manager
ctx.intemediate_input_batches = intemediate_input_batches
ctx.sequences_for_batches = sequences_for_batches
@ -220,9 +258,19 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
grad_input_batches = RemoteExpertWorker.run_coroutine(
_gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
outputs = RemoteExpertWorker.run_coroutine(
_gather_backward(
grad_output_batches,
intermediate_input_batches,
ctx.prompt_batches,
forward_sequences,
ctx.sequence_manager,
)
)
grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
grad_inputs = torch.cat(grad_inputs, dim=0)
return (grad_inputs, None)
grad_input_batches = [output[0][0] for output in outputs]
grad_prompt_batches = [output[1] for output in outputs]
grad_inputs = torch.cat(grad_input_batches, dim=0)
dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
return (grad_inputs, grad_prompts, None)

@ -1,8 +1,16 @@
import contextlib
from typing import AsyncIterator, Dict, Sequence
from typing import AsyncIterator, Dict, List, Optional, Sequence, Union
import torch
from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
from hivemind import (
DHT,
MSGPackSerializer,
P2PContext,
TensorDescriptor,
deserialize_torch_tensor,
nested_flatten,
serialize_torch_tensor,
)
from hivemind.moe.server.connection_handler import ConnectionHandler
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
from hivemind.proto import runtime_pb2
@ -12,6 +20,7 @@ from hivemind.utils.streaming import split_for_streaming
from src.data_structures import CHAIN_DELIMITER, ModuleUID
from src.server.backend import MAX_LENGTH, TransformerBackend
from src.utils.misc import DUMMY, is_dummy
class TransformerConnectionHandler(ConnectionHandler):
@ -33,7 +42,7 @@ class TransformerConnectionHandler(ConnectionHandler):
try:
print("OPENED RPC_INFERENCE")
request = await anext(requests)
requested_uids = self._check_header(request)
requested_uids = self._check_uids(request.uid)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
batch_size = request.tensors[0].size[0] if request.tensors else 1
@ -80,27 +89,18 @@ class TransformerConnectionHandler(ConnectionHandler):
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
# Parse request and prepare backends
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_header(request)
flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_uids(request.uid)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
# Cast inputs to backend dtype
hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
# Run a chain of requested backends
for backend in requested_backends:
assert isinstance(hidden_states, (list, tuple))
assert (
len(hidden_states) == 1 and hidden_states[0].ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
hidden_states = await backend.forward_pool.submit_task(*hidden_states)
# Serialize the overall output and respond
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
# Serialize output and respond to client
return runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
)
@ -108,29 +108,20 @@ class TransformerConnectionHandler(ConnectionHandler):
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
# Parse requests and prepare backends
uids_header, hidden_states = await self._gather_inputs(requests, context)
requested_uids = self._check_header_str(uids_header)
uid_str, flat_inputs = await self._gather_inputs(requests, context)
requested_uids = self._check_uids(uid_str)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
# Cast inputs to backend dtype
hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
# Run a chain of requested backends
for backend in requested_backends:
assert isinstance(hidden_states, (list, tuple))
assert (
len(hidden_states) == 1 and hidden_states[0].ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
hidden_states = await backend.forward_pool.submit_task(*hidden_states)
hidden_states = await _rpc_forward(flat_inputs, requested_backends)
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
# Serialize the overall output
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
serialized_output = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
# Split the serialized_output for streaming and respond
# Split the serialized_output for streaming and respond to client
output_split = [
part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
]
@ -139,36 +130,25 @@ class TransformerConnectionHandler(ConnectionHandler):
async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
# Parse requests and prepare backends
inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_header(request)
flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_uids(request.uid)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grads = grads.to(requested_backends[-1].dtype)
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = [inputs]
for backend in requested_backends[:-1]:
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
inputs = await backend.forward_pool.submit_task(inputs)
assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
inputs = inputs[0]
inter_inputs.append(inputs)
# Run a chain of requested backends
for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
inputs_and_grads = [inp, grads]
grads = await backend.backward_pool.submit_task(*inputs_and_grads)
assert isinstance(grads, (list, tuple)) and len(grads) == 1
grads = grads[0]
grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
# Modify grad_inputs_schema to support grad_prompts
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
grad_inputs_schema_with_prompts = (
requested_backends[0].args_schema * len(grads),
requested_backends[0].kwargs_schema,
) # TODO generalize
# Serialize the overall grad_input and respond
return runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
]
)
@ -176,36 +156,23 @@ class TransformerConnectionHandler(ConnectionHandler):
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
inputs, grads = inputs_and_grads
requested_uids = self._check_header_str(uids_header)
uids_header, flat_tensors = await self._gather_inputs(requests, context)
requested_uids = self._check_uids(uids_header)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grads = grads.to(requested_backends[-1].dtype)
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its outputs
inter_inputs = [inputs]
for backend in requested_backends[:-1]:
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
inputs = await backend.forward_pool.submit_task(inputs)
assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
inputs = inputs[0]
inter_inputs.append(inputs)
# Run a backward chain for requested backends
for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
inputs_and_grads = [inp, grads]
grads = await backend.backward_pool.submit_task(*inputs_and_grads)
assert isinstance(grads, (list, tuple)) and len(grads) == 1
grads = grads[0]
grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
# Modify grad_inputs_schema to support grad_prompts
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
grad_inputs_schema_with_prompts = (
requested_backends[0].args_schema * len(grads),
requested_backends[0].kwargs_schema,
) # TODO generalize
# Serialize the overall grad_inputs
serialized_grad_inputs = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
]
# Split the serialized_grad_inputs for streaming and respond
output_split = [
@ -215,19 +182,9 @@ class TransformerConnectionHandler(ConnectionHandler):
async for part in as_aiter(*output_split):
yield runtime_pb2.ExpertResponse(tensors=[part])
def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
"""Check that the first request to rpc_inference is valid"""
uids = (request.uid or "").split(CHAIN_DELIMITER)
if not uids:
raise RuntimeError("User did not provide any uids")
for uid in uids:
if uid not in self.module_backends:
raise RuntimeError(f"Remote peer does not serve {uid}")
return tuple(uids)
def _check_header_str(self, header) -> Sequence[ModuleUID]:
"""Check that the first request to rpc_inference is valid"""
uids = (header or "").split(CHAIN_DELIMITER)
uids = (uids or "").split(CHAIN_DELIMITER)
if not uids:
raise RuntimeError("User did not provide any uids")
for uid in uids:
@ -252,3 +209,83 @@ class TransformerConnectionHandler(ConnectionHandler):
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
yield handles
async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
hidden_states, *prompts = flat_tensors
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
assert hidden_states.ndim == 3
if not prompts or is_dummy(prompts[0]):
prompts = [DUMMY] * len(requested_backends)
pre_seq_len = 0
else:
prompts = [prompts[0].to(requested_backends[0].dtype)]
prompts = [p.squeeze(0) for p in prompts[0].split(1)]
pre_seq_len = prompts[0].shape[-2]
# Run a chain of requested backends
for backend, prompt in zip(requested_backends, prompts):
if not is_dummy(prompt):
hidden_states[:, :pre_seq_len] += prompt
(hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
assert isinstance(hidden_states, torch.Tensor)
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
# Serialize the overall output
return hidden_states
async def _rpc_backward(
*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
inputs, grad_outputs, *prompts = flat_tensors
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
if not prompts or is_dummy(prompts[0]):
prompts = [DUMMY] * len(requested_backends)
pre_seq_len = 0
else:
prompts = [prompts[0].to(requested_backends[0].dtype)]
prompts = [p.squeeze(0) for p in prompts[0].split(1)]
pre_seq_len = prompts[0].shape[-2]
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = []
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
if not is_dummy(prompt):
inputs[:, :pre_seq_len] += prompt
inter_inputs.append(inputs)
(inputs,) = await backend.forward_pool.submit_task(inputs)
assert isinstance(inputs, torch.Tensor)
if not is_dummy(prompts[-1]):
inputs[:, :pre_seq_len] += prompts[-1]
inter_inputs.append(inputs)
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
grad_prompts_reversed = []
# Run a chain of requested backends
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape

@ -0,0 +1,7 @@
import torch
DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters
def is_dummy(tensor: torch.Tensor):
return tensor.numel() == 0

@ -4,6 +4,7 @@ from hivemind import DHT, get_logger, use_hivemind_log_handler
from test_utils import *
from src import RemoteSequential
from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_model import DistributedBloomConfig
use_hivemind_log_handler("in_root_logger")
@ -41,3 +42,48 @@ def test_remote_sequential():
(second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad)
@pytest.mark.forked
def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
remote_sequential = RemoteSequential(config, dht)
inputs = torch.randn(batch_size, seq_len, config.hidden_size)
output_proj = torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size)
input_prompts = torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
input_prompts = input_prompts.detach().requires_grad_(True)
intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
inputs_with_prompts = torch.cat([inputs, input_prompts], dim=1)
assert inputs_with_prompts.shape == (batch_size, seq_len + pre_seq_len, config.hidden_size)
outputs = remote_sequential(inputs_with_prompts, prompts=intermediate_prompts)
(outputs * output_proj).sum().backward()
assert intermediate_prompts.grad is not None
input_prompts_ref = input_prompts.clone().detach().requires_grad_(True)
intermediate_prompts_ref = intermediate_prompts.clone().detach().requires_grad_(True)
assert input_prompts_ref.grad is None
assert intermediate_prompts_ref.grad is None
outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)
for block_index in range(config.n_layer):
block_prompt = intermediate_prompts_ref[block_index]
outputs_ref[:, : block_prompt.shape[1]] += block_prompt
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
(outputs_ref,) = block(outputs_ref)
assert torch.allclose(outputs_ref, outputs)
(outputs_ref * output_proj).sum().backward()
assert input_prompts_ref.grad is not None
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad)
assert intermediate_prompts_ref.grad is not None
assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad)

Loading…
Cancel
Save