make it work for fwd, bwd

This commit is contained in:
Your Name 2023-09-06 01:28:01 +03:00
parent 465fd93147
commit f2049658b6

View File

@ -135,7 +135,7 @@ async def run_rpc_backward(
priority = prioritizer.prioritize(
hidden_states, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
)
(*grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task(
(grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task(
active_adapter, grad_outputs, hidden_states, **kwargs, priority=priority, size=num_tokens
)
assert isinstance(grad_outputs, torch.Tensor)
@ -145,7 +145,7 @@ async def run_rpc_backward(
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
grad_args = [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]
return pack_args_kwargs((grad_args, reversed(grad_backend_kwargs_reversed)))
return pack_args_kwargs((grad_args, list(reversed(grad_backend_kwargs_reversed))))
async def iterate_rpc_inference(