|
|
|
@ -135,7 +135,7 @@ async def run_rpc_backward(
|
|
|
|
|
priority = prioritizer.prioritize(
|
|
|
|
|
hidden_states, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
|
|
|
|
)
|
|
|
|
|
(*grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task(
|
|
|
|
|
(grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task(
|
|
|
|
|
active_adapter, grad_outputs, hidden_states, **kwargs, priority=priority, size=num_tokens
|
|
|
|
|
)
|
|
|
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
|
|
@ -145,7 +145,7 @@ async def run_rpc_backward(
|
|
|
|
|
|
|
|
|
|
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
|
|
|
|
grad_args = [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]
|
|
|
|
|
return pack_args_kwargs((grad_args, reversed(grad_backend_kwargs_reversed)))
|
|
|
|
|
return pack_args_kwargs((grad_args, list(reversed(grad_backend_kwargs_reversed))))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def iterate_rpc_inference(
|
|
|
|
|