|
|
|
@ -76,22 +76,22 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
requested_uids = self._check_header(request)
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
|
|
|
|
# Run a chain of requested backends
|
|
|
|
|
# Run a chain of requested backends
|
|
|
|
|
for backend in requested_backends:
|
|
|
|
|
assert isinstance(hidden_states, (list, tuple))
|
|
|
|
|
assert (
|
|
|
|
|
len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
|
|
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
|
|
hidden_states = await backend.forward_pool.submit_task(*hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Serialize the overall output and respond
|
|
|
|
|
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
|
|
return runtime_pb2.ExpertResponse(tensors=[
|
|
|
|
|
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
|
|
for result, proto in zip(
|
|
|
|
|
hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
|
|
|
|
|
)
|
|
|
|
|
])
|
|
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
|
|
tensors=[
|
|
|
|
|
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
|
|
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def rpc_forward_stream(
|
|
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
|
@ -101,48 +101,41 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
requested_uids = self._check_header_str(uids_header)
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
|
|
|
|
# Run a chain of requested backends
|
|
|
|
|
# Run a chain of requested backends
|
|
|
|
|
for backend in requested_backends:
|
|
|
|
|
assert isinstance(hidden_states, (list, tuple))
|
|
|
|
|
assert (
|
|
|
|
|
len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
|
|
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
|
|
hidden_states = await backend.forward_pool.submit_task(*hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Serialize the overall output
|
|
|
|
|
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
|
|
serialized_output = [
|
|
|
|
|
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
|
|
for result, proto in zip(
|
|
|
|
|
hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
|
|
|
|
|
)
|
|
|
|
|
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Split the serialized_output for streaming and respond
|
|
|
|
|
output_split = [
|
|
|
|
|
part
|
|
|
|
|
for tensor in serialized_output
|
|
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
|
|
part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
|
|
]
|
|
|
|
|
async for part in as_aiter(*output_split):
|
|
|
|
|
yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
|
|
|
|
|
|
|
async def rpc_backward(
|
|
|
|
|
self, request: runtime_pb2.ExpertRequest, context: P2PContext
|
|
|
|
|
) -> runtime_pb2.ExpertResponse:
|
|
|
|
|
async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
|
|
# Parse requests and prepare backends
|
|
|
|
|
inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
|
|
requested_uids = self._check_header(request)
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
|
|
|
|
# Run a forward chain to collect intermediate inputs
|
|
|
|
|
# Note that we do not forward for the last module since we do not need its output
|
|
|
|
|
# Note that we do not forward for the last module since we do not need its output
|
|
|
|
|
inter_inputs = [inputs]
|
|
|
|
|
for backend in requested_backends[:-1]:
|
|
|
|
|
assert (inputs.ndim == 3
|
|
|
|
|
), f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
|
|
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
|
|
inputs = await backend.forward_pool.submit_task(inputs)
|
|
|
|
|
assert (isinstance(inputs, (list, tuple)) and len(inputs) == 1)
|
|
|
|
|
assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
|
|
|
|
|
inputs = inputs[0]
|
|
|
|
|
inter_inputs.append(inputs)
|
|
|
|
|
|
|
|
|
@ -150,16 +143,16 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
|
|
|
|
|
inputs_and_grads = [inp, grads]
|
|
|
|
|
grads = await backend.backward_pool.submit_task(*inputs_and_grads)
|
|
|
|
|
assert (isinstance(grads, (list, tuple)) and len(grads) == 1)
|
|
|
|
|
assert isinstance(grads, (list, tuple)) and len(grads) == 1
|
|
|
|
|
grads = grads[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Serialize the overall grad_input and respond
|
|
|
|
|
return runtime_pb2.ExpertResponse(tensors=[
|
|
|
|
|
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
|
|
for result, proto in zip(
|
|
|
|
|
[grads], nested_flatten(requested_backends[0].grad_inputs_schema)
|
|
|
|
|
)
|
|
|
|
|
])
|
|
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
|
|
tensors=[
|
|
|
|
|
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
|
|
for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def rpc_backward_stream(
|
|
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
|
@ -170,35 +163,30 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
|
|
|
|
# Run a forward chain to collect intermediate inputs
|
|
|
|
|
# Note that we do not forward for the last module since we do not need its outputs
|
|
|
|
|
# Note that we do not forward for the last module since we do not need its outputs
|
|
|
|
|
inter_inputs = [inputs]
|
|
|
|
|
for backend in requested_backends[:-1]:
|
|
|
|
|
assert (inputs.ndim == 3
|
|
|
|
|
), f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
|
|
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
|
|
inputs = await backend.forward_pool.submit_task(inputs)
|
|
|
|
|
assert (isinstance(inputs, (list, tuple)) and len(inputs) == 1)
|
|
|
|
|
assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
|
|
|
|
|
inputs = inputs[0]
|
|
|
|
|
inter_inputs.append(inputs)
|
|
|
|
|
|
|
|
|
|
# Run a backward chain for requested backends
|
|
|
|
|
# Run a backward chain for requested backends
|
|
|
|
|
for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
|
|
|
|
|
inputs_and_grads = [inp, grads]
|
|
|
|
|
grads = await backend.backward_pool.submit_task(*inputs_and_grads)
|
|
|
|
|
assert (isinstance(grads, (list, tuple)) and len(grads) == 1)
|
|
|
|
|
assert isinstance(grads, (list, tuple)) and len(grads) == 1
|
|
|
|
|
grads = grads[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Serialize the overall grad_inputs
|
|
|
|
|
serialized_grad_inputs = [
|
|
|
|
|
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
|
|
for result, proto in zip(
|
|
|
|
|
[grads], nested_flatten(requested_backends[0].grad_inputs_schema)
|
|
|
|
|
)
|
|
|
|
|
for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
|
|
|
|
|
]
|
|
|
|
|
# Split the serialized_grad_inputs for streaming and respond
|
|
|
|
|
output_split = [
|
|
|
|
|
part
|
|
|
|
|
for tensor in serialized_grad_inputs
|
|
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
|
|
part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
async for part in as_aiter(*output_split):
|
|
|
|
|