mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
black-isort-clarify
This commit is contained in:
parent
9e29140bb0
commit
17d278e88a
@ -104,9 +104,13 @@ async def run_remote_forward(
|
||||
size = sum(t.element_size() * t.nelement() for t in flat_tensors)
|
||||
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
|
||||
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR
|
||||
return await forward_fn(
|
||||
output_tensors = await forward_fn(
|
||||
merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata)
|
||||
)
|
||||
# backward compatibility: ensure requires_grad; remove after https://github.com/learning-at-home/hivemind/pull/591
|
||||
requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
|
||||
output_tensors = [tensor.requires_grad_(requires_grad) for tensor in output_tensors]
|
||||
return output_tensors
|
||||
|
||||
|
||||
async def run_remote_backward(
|
||||
|
@ -493,6 +493,7 @@ class RemoteSequenceManager:
|
||||
self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs
|
||||
) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]:
|
||||
"""
|
||||
return a sequence of compression codecs for client-side compression (applied to tensors sent to remote server)
|
||||
:param peer_id: remote server's PeerID
|
||||
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
|
||||
:param args: request-specific input tensors
|
||||
|
@ -31,20 +31,27 @@ logger = get_logger(__name__)
|
||||
|
||||
async def run_rpc_forward(
|
||||
*flat_tensors: torch.Tensor,
|
||||
args_structure: Any,
|
||||
requested_backends: Sequence[TransformerBackend],
|
||||
active_adapter: str = "",
|
||||
prioritizer: TaskPrioritizerBase,
|
||||
points: int = 0,
|
||||
args_structure: Any,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
||||
|
||||
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
|
||||
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
|
||||
:param args_structure: a schema that defines which of flat_tensors corresponds to which arg / kwarg
|
||||
:note: see pack_args_kwargs function for the definition of args_structure
|
||||
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
||||
:param active_adapter: the name of LoRA adapter to use; defaults to no adapter
|
||||
:param prioritizer: assigns priorities to each sub-request based on the number of points
|
||||
:param points: client-specified number of points, used to assign priorities
|
||||
:param args_structure:
|
||||
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
||||
"""
|
||||
requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
|
||||
flat_tensors = tuple(tensor.detach() for tensor in flat_tensors)
|
||||
(hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
|
||||
dtype = requested_backends[0].dtype
|
||||
# check parse input tensors and cast dtypes
|
||||
@ -77,7 +84,7 @@ async def run_rpc_forward(
|
||||
hidden_states.ndim == 3
|
||||
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
||||
|
||||
return hidden_states
|
||||
return hidden_states.requires_grad_(requires_grad)
|
||||
|
||||
|
||||
async def run_rpc_backward(
|
||||
@ -88,19 +95,22 @@ async def run_rpc_backward(
|
||||
points: int = 0,
|
||||
args_structure: Any,
|
||||
) -> Tuple[Sequence[torch.Tensor], Any]:
|
||||
"""A custom backward pass used by the server to service rpc_backward and rpc_backward_stream requests"""
|
||||
assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
|
||||
((grad_outputs,), hidden_states, prompts), backend_kwargs = _check_inputs(
|
||||
requested_backends, flat_tensors, args_structure
|
||||
)
|
||||
input_requires_grad, prompts_requires_grad = hidden_states.requires_grad, prompts.requires_grad
|
||||
|
||||
# Cast inputs & grad outputs to backend dtype
|
||||
num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
|
||||
hidden_states = hidden_states.to(requested_backends[0].dtype)
|
||||
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
||||
hidden_states = hidden_states.detach().to(requested_backends[0].dtype)
|
||||
grad_outputs = grad_outputs.detach().to(requested_backends[-1].dtype)
|
||||
|
||||
if prompts is None or is_dummy(prompts):
|
||||
prompts = [DUMMY] * len(requested_backends)
|
||||
else:
|
||||
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||
prompts = [p.squeeze(0).detach() for p in prompts.detach().to(requested_backends[0].dtype).split(1, dim=0)]
|
||||
|
||||
# Run a forward chain to collect intermediate inputs
|
||||
# Note that we do not forward for the last module since we do not need its output
|
||||
@ -140,7 +150,7 @@ async def run_rpc_backward(
|
||||
active_adapter, grad_outputs, hidden_states, **kwargs, priority=priority, size=num_tokens
|
||||
)
|
||||
assert isinstance(grad_outputs, torch.Tensor)
|
||||
if not is_dummy(prompt):
|
||||
if not is_dummy(prompt) and prompts_requires_grad:
|
||||
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
||||
grad_backend_kwargs_reversed.append(grad_kwargs)
|
||||
|
||||
|
@ -361,18 +361,19 @@ class TransformerConnectionHandler(ConnectionHandler):
|
||||
active_adapter = self._get_active_adapter(metadata)
|
||||
points = metadata.get("points", 0)
|
||||
args_structure = metadata.get("args_structure")
|
||||
|
||||
assert isinstance(
|
||||
points, (float, int)
|
||||
), f"rpc_forward should have number of points as number or None, got {points}"
|
||||
|
||||
hidden_states = await run_rpc_forward(
|
||||
*flat_inputs,
|
||||
requested_backends=requested_backends,
|
||||
prioritizer=self._prioritizer,
|
||||
active_adapter=active_adapter,
|
||||
points=points,
|
||||
args_structure=args_structure,
|
||||
requested_backends=requested_backends,
|
||||
active_adapter=active_adapter,
|
||||
prioritizer=self._prioritizer,
|
||||
points=points,
|
||||
)
|
||||
|
||||
return runtime_pb2.ExpertResponse(
|
||||
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
|
||||
)
|
||||
@ -396,11 +397,11 @@ class TransformerConnectionHandler(ConnectionHandler):
|
||||
|
||||
hidden_states = await run_rpc_forward(
|
||||
*flat_inputs,
|
||||
requested_backends=requested_backends,
|
||||
prioritizer=self._prioritizer,
|
||||
active_adapter=active_adapter,
|
||||
points=points,
|
||||
args_structure=args_structure,
|
||||
requested_backends=requested_backends,
|
||||
active_adapter=active_adapter,
|
||||
prioritizer=self._prioritizer,
|
||||
points=points,
|
||||
)
|
||||
|
||||
# Split the serialized_output for streaming and respond to client
|
||||
@ -450,8 +451,8 @@ class TransformerConnectionHandler(ConnectionHandler):
|
||||
flat_grads, grads_structure = await run_rpc_backward(
|
||||
*flat_tensors,
|
||||
requested_backends=requested_backends,
|
||||
prioritizer=self._prioritizer,
|
||||
active_adapter=active_adapter,
|
||||
prioritizer=self._prioritizer,
|
||||
points=points,
|
||||
args_structure=args_structure,
|
||||
)
|
||||
@ -479,8 +480,8 @@ class TransformerConnectionHandler(ConnectionHandler):
|
||||
flat_grads, grad_structure = await run_rpc_backward(
|
||||
*flat_tensors,
|
||||
requested_backends=requested_backends,
|
||||
prioritizer=self._prioritizer,
|
||||
active_adapter=active_adapter,
|
||||
prioritizer=self._prioritizer,
|
||||
points=points,
|
||||
args_structure=args_structure,
|
||||
)
|
||||
|
@ -73,8 +73,8 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
|
||||
rpc_info["forward_schema"] = (compressed_input_schema,), dict() # (args, kwargs)
|
||||
return rpc_info
|
||||
|
||||
def get_request_metadata(self, protocol: str, *args, **kwargs):
|
||||
metadata = super().get_request_metadata(protocol, *args, **kwargs)
|
||||
def get_request_metadata(self, peer_id, protocol, block_uids, *args, **kwargs):
|
||||
metadata = super().get_request_metadata(peer_id, protocol, block_uids, *args, **kwargs)
|
||||
if protocol == "rpc_forward":
|
||||
metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
|
||||
elif protocol == "rpc_backward":
|
||||
|
Loading…
Reference in New Issue
Block a user