black, isort

partial_rollback
Your Name 10 months ago
parent 13c13d347a
commit 4529471f3f

@ -18,7 +18,7 @@ from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import TaskPrioritizerBase
from petals.utils.convert_block import QuantType
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import unpack_args_kwargs, pack_args_kwargs
from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
# We prioritize short inference requests and make them use a *merged* inference pool,
# so they are processed without interruptions and extra overheads
@ -88,7 +88,9 @@ async def run_rpc_backward(
points: int = 0,
args_structure: Any,
) -> Tuple[Sequence[torch.Tensor], Any]:
(hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
(hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(
requested_backends, flat_tensors, args_structure
)
# Cast inputs & grad outputs to backend dtype
assert hidden_states.ndim == 3
num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
@ -166,7 +168,9 @@ async def iterate_rpc_inference(
async for request, step_metadata in input_iterator:
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
(hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
(hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(
requested_backends, flat_tensors, args_structure
)
batch_size, length_increment, _ = hidden_states.shape
num_tokens = batch_size * length_increment

@ -502,15 +502,19 @@ class TransformerConnectionHandler(ConnectionHandler):
) -> Sequence[runtime_pb2.Tensor]:
"""Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema"""
inputs_with_grad = tuple(input for input in flat_inputs if input.requires_grad)
assert len(flat_grads) == len(inputs_with_grad), f"user provides {len(inputs_with_grad)} inputs with grad, " \
f"but backward produced {len(flat_grads)} gradients"
assert len(flat_grads) == len(inputs_with_grad), (
f"user provides {len(inputs_with_grad)} inputs with grad, "
f"but backward produced {len(flat_grads)} gradients"
)
# Modify grad_inputs_schema to support grad_prompts
if input_metadata.get("output_compression") is not None:
output_compression = input_metadata["output_compression"]
assert isinstance(output_compression, (list, tuple)), "output_compression must be a tuple/list"
assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
assert len(output_compression) == len(flat_grads), f"output_compression should have {len(flat_grads)} " \
f"elements, one for every tensor thar requires grad"
assert len(output_compression) == len(flat_grads), (
f"output_compression should have {len(flat_grads)} "
f"elements, one for every tensor thar requires grad"
)
else:
output_compression = tuple(runtime_pb2.NONE for _ in flat_grads)
output_compression = tuple(output_compression)

Loading…
Cancel
Save