|
|
|
@ -52,7 +52,7 @@ async def run_rpc_forward(
|
|
|
|
|
"""
|
|
|
|
|
requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
|
|
|
|
|
flat_tensors = tuple(tensor.detach() for tensor in flat_tensors)
|
|
|
|
|
(hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
|
|
|
|
|
(hidden_states, prompts), block_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
|
|
|
|
|
dtype = requested_backends[0].dtype
|
|
|
|
|
# check parse input tensors and cast dtypes
|
|
|
|
|
hidden_states = hidden_states.to(dtype)
|
|
|
|
@ -64,7 +64,7 @@ async def run_rpc_forward(
|
|
|
|
|
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
|
|
|
|
|
|
|
# Run a chain of requested backends
|
|
|
|
|
for backend, prompt, kwargs in zip(requested_backends, prompts, backend_kwargs):
|
|
|
|
|
for backend, prompt, kwargs in zip(requested_backends, prompts, block_kwargs):
|
|
|
|
|
if not is_dummy(prompt):
|
|
|
|
|
hidden_states[:, : prompt.shape[1]] += prompt
|
|
|
|
|
|
|
|
|
@ -97,7 +97,7 @@ async def run_rpc_backward(
|
|
|
|
|
) -> Tuple[Sequence[torch.Tensor], Any]:
|
|
|
|
|
"""A custom backward pass used by the server to service rpc_backward and rpc_backward_stream requests"""
|
|
|
|
|
assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
|
|
|
|
|
((grad_outputs,), hidden_states, prompts), backend_kwargs = _check_inputs(
|
|
|
|
|
((grad_outputs,), hidden_states, prompts), block_kwargs = _check_inputs(
|
|
|
|
|
requested_backends, flat_tensors, args_structure
|
|
|
|
|
)
|
|
|
|
|
input_requires_grad, prompts_requires_grad = hidden_states.requires_grad, prompts.requires_grad
|
|
|
|
@ -115,7 +115,7 @@ async def run_rpc_backward(
|
|
|
|
|
# Run a forward chain to collect intermediate inputs
|
|
|
|
|
# Note that we do not forward for the last module since we do not need its output
|
|
|
|
|
inter_inputs = []
|
|
|
|
|
for backend, prompt, kwargs in zip(requested_backends[:-1], prompts[:-1], backend_kwargs):
|
|
|
|
|
for backend, prompt, kwargs in zip(requested_backends[:-1], prompts[:-1], block_kwargs):
|
|
|
|
|
assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
|
|
if not is_dummy(prompt):
|
|
|
|
|
hidden_states[:, : prompt.shape[1]] += prompt
|
|
|
|
@ -135,11 +135,11 @@ async def run_rpc_backward(
|
|
|
|
|
|
|
|
|
|
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
|
|
|
|
|
grad_prompts_reversed = []
|
|
|
|
|
grad_backend_kwargs_reversed = []
|
|
|
|
|
grad_block_kwargs_reversed = []
|
|
|
|
|
|
|
|
|
|
# Run a chain of requested backends
|
|
|
|
|
for hidden_states, prompt, backend, kwargs in reversed(
|
|
|
|
|
list(zip(inter_inputs, prompts, requested_backends, backend_kwargs))
|
|
|
|
|
list(zip(inter_inputs, prompts, requested_backends, block_kwargs))
|
|
|
|
|
):
|
|
|
|
|
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
|
|
|
hidden_states = hidden_states.detach().requires_grad_(True)
|
|
|
|
@ -152,11 +152,11 @@ async def run_rpc_backward(
|
|
|
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
|
|
|
if not is_dummy(prompt) and prompts_requires_grad:
|
|
|
|
|
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
|
|
|
|
grad_backend_kwargs_reversed.append(grad_kwargs)
|
|
|
|
|
grad_block_kwargs_reversed.append(grad_kwargs)
|
|
|
|
|
|
|
|
|
|
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
|
|
|
|
grad_args = [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]
|
|
|
|
|
return pack_args_kwargs((grad_args, list(reversed(grad_backend_kwargs_reversed))))
|
|
|
|
|
return pack_args_kwargs((grad_args, list(reversed(grad_block_kwargs_reversed))))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def iterate_rpc_inference(
|
|
|
|
@ -179,7 +179,7 @@ async def iterate_rpc_inference(
|
|
|
|
|
|
|
|
|
|
async for request, step_metadata in input_iterator:
|
|
|
|
|
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
|
|
|
|
|
(hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(
|
|
|
|
|
(hidden_states, prompts, hypo_ids), block_kwargs = _check_inputs(
|
|
|
|
|
requested_backends, flat_tensors, args_structure
|
|
|
|
|
)
|
|
|
|
|
batch_size, length_increment, _ = hidden_states.shape
|
|
|
|
@ -230,13 +230,13 @@ async def iterate_rpc_inference(
|
|
|
|
|
hypo_ids,
|
|
|
|
|
inference_infos,
|
|
|
|
|
*prompts,
|
|
|
|
|
backend_kwargs=backend_kwargs,
|
|
|
|
|
block_kwargs=block_kwargs,
|
|
|
|
|
priority=priority,
|
|
|
|
|
size=num_tokens,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
for backend, uid, handles, prompt, kwargs in zip(
|
|
|
|
|
requested_backends, requested_uids, cache_handles, prompts, backend_kwargs
|
|
|
|
|
requested_backends, requested_uids, cache_handles, prompts, block_kwargs
|
|
|
|
|
):
|
|
|
|
|
inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
|
|
|
|
|
(hidden_states,) = await backend.inference_pool.submit_task(
|
|
|
|
@ -244,7 +244,7 @@ async def iterate_rpc_inference(
|
|
|
|
|
hypo_ids,
|
|
|
|
|
inference_infos,
|
|
|
|
|
prompt,
|
|
|
|
|
backend_kwargs=(kwargs,),
|
|
|
|
|
block_kwargs=(kwargs,),
|
|
|
|
|
priority=priority,
|
|
|
|
|
size=num_tokens,
|
|
|
|
|
)
|
|
|
|
@ -269,19 +269,19 @@ def _check_inputs(
|
|
|
|
|
hidden_states, grad_outputs, prompts = flat_tensors
|
|
|
|
|
flat_tensors = grad_outputs, hidden_states, prompts
|
|
|
|
|
if args_structure is not None:
|
|
|
|
|
args, *backend_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
|
|
|
|
args, *block_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
|
|
|
|
else:
|
|
|
|
|
args, *backend_kwargs = flat_tensors, {} # backward compatibility for grad structure, remove at 2.2
|
|
|
|
|
args, *block_kwargs = flat_tensors, {} # backward compatibility for grad structure, remove at 2.2
|
|
|
|
|
|
|
|
|
|
if len(backend_kwargs) not in (1, len(requested_backends)):
|
|
|
|
|
if len(block_kwargs) not in (1, len(requested_backends)):
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Server expected either one dict of keyword arguments or {len(requested_backends)} dicts "
|
|
|
|
|
f"(one for each block). Found {len(backend_kwargs)} instead."
|
|
|
|
|
f"(one for each block). Found {len(block_kwargs)} instead."
|
|
|
|
|
)
|
|
|
|
|
if len(backend_kwargs) == 1:
|
|
|
|
|
backend_kwargs = backend_kwargs * len(requested_backends)
|
|
|
|
|
assert len(backend_kwargs) == len(requested_backends)
|
|
|
|
|
for i, kwargs in enumerate(backend_kwargs):
|
|
|
|
|
if len(block_kwargs) == 1:
|
|
|
|
|
block_kwargs = block_kwargs * len(requested_backends)
|
|
|
|
|
assert len(block_kwargs) == len(requested_backends)
|
|
|
|
|
for i, kwargs in enumerate(block_kwargs):
|
|
|
|
|
if not isinstance(kwargs, dict):
|
|
|
|
|
raise RuntimeError(f"Expected kwargs for block {i} to be a dictionary, got {type(kwargs)}")
|
|
|
|
|
return args, backend_kwargs
|
|
|
|
|
return args, block_kwargs
|
|
|
|
|