|
|
|
@ -18,7 +18,7 @@ from petals.server.task_pool import PrioritizedTaskPool
|
|
|
|
|
from petals.server.task_prioritizer import TaskPrioritizerBase
|
|
|
|
|
from petals.utils.convert_block import QuantType
|
|
|
|
|
from petals.utils.misc import DUMMY, is_dummy
|
|
|
|
|
from petals.utils.packaging import unpack_args_kwargs, pack_args_kwargs
|
|
|
|
|
from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
|
|
|
|
|
|
|
|
|
|
# We prioritize short inference requests and make them use a *merged* inference pool,
|
|
|
|
|
# so they are processed without interruptions and extra overheads
|
|
|
|
@ -88,7 +88,9 @@ async def run_rpc_backward(
|
|
|
|
|
points: int = 0,
|
|
|
|
|
args_structure: Any,
|
|
|
|
|
) -> Tuple[Sequence[torch.Tensor], Any]:
|
|
|
|
|
(hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
|
|
|
|
|
(hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(
|
|
|
|
|
requested_backends, flat_tensors, args_structure
|
|
|
|
|
)
|
|
|
|
|
# Cast inputs & grad outputs to backend dtype
|
|
|
|
|
assert hidden_states.ndim == 3
|
|
|
|
|
num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
|
|
|
|
@ -166,7 +168,9 @@ async def iterate_rpc_inference(
|
|
|
|
|
|
|
|
|
|
async for request, step_metadata in input_iterator:
|
|
|
|
|
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
|
|
|
|
|
(hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
|
|
|
|
|
(hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(
|
|
|
|
|
requested_backends, flat_tensors, args_structure
|
|
|
|
|
)
|
|
|
|
|
batch_size, length_increment, _ = hidden_states.shape
|
|
|
|
|
num_tokens = batch_size * length_increment
|
|
|
|
|
|
|
|
|
|