wip (again)

partial_rollback
Your Name 10 months ago
parent 65e87395bc
commit 13c13d347a

@ -35,7 +35,7 @@ async def run_rpc_forward(
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
structure: Any,
args_structure: Any,
) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
@ -45,7 +45,7 @@ async def run_rpc_forward(
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
(hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure)
(hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
@ -247,10 +247,10 @@ async def iterate_rpc_inference(
def _check_inputs(
requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], structure: Any
requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], args_structure: Any
):
if structure is not None:
args, *backend_kwargs = unpack_args_kwargs(flat_tensors, structure)
if args_structure is not None:
args, *backend_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
else:
args, *backend_kwargs = flat_tensors, {} # backward compatibility

@ -368,7 +368,7 @@ class TransformerConnectionHandler(ConnectionHandler):
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
structure=args_structure,
args_structure=args_structure,
)
return runtime_pb2.ExpertResponse(
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
@ -397,7 +397,7 @@ class TransformerConnectionHandler(ConnectionHandler):
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
structure=args_structure,
args_structure=args_structure,
)
# Split the serialized_output for streaming and respond to client

Loading…
Cancel
Save