|
|
|
@ -35,7 +35,7 @@ async def run_rpc_forward(
|
|
|
|
|
active_adapter: str = "",
|
|
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
|
|
points: int = 0,
|
|
|
|
|
structure: Any,
|
|
|
|
|
args_structure: Any,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
|
|
|
@ -45,7 +45,7 @@ async def run_rpc_forward(
|
|
|
|
|
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
|
|
|
|
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
|
|
|
|
"""
|
|
|
|
|
(hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure)
|
|
|
|
|
(hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
|
|
|
|
|
dtype = requested_backends[0].dtype
|
|
|
|
|
# check parse input tensors and cast dtypes
|
|
|
|
|
hidden_states = hidden_states.to(dtype)
|
|
|
|
@ -247,10 +247,10 @@ async def iterate_rpc_inference(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_inputs(
|
|
|
|
|
requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], structure: Any
|
|
|
|
|
requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], args_structure: Any
|
|
|
|
|
):
|
|
|
|
|
if structure is not None:
|
|
|
|
|
args, *backend_kwargs = unpack_args_kwargs(flat_tensors, structure)
|
|
|
|
|
if args_structure is not None:
|
|
|
|
|
args, *backend_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
|
|
|
|
else:
|
|
|
|
|
args, *backend_kwargs = flat_tensors, {} # backward compatibility
|
|
|
|
|
|
|
|
|
|