@ -16,8 +16,15 @@ from petals.server.backend import TransformerBackend
from petals . server . memory_cache import Handle
from petals . server . task_pool import PrioritizedTaskPool
from petals . server . task_prioritizer import TaskPrioritizerBase
from petals . utils . convert_block import QuantType
from petals . utils . misc import DUMMY , is_dummy
# We prioritize short inference requests and make them use a *merged* inference pool,
# so they are processed without interruptions and extra overheads
# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward
MAX_SHORT_INFERENCE_TOKENS = 128
MAX_NF4_SHORT_INFERENCE_TOKENS = 1
async def run_rpc_forward (
* flat_tensors : torch . Tensor ,
@ -127,9 +134,11 @@ async def iterate_rpc_inference(
active_adapter : Optional [ str ] ,
input_iterator : AsyncIterator [ Tuple [ runtime_pb2 . ExpertRequest , dict ] ] ,
cache_handles : Sequence [ Sequence [ Handle ] ] ,
* ,
max_length : int ,
prioritizer : TaskPrioritizerBase ,
points : int ,
quant_type : QuantType ,
) - > AsyncIterator [ Tuple [ Sequence [ runtime_pb2 . Tensor ] , bool ] ] :
assert len ( cache_handles ) == len ( requested_backends )
@ -138,6 +147,7 @@ async def iterate_rpc_inference(
async for request , step_metadata in input_iterator :
hidden_states , prompts , hypo_ids = map ( deserialize_torch_tensor , request . tensors )
batch_size , length_increment , _ = hidden_states . shape
# Cast inputs to backend dtype
hidden_states = hidden_states . to ( requested_backends [ 0 ] . dtype )
@ -154,34 +164,40 @@ async def iterate_rpc_inference(
if not ( len ( requested_backends ) == len ( prompts ) ) :
raise ValueError ( f " Received { len ( prompts ) } prompts for { len ( requested_backends ) } backends " )
length_increment = hidden_states . shape [ 1 ] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length :
raise ValueError (
f " Maximum length exceeded: prefix { prefix_length } + current { length_increment } "
f " exceeds pre-allocated maximum { max_length } "
)
merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType . NF4 else MAX_SHORT_INFERENCE_TOKENS
can_merge_pools = batch_size * length_increment < = merge_max_tokens
priority = prioritizer . prioritize (
hidden_states ,
hypo_ids ,
points = point_per_piece ,
requested_uids = requested_uids ,
type = " inference " ,
)
inference_infos = tuple (
InferenceMetadata ( uid , prefix_length , tuple ( handles ) , active_adapter )
for uid , handles in zip ( requested_uids , cache_handles )
type = " short_inference " if can_merge_pools else " inference " ,
)
if hidden_states . numel ( ) == 0 :
pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
else :
# A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache.
if hidden_states . numel ( ) > 0 :
assert hidden_states . ndim == 3 , f " hidden states must be a single 3d tensor "
( hidden_states , ) = await requested_backends [ 0 ] . inference_pool . submit_task (
hidden_states , hypo_ids , inference_infos , * prompts , priority = priority
)
if can_merge_pools :
inference_infos = tuple (
InferenceMetadata ( uid , prefix_length , tuple ( handles ) , active_adapter )
for uid , handles in zip ( requested_uids , cache_handles )
)
( hidden_states , ) = await requested_backends [ 0 ] . inference_pool . submit_task (
hidden_states , hypo_ids , inference_infos , * prompts , priority = priority
)
else :
for backend , uid , handles , prompt in zip ( requested_backends , requested_uids , cache_handles , prompts ) :
inference_infos = ( InferenceMetadata ( uid , prefix_length , tuple ( handles ) , active_adapter ) , )
( hidden_states , ) = await backend . inference_pool . submit_task (
hidden_states , hypo_ids , inference_infos , prompt , priority = priority
)
# serialize and send last layer outputs
output_tensors = [