@ -1,6 +1,6 @@
import asyncio
import contextlib
from typing import A syncIterator, Dict , Iterable , List , Sequence , Tuple , Union
from typing import A ny, A syncIterator, Dict , Iterable , List , Sequence , Tuple , Union
import torch
from async_timeout import timeout
@ -202,14 +202,8 @@ class TransformerConnectionHandler(ConnectionHandler):
hidden_states = await _rpc_forward (
* flat_inputs , requested_backends = requested_backends , prioritizer = self . _prioritizer , points = points
)
assert isinstance ( hidden_states , torch . Tensor ) and hidden_states . ndim == 3
# Serialize output and respond to client
return runtime_pb2 . ExpertResponse (
tensors = [
serialize_torch_tensor ( result . to ( proto . dtype ) , proto . compression , allow_inplace = True )
for result , proto in zip ( ( hidden_states , ) , nested_flatten ( requested_backends [ - 1 ] . outputs_schema ) )
]
tensors = self . _serialize_outputs ( hidden_states , requested_backends , metadata )
)
async def rpc_forward_stream (
@ -230,22 +224,34 @@ class TransformerConnectionHandler(ConnectionHandler):
hidden_states = await _rpc_forward (
* flat_inputs , requested_backends = requested_backends , prioritizer = self . _prioritizer , points = points
)
assert (
isinstance ( hidden_states , torch . Tensor ) and hidden_states . ndim == 3
) , " hidden_states must be a 3d tensor "
# Serialize the overall output
serialized_output = [
serialize_torch_tensor ( result . to ( proto . dtype ) , proto . compression , allow_inplace = True )
for result , proto in zip ( ( hidden_states , ) , nested_flatten ( requested_backends [ - 1 ] . outputs_schema ) )
]
# Split the serialized_output for streaming and respond to client
output_split = [
part for tensor in serialized_output for part in split_for_streaming ( tensor , DEFAULT_MAX_MSG_SIZE )
]
async for part in as_aiter ( * output_split ) :
yield runtime_pb2 . ExpertResponse ( tensors = [ part ] )
for tensor in self . _serialize_outputs ( hidden_states , requested_backends , metadata ) :
for part in split_for_streaming ( tensor , DEFAULT_MAX_MSG_SIZE ) :
yield runtime_pb2 . ExpertResponse ( tensors = [ part ] )
def _serialize_outputs (
self ,
hidden_states : torch . Tensor ,
requested_backends : Sequence [ TransformerBackend ] ,
metadata : Dict [ str , Any ] ,
) - > Sequence [ runtime_pb2 . Tensor ] :
""" Serialize forward outputs using either outputs_schema or custom user-specified schema """
assert isinstance ( hidden_states , torch . Tensor ) and hidden_states . ndim == 3 , " hidden_states must be a 3d tensor "
outputs_schema = requested_backends [ - 1 ] . outputs_schema
if metadata . get ( " output_compression " ) is not None :
assert isinstance ( metadata [ " output_compression " ] , ( list , tuple ) ) , " output_compression must be a tuple/list "
output_compression = tuple ( metadata [ " output_compression " ] )
assert all ( isinstance ( c , int ) for c in output_compression ) , " output_compression must contain integers "
assert len ( output_compression ) == 1 , f " output_compression tuple should have 1 element "
else :
output_compression = tuple ( tensor . compression for tensor in outputs_schema )
return [
serialize_torch_tensor ( result . to ( proto . dtype ) , compression , allow_inplace = True )
for result , proto , compression in zip ( [ hidden_states ] , outputs_schema , output_compression )
]
async def rpc_backward ( self , request : runtime_pb2 . ExpertRequest , context : P2PContext ) - > runtime_pb2 . ExpertResponse :
async with timeout ( self . request_timeout ) :
@ -265,21 +271,7 @@ class TransformerConnectionHandler(ConnectionHandler):
* flat_tensors , requested_backends = requested_backends , prioritizer = self . _prioritizer , points = points
)
# Modify grad_inputs_schema to support grad_prompts
assert len ( requested_backends [ 0 ] . args_schema ) == 1 and len ( grads ) in ( 1 , 2 ) # TODO generalize
grad_inputs_schema_with_prompts = (
requested_backends [ 0 ] . args_schema * len ( grads ) ,
requested_backends [ 0 ] . kwargs_schema ,
) # TODO generalize
# Serialize the overall grad_input and respond
return runtime_pb2 . ExpertResponse (
tensors = [
serialize_torch_tensor ( result . to ( proto . dtype ) , proto . compression , allow_inplace = True )
for result , proto in zip ( grads , nested_flatten ( grad_inputs_schema_with_prompts ) )
]
)
return runtime_pb2 . ExpertResponse ( tensors = self . _serialize_grads ( grads , requested_backends , metadata ) )
async def rpc_backward_stream (
self , requests : AsyncIterator [ runtime_pb2 . ExpertRequest ] , context : P2PContext
@ -298,28 +290,38 @@ class TransformerConnectionHandler(ConnectionHandler):
grads = await _rpc_backward (
* flat_tensors , requested_backends = requested_backends , prioritizer = self . _prioritizer , points = points
)
# Modify grad_inputs_schema to support grad_prompts
assert len ( requested_backends [ 0 ] . args_schema ) == 1 and len ( grads ) in ( 1 , 2 ) # TODO generalize
grad_inputs_schema_with_prompts = (
requested_backends [ 0 ] . args_schema * len ( grads ) ,
requested_backends [ 0 ] . kwargs_schema ,
) # TODO generalize
# Serialize the overall grad_inputs
serialized_grad_inputs = [
serialize_torch_tensor ( result . to ( proto . dtype ) , proto . compression , allow_inplace = True )
for result , proto in zip ( grads , nested_flatten ( grad_inputs_schema_with_prompts ) )
]
# Split the serialized_grad_inputs for streaming and respond
output_split = [
part for tensor in serialized_grad_inputs for part in split_for_streaming ( tensor , DEFAULT_MAX_MSG_SIZE )
]
for tensor in self . _serialize_grads ( grads , requested_backends , metadata ) :
for part in split_for_streaming ( tensor , DEFAULT_MAX_MSG_SIZE ) :
yield runtime_pb2 . ExpertResponse ( tensors = [ part ] )
async for part in as_aiter ( * output_split ) :
yield runtime_pb2 . ExpertResponse ( tensors = [ part ] )
def _check_uids ( self , uids : str ) - > Sequence [ ModuleUID ] :
def _serialize_grads (
self ,
grads : Sequence [ torch . Tensor ] ,
requested_backends : Sequence [ TransformerBackend ] ,
metadata : Dict [ str , Any ] ,
) - > Sequence [ runtime_pb2 . Tensor ] :
""" Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema """
# Modify grad_inputs_schema to support grad_prompts
assert len ( requested_backends [ 0 ] . args_schema ) == 1 and len ( grads ) in ( 1 , 2 ) # TODO generalize
flat_grads_schema = tuple (
nested_flatten ( ( requested_backends [ 0 ] . args_schema * len ( grads ) , requested_backends [ 0 ] . kwargs_schema ) )
) # TODO generalize
if metadata . get ( " output_compression " ) is not None :
assert isinstance ( metadata [ " output_compression " ] , ( list , tuple ) ) , " output_compression must be a tuple/list "
output_compression = tuple ( metadata [ " output_compression " ] )
assert all ( isinstance ( c , int ) for c in output_compression ) , " output_compression must contain integers "
assert len ( output_compression ) == len ( grads ) , f " output_compression should have { len ( grads ) } elements "
else :
output_compression = tuple ( tensor . compression for tensor in flat_grads_schema )
return [
serialize_torch_tensor ( result . to ( proto . dtype ) , compression , allow_inplace = True )
for result , proto , compression in zip ( grads , flat_grads_schema , output_compression )
]
def _check_uids ( self , uids : str ) - > Tuple [ ModuleUID , . . . ] :
""" Check that the first request to rpc_inference is valid """
uids = ( uids or " " ) . split ( CHAIN_DELIMITER )
if not uids :
@ -360,7 +362,7 @@ class TransformerConnectionHandler(ConnectionHandler):
yield handles
def _log_request ( self , method : str , uids : List [ ModuleUID ] , context : P2PContext ) - > None :
def _log_request ( self , method : str , uids : Sequence [ ModuleUID ] , context : P2PContext ) - > None :
friendly_uids = [ uid . split ( " . " ) [ - 1 ] for uid in uids if " . " in uid ]
friendly_uids = [ int ( uid ) for uid in friendly_uids if uid . isdigit ( ) ]
friendly_uids = f " { min ( friendly_uids ) } : { max ( friendly_uids ) + 1 } " if friendly_uids else uids