standardize: s/backend_kwargs/block_kwargs/g everywhere

pull/467/head
Your Name 9 months ago
parent 68b8cea246
commit 8eb1722f1e

@ -242,12 +242,12 @@ class _MergedInferenceStep:
hypo_ids: torch.LongTensor,
inference_infos: Sequence[InferenceMetadata],
*optional_prompts: Optional[torch.Tensor],
backend_kwargs: Sequence[Dict[str, torch.Tensor]],
block_kwargs: Sequence[Dict[str, torch.Tensor]],
) -> Tuple[torch.Tensor, ...]:
assert (
len(inference_infos) == len(optional_prompts) == len(backend_kwargs)
), f"mismatch: got {len(inference_infos)} infos, {len(optional_prompts)} prompts, {len(backend_kwargs)} kwargs"
for inference_info, optional_prompt, kwargs in zip(inference_infos, optional_prompts, backend_kwargs):
len(inference_infos) == len(optional_prompts) == len(block_kwargs)
), f"mismatch: got {len(inference_infos)} infos, {len(optional_prompts)} prompts, {len(block_kwargs)} kwargs"
for inference_info, optional_prompt, kwargs in zip(inference_infos, optional_prompts, block_kwargs):
if optional_prompt is not None:
hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
(hidden_states,) = self.backends[inference_info.uid].inference_step(

@ -52,7 +52,7 @@ async def run_rpc_forward(
"""
requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
flat_tensors = tuple(tensor.detach() for tensor in flat_tensors)
(hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
(hidden_states, prompts), block_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
@ -64,7 +64,7 @@ async def run_rpc_forward(
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a chain of requested backends
for backend, prompt, kwargs in zip(requested_backends, prompts, backend_kwargs):
for backend, prompt, kwargs in zip(requested_backends, prompts, block_kwargs):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
@ -97,7 +97,7 @@ async def run_rpc_backward(
) -> Tuple[Sequence[torch.Tensor], Any]:
"""A custom backward pass used by the server to service rpc_backward and rpc_backward_stream requests"""
assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
((grad_outputs,), hidden_states, prompts), backend_kwargs = _check_inputs(
((grad_outputs,), hidden_states, prompts), block_kwargs = _check_inputs(
requested_backends, flat_tensors, args_structure
)
input_requires_grad, prompts_requires_grad = hidden_states.requires_grad, prompts.requires_grad
@ -115,7 +115,7 @@ async def run_rpc_backward(
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = []
for backend, prompt, kwargs in zip(requested_backends[:-1], prompts[:-1], backend_kwargs):
for backend, prompt, kwargs in zip(requested_backends[:-1], prompts[:-1], block_kwargs):
assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
@ -135,11 +135,11 @@ async def run_rpc_backward(
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
grad_prompts_reversed = []
grad_backend_kwargs_reversed = []
grad_block_kwargs_reversed = []
# Run a chain of requested backends
for hidden_states, prompt, backend, kwargs in reversed(
list(zip(inter_inputs, prompts, requested_backends, backend_kwargs))
list(zip(inter_inputs, prompts, requested_backends, block_kwargs))
):
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
hidden_states = hidden_states.detach().requires_grad_(True)
@ -152,11 +152,11 @@ async def run_rpc_backward(
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt) and prompts_requires_grad:
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
grad_backend_kwargs_reversed.append(grad_kwargs)
grad_block_kwargs_reversed.append(grad_kwargs)
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
grad_args = [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]
return pack_args_kwargs((grad_args, list(reversed(grad_backend_kwargs_reversed))))
return pack_args_kwargs((grad_args, list(reversed(grad_block_kwargs_reversed))))
async def iterate_rpc_inference(
@ -179,7 +179,7 @@ async def iterate_rpc_inference(
async for request, step_metadata in input_iterator:
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
(hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(
(hidden_states, prompts, hypo_ids), block_kwargs = _check_inputs(
requested_backends, flat_tensors, args_structure
)
batch_size, length_increment, _ = hidden_states.shape
@ -230,13 +230,13 @@ async def iterate_rpc_inference(
hypo_ids,
inference_infos,
*prompts,
backend_kwargs=backend_kwargs,
block_kwargs=block_kwargs,
priority=priority,
size=num_tokens,
)
else:
for backend, uid, handles, prompt, kwargs in zip(
requested_backends, requested_uids, cache_handles, prompts, backend_kwargs
requested_backends, requested_uids, cache_handles, prompts, block_kwargs
):
inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
(hidden_states,) = await backend.inference_pool.submit_task(
@ -244,7 +244,7 @@ async def iterate_rpc_inference(
hypo_ids,
inference_infos,
prompt,
backend_kwargs=(kwargs,),
block_kwargs=(kwargs,),
priority=priority,
size=num_tokens,
)
@ -269,19 +269,19 @@ def _check_inputs(
hidden_states, grad_outputs, prompts = flat_tensors
flat_tensors = grad_outputs, hidden_states, prompts
if args_structure is not None:
args, *backend_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
args, *block_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
else:
args, *backend_kwargs = flat_tensors, {} # backward compatibility for grad structure, remove at 2.2
args, *block_kwargs = flat_tensors, {} # backward compatibility for grad structure, remove at 2.2
if len(backend_kwargs) not in (1, len(requested_backends)):
if len(block_kwargs) not in (1, len(requested_backends)):
raise RuntimeError(
f"Server expected either one dict of keyword arguments or {len(requested_backends)} dicts "
f"(one for each block). Found {len(backend_kwargs)} instead."
f"(one for each block). Found {len(block_kwargs)} instead."
)
if len(backend_kwargs) == 1:
backend_kwargs = backend_kwargs * len(requested_backends)
assert len(backend_kwargs) == len(requested_backends)
for i, kwargs in enumerate(backend_kwargs):
if len(block_kwargs) == 1:
block_kwargs = block_kwargs * len(requested_backends)
assert len(block_kwargs) == len(requested_backends)
for i, kwargs in enumerate(block_kwargs):
if not isinstance(kwargs, dict):
raise RuntimeError(f"Expected kwargs for block {i} to be a dictionary, got {type(kwargs)}")
return args, backend_kwargs
return args, block_kwargs

Loading…
Cancel
Save