|
|
|
@ -128,8 +128,9 @@ async def run_rpc_backward(
|
|
|
|
|
grad_backend_kwargs_reversed = []
|
|
|
|
|
|
|
|
|
|
# Run a chain of requested backends
|
|
|
|
|
for hidden_states, prompt, backend, kwargs in reversed(list(zip(
|
|
|
|
|
inter_inputs, prompts, requested_backends, backend_kwargs))):
|
|
|
|
|
for hidden_states, prompt, backend, kwargs in reversed(
|
|
|
|
|
list(zip(inter_inputs, prompts, requested_backends, backend_kwargs))
|
|
|
|
|
):
|
|
|
|
|
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
|
|
|
hidden_states = hidden_states.detach().requires_grad_(True)
|
|
|
|
|
priority = prioritizer.prioritize(
|
|
|
|
|