black-isort-clarify

pull/467/head
Your Name 9 months ago
parent 9e29140bb0
commit 17d278e88a

@ -104,9 +104,13 @@ async def run_remote_forward(
size = sum(t.element_size() * t.nelement() for t in flat_tensors)
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR
return await forward_fn(
output_tensors = await forward_fn(
merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata)
)
# backward compatibility: ensure requires_grad; remove after https://github.com/learning-at-home/hivemind/pull/591
requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
output_tensors = [tensor.requires_grad_(requires_grad) for tensor in output_tensors]
return output_tensors
async def run_remote_backward(

@ -493,6 +493,7 @@ class RemoteSequenceManager:
self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs
) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]:
"""
return a sequence of compression codecs for client-side compression (applied to tensors sent to remote server)
:param peer_id: remote server's PeerID
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
:param args: request-specific input tensors

@ -31,20 +31,27 @@ logger = get_logger(__name__)
async def run_rpc_forward(
*flat_tensors: torch.Tensor,
args_structure: Any,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
args_structure: Any,
) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
:param args_structure: a schema that defines which of flat_tensors corresponds to which arg / kwarg
:note: see pack_args_kwargs function for the definition of args_structure
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:param active_adapter: the name of LoRA adapter to use; defaults to no adapter
:param prioritizer: assigns priorities to each sub-request based on the number of points
:param points: client-specified number of points, used to assign priorities
:param args_structure:
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
flat_tensors = tuple(tensor.detach() for tensor in flat_tensors)
(hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
@ -77,7 +84,7 @@ async def run_rpc_forward(
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
return hidden_states
return hidden_states.requires_grad_(requires_grad)
async def run_rpc_backward(
@ -88,19 +95,22 @@ async def run_rpc_backward(
points: int = 0,
args_structure: Any,
) -> Tuple[Sequence[torch.Tensor], Any]:
"""A custom backward pass used by the server to service rpc_backward and rpc_backward_stream requests"""
assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
((grad_outputs,), hidden_states, prompts), backend_kwargs = _check_inputs(
requested_backends, flat_tensors, args_structure
)
input_requires_grad, prompts_requires_grad = hidden_states.requires_grad, prompts.requires_grad
# Cast inputs & grad outputs to backend dtype
num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
hidden_states = hidden_states.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
hidden_states = hidden_states.detach().to(requested_backends[0].dtype)
grad_outputs = grad_outputs.detach().to(requested_backends[-1].dtype)
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
prompts = [p.squeeze(0).detach() for p in prompts.detach().to(requested_backends[0].dtype).split(1, dim=0)]
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
@ -140,7 +150,7 @@ async def run_rpc_backward(
active_adapter, grad_outputs, hidden_states, **kwargs, priority=priority, size=num_tokens
)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
if not is_dummy(prompt) and prompts_requires_grad:
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
grad_backend_kwargs_reversed.append(grad_kwargs)

@ -361,18 +361,19 @@ class TransformerConnectionHandler(ConnectionHandler):
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_forward should have number of points as number or None, got {points}"
hidden_states = await run_rpc_forward(
*flat_inputs,
args_structure=args_structure,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
prioritizer=self._prioritizer,
points=points,
args_structure=args_structure,
)
return runtime_pb2.ExpertResponse(
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
)
@ -396,11 +397,11 @@ class TransformerConnectionHandler(ConnectionHandler):
hidden_states = await run_rpc_forward(
*flat_inputs,
args_structure=args_structure,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
prioritizer=self._prioritizer,
points=points,
args_structure=args_structure,
)
# Split the serialized_output for streaming and respond to client
@ -450,8 +451,8 @@ class TransformerConnectionHandler(ConnectionHandler):
flat_grads, grads_structure = await run_rpc_backward(
*flat_tensors,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
prioritizer=self._prioritizer,
points=points,
args_structure=args_structure,
)
@ -479,8 +480,8 @@ class TransformerConnectionHandler(ConnectionHandler):
flat_grads, grad_structure = await run_rpc_backward(
*flat_tensors,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
prioritizer=self._prioritizer,
points=points,
args_structure=args_structure,
)

@ -73,8 +73,8 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
rpc_info["forward_schema"] = (compressed_input_schema,), dict() # (args, kwargs)
return rpc_info
def get_request_metadata(self, protocol: str, *args, **kwargs):
metadata = super().get_request_metadata(protocol, *args, **kwargs)
def get_request_metadata(self, peer_id, protocol, block_uids, *args, **kwargs):
metadata = super().get_request_metadata(peer_id, protocol, block_uids, *args, **kwargs)
if protocol == "rpc_forward":
metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
elif protocol == "rpc_backward":

Loading…
Cancel
Save