|
|
|
@ -18,7 +18,7 @@ from petals.server.task_pool import PrioritizedTaskPool
|
|
|
|
|
from petals.server.task_prioritizer import TaskPrioritizerBase
|
|
|
|
|
from petals.utils.convert_block import QuantType
|
|
|
|
|
from petals.utils.misc import DUMMY, is_dummy
|
|
|
|
|
from petals.utils.packaging import unpack_args_kwargs
|
|
|
|
|
from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
|
|
|
|
|
|
|
|
|
|
# We prioritize short inference requests and make them use a *merged* inference pool,
|
|
|
|
|
# so they are processed without interruptions and extra overheads
|
|
|
|
@ -31,36 +31,40 @@ logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
async def run_rpc_forward(
|
|
|
|
|
*flat_tensors: torch.Tensor,
|
|
|
|
|
args_structure: Any,
|
|
|
|
|
requested_backends: Sequence[TransformerBackend],
|
|
|
|
|
active_adapter: str = "",
|
|
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
|
|
points: int = 0,
|
|
|
|
|
args_structure: Any = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
|
|
|
|
|
|
|
|
|
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
|
|
|
|
|
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
|
|
|
|
|
:param args_structure: a schema that defines which of flat_tensors corresponds to which arg / kwarg
|
|
|
|
|
:note: see pack_args_kwargs function for the definition of args_structure
|
|
|
|
|
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
|
|
|
|
:param active_adapter: the name of LoRA adapter to use; defaults to no adapter
|
|
|
|
|
:param prioritizer: assigns priorities to each sub-request based on the number of points
|
|
|
|
|
:param points: client-specified number of points, used to assign priorities
|
|
|
|
|
:param args_structure:
|
|
|
|
|
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
|
|
|
|
"""
|
|
|
|
|
if args_structure is not None:
|
|
|
|
|
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
|
|
|
|
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
|
|
|
|
hidden_states, prompts, *_ = flat_tensors
|
|
|
|
|
|
|
|
|
|
requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
|
|
|
|
|
flat_tensors = tuple(tensor.detach() for tensor in flat_tensors)
|
|
|
|
|
(hidden_states, prompts), block_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
|
|
|
|
|
dtype = requested_backends[0].dtype
|
|
|
|
|
# check parse input tensors and cast dtypes
|
|
|
|
|
hidden_states = hidden_states.to(dtype)
|
|
|
|
|
assert hidden_states.ndim == 3
|
|
|
|
|
num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
|
|
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
|
|
prompts = [DUMMY] * len(requested_backends)
|
|
|
|
|
else:
|
|
|
|
|
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
|
|
|
|
|
|
|
# Run a chain of requested backends
|
|
|
|
|
for backend, prompt in zip(requested_backends, prompts):
|
|
|
|
|
for backend, prompt, kwargs in zip(requested_backends, prompts, block_kwargs):
|
|
|
|
|
if not is_dummy(prompt):
|
|
|
|
|
hidden_states[:, : prompt.shape[1]] += prompt
|
|
|
|
|
|
|
|
|
@ -69,16 +73,18 @@ async def run_rpc_forward(
|
|
|
|
|
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
|
|
|
|
|
)
|
|
|
|
|
(hidden_states,) = await backend.forward_pool.submit_task(
|
|
|
|
|
hidden_states,
|
|
|
|
|
active_adapter,
|
|
|
|
|
hidden_states,
|
|
|
|
|
**kwargs,
|
|
|
|
|
priority=priority,
|
|
|
|
|
size=num_tokens,
|
|
|
|
|
)
|
|
|
|
|
assert isinstance(hidden_states, torch.Tensor)
|
|
|
|
|
assert (
|
|
|
|
|
hidden_states.ndim == 3
|
|
|
|
|
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
|
|
|
|
|
|
|
return hidden_states
|
|
|
|
|
return hidden_states.requires_grad_(requires_grad)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def run_rpc_backward(
|
|
|
|
@ -87,58 +93,70 @@ async def run_rpc_backward(
|
|
|
|
|
active_adapter: str = "",
|
|
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
|
|
points: int = 0,
|
|
|
|
|
args_structure: Any = None,
|
|
|
|
|
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
|
|
|
if args_structure is not None:
|
|
|
|
|
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
|
|
|
|
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
|
|
|
|
inputs, grad_outputs, prompts, *_ = flat_tensors
|
|
|
|
|
args_structure: Any,
|
|
|
|
|
) -> Tuple[Sequence[torch.Tensor], Any]:
|
|
|
|
|
"""A custom backward pass used by the server to service rpc_backward and rpc_backward_stream requests"""
|
|
|
|
|
assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
|
|
|
|
|
((grad_outputs,), hidden_states, prompts), block_kwargs = _check_inputs(
|
|
|
|
|
requested_backends, flat_tensors, args_structure
|
|
|
|
|
)
|
|
|
|
|
input_requires_grad, prompts_requires_grad = hidden_states.requires_grad, prompts.requires_grad
|
|
|
|
|
|
|
|
|
|
# Cast inputs & grad outputs to backend dtype
|
|
|
|
|
inputs = inputs.to(requested_backends[0].dtype)
|
|
|
|
|
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
|
|
|
|
num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
|
|
|
|
|
hidden_states = hidden_states.detach().to(requested_backends[0].dtype)
|
|
|
|
|
grad_outputs = grad_outputs.detach().to(requested_backends[-1].dtype)
|
|
|
|
|
|
|
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
|
|
prompts = [DUMMY] * len(requested_backends)
|
|
|
|
|
else:
|
|
|
|
|
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
|
|
prompts = [p.squeeze(0).detach() for p in prompts.detach().to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
|
|
|
|
|
|
|
# Run a forward chain to collect intermediate inputs
|
|
|
|
|
# Note that we do not forward for the last module since we do not need its output
|
|
|
|
|
inter_inputs = []
|
|
|
|
|
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
|
|
|
|
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
|
|
for backend, prompt, kwargs in zip(requested_backends[:-1], prompts[:-1], block_kwargs):
|
|
|
|
|
assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
|
|
if not is_dummy(prompt):
|
|
|
|
|
inputs[:, : prompt.shape[1]] += prompt
|
|
|
|
|
inter_inputs.append(inputs)
|
|
|
|
|
hidden_states[:, : prompt.shape[1]] += prompt
|
|
|
|
|
inter_inputs.append(hidden_states)
|
|
|
|
|
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
|
|
|
priority = prioritizer.prioritize(
|
|
|
|
|
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
|
|
|
|
hidden_states, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
|
|
|
|
)
|
|
|
|
|
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
|
|
|
|
|
|
|
|
|
|
assert isinstance(inputs, torch.Tensor)
|
|
|
|
|
(hidden_states,) = await backend.forward_pool.submit_task(
|
|
|
|
|
active_adapter, hidden_states, **kwargs, priority=priority, size=num_tokens
|
|
|
|
|
)
|
|
|
|
|
assert isinstance(hidden_states, torch.Tensor), "intermediate hidden states is not a tensor"
|
|
|
|
|
|
|
|
|
|
if not is_dummy(prompts[-1]):
|
|
|
|
|
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
|
|
|
|
|
inter_inputs.append(inputs)
|
|
|
|
|
hidden_states[:, : prompts[-1].shape[1]] += prompts[-1]
|
|
|
|
|
inter_inputs.append(hidden_states)
|
|
|
|
|
|
|
|
|
|
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
|
|
|
|
|
grad_prompts_reversed = []
|
|
|
|
|
grad_block_kwargs_reversed = []
|
|
|
|
|
|
|
|
|
|
# Run a chain of requested backends
|
|
|
|
|
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
|
|
|
|
|
for hidden_states, prompt, backend, kwargs in reversed(
|
|
|
|
|
list(zip(inter_inputs, prompts, requested_backends, block_kwargs))
|
|
|
|
|
):
|
|
|
|
|
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
|
|
|
hidden_states = hidden_states.detach().requires_grad_(True)
|
|
|
|
|
priority = prioritizer.prioritize(
|
|
|
|
|
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
|
|
|
|
hidden_states, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
|
|
|
|
)
|
|
|
|
|
(grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task(
|
|
|
|
|
active_adapter, grad_outputs, hidden_states, **kwargs, priority=priority, size=num_tokens
|
|
|
|
|
)
|
|
|
|
|
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
|
|
|
|
|
|
|
|
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
|
|
|
if not is_dummy(prompt):
|
|
|
|
|
if not is_dummy(prompt) and prompts_requires_grad:
|
|
|
|
|
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
|
|
|
|
grad_block_kwargs_reversed.append(grad_kwargs)
|
|
|
|
|
|
|
|
|
|
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
|
|
|
|
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
|
|
|
|
|
grad_args = [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]
|
|
|
|
|
return pack_args_kwargs((grad_args, list(reversed(grad_block_kwargs_reversed))))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def iterate_rpc_inference(
|
|
|
|
@ -161,12 +179,11 @@ async def iterate_rpc_inference(
|
|
|
|
|
|
|
|
|
|
async for request, step_metadata in input_iterator:
|
|
|
|
|
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
|
|
|
|
|
if args_structure is not None:
|
|
|
|
|
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
|
|
|
|
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
|
|
|
|
|
|
|
|
|
hidden_states, prompts, hypo_ids, *_ = flat_tensors
|
|
|
|
|
(hidden_states, prompts, hypo_ids), block_kwargs = _check_inputs(
|
|
|
|
|
requested_backends, flat_tensors, args_structure
|
|
|
|
|
)
|
|
|
|
|
batch_size, length_increment, _ = hidden_states.shape
|
|
|
|
|
num_tokens = batch_size * length_increment
|
|
|
|
|
|
|
|
|
|
# Cast inputs to backend dtype
|
|
|
|
|
hidden_states = hidden_states.to(requested_backends[0].dtype)
|
|
|
|
@ -209,13 +226,27 @@ async def iterate_rpc_inference(
|
|
|
|
|
for uid, handles in zip(requested_uids, cache_handles)
|
|
|
|
|
)
|
|
|
|
|
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
|
|
|
|
|
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
|
|
|
|
|
hidden_states,
|
|
|
|
|
hypo_ids,
|
|
|
|
|
inference_infos,
|
|
|
|
|
*prompts,
|
|
|
|
|
block_kwargs=block_kwargs,
|
|
|
|
|
priority=priority,
|
|
|
|
|
size=num_tokens,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
|
|
|
|
|
for backend, uid, handles, prompt, kwargs in zip(
|
|
|
|
|
requested_backends, requested_uids, cache_handles, prompts, block_kwargs
|
|
|
|
|
):
|
|
|
|
|
inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
|
|
|
|
|
(hidden_states,) = await backend.inference_pool.submit_task(
|
|
|
|
|
hidden_states, hypo_ids, inference_infos, prompt, priority=priority
|
|
|
|
|
hidden_states,
|
|
|
|
|
hypo_ids,
|
|
|
|
|
inference_infos,
|
|
|
|
|
prompt,
|
|
|
|
|
block_kwargs=(kwargs,),
|
|
|
|
|
priority=priority,
|
|
|
|
|
size=num_tokens,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# serialize and send last layer outputs
|
|
|
|
@ -228,3 +259,29 @@ async def iterate_rpc_inference(
|
|
|
|
|
|
|
|
|
|
# prepare for next step
|
|
|
|
|
prefix_length += length_increment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_inputs(
|
|
|
|
|
requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], args_structure: Any
|
|
|
|
|
):
|
|
|
|
|
if len(flat_tensors) == 3: # backward compatibility for rpc_backward, remove after 2.3
|
|
|
|
|
if flat_tensors[0].requires_grad and not flat_tensors[1].requires_grad:
|
|
|
|
|
hidden_states, grad_outputs, prompts = flat_tensors
|
|
|
|
|
flat_tensors = grad_outputs, hidden_states, prompts
|
|
|
|
|
if args_structure is not None:
|
|
|
|
|
args, *block_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
|
|
|
|
else:
|
|
|
|
|
args, *block_kwargs = flat_tensors, {} # backward compatibility for grad structure, remove at 2.2
|
|
|
|
|
|
|
|
|
|
if len(block_kwargs) not in (1, len(requested_backends)):
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Server expected either one dict of keyword arguments or {len(requested_backends)} dicts "
|
|
|
|
|
f"(one for each block). Found {len(block_kwargs)} instead."
|
|
|
|
|
)
|
|
|
|
|
if len(block_kwargs) == 1:
|
|
|
|
|
block_kwargs = block_kwargs * len(requested_backends)
|
|
|
|
|
assert len(block_kwargs) == len(requested_backends)
|
|
|
|
|
for i, kwargs in enumerate(block_kwargs):
|
|
|
|
|
if not isinstance(kwargs, dict):
|
|
|
|
|
raise RuntimeError(f"Expected kwargs for block {i} to be a dictionary, got {type(kwargs)}")
|
|
|
|
|
return args, block_kwargs
|
|
|
|
|