|
|
|
@ -31,20 +31,27 @@ logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
async def run_rpc_forward(
|
|
|
|
|
*flat_tensors: torch.Tensor,
|
|
|
|
|
args_structure: Any,
|
|
|
|
|
requested_backends: Sequence[TransformerBackend],
|
|
|
|
|
active_adapter: str = "",
|
|
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
|
|
points: int = 0,
|
|
|
|
|
args_structure: Any,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
|
|
|
|
|
|
|
|
|
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
|
|
|
|
|
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
|
|
|
|
|
:param args_structure: a schema that defines which of flat_tensors corresponds to which arg / kwarg
|
|
|
|
|
:note: see pack_args_kwargs function for the definition of args_structure
|
|
|
|
|
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
|
|
|
|
:param active_adapter: the name of LoRA adapter to use; defaults to no adapter
|
|
|
|
|
:param prioritizer: assigns priorities to each sub-request based on the number of points
|
|
|
|
|
:param points: client-specified number of points, used to assign priorities
|
|
|
|
|
:param args_structure:
|
|
|
|
|
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
|
|
|
|
"""
|
|
|
|
|
requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
|
|
|
|
|
flat_tensors = tuple(tensor.detach() for tensor in flat_tensors)
|
|
|
|
|
(hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
|
|
|
|
|
dtype = requested_backends[0].dtype
|
|
|
|
|
# check parse input tensors and cast dtypes
|
|
|
|
@ -77,7 +84,7 @@ async def run_rpc_forward(
|
|
|
|
|
hidden_states.ndim == 3
|
|
|
|
|
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
|
|
|
|
|
|
|
return hidden_states
|
|
|
|
|
return hidden_states.requires_grad_(requires_grad)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def run_rpc_backward(
|
|
|
|
@ -88,19 +95,22 @@ async def run_rpc_backward(
|
|
|
|
|
points: int = 0,
|
|
|
|
|
args_structure: Any,
|
|
|
|
|
) -> Tuple[Sequence[torch.Tensor], Any]:
|
|
|
|
|
"""A custom backward pass used by the server to service rpc_backward and rpc_backward_stream requests"""
|
|
|
|
|
assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
|
|
|
|
|
((grad_outputs,), hidden_states, prompts), backend_kwargs = _check_inputs(
|
|
|
|
|
requested_backends, flat_tensors, args_structure
|
|
|
|
|
)
|
|
|
|
|
input_requires_grad, prompts_requires_grad = hidden_states.requires_grad, prompts.requires_grad
|
|
|
|
|
|
|
|
|
|
# Cast inputs & grad outputs to backend dtype
|
|
|
|
|
num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
|
|
|
|
|
hidden_states = hidden_states.to(requested_backends[0].dtype)
|
|
|
|
|
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
|
|
|
|
hidden_states = hidden_states.detach().to(requested_backends[0].dtype)
|
|
|
|
|
grad_outputs = grad_outputs.detach().to(requested_backends[-1].dtype)
|
|
|
|
|
|
|
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
|
|
prompts = [DUMMY] * len(requested_backends)
|
|
|
|
|
else:
|
|
|
|
|
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
|
|
prompts = [p.squeeze(0).detach() for p in prompts.detach().to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
|
|
|
|
|
|
|
# Run a forward chain to collect intermediate inputs
|
|
|
|
|
# Note that we do not forward for the last module since we do not need its output
|
|
|
|
@ -140,7 +150,7 @@ async def run_rpc_backward(
|
|
|
|
|
active_adapter, grad_outputs, hidden_states, **kwargs, priority=priority, size=num_tokens
|
|
|
|
|
)
|
|
|
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
|
|
|
if not is_dummy(prompt):
|
|
|
|
|
if not is_dummy(prompt) and prompts_requires_grad:
|
|
|
|
|
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
|
|
|
|
grad_backend_kwargs_reversed.append(grad_kwargs)
|
|
|
|
|
|
|
|
|
|