@ -1,8 +1,10 @@
import contextlib
from typing import List , Optional
import torch
from hivemind . utils . logging import get_logger
from petals . client . inference_session import InferenceSession
from petals . utils . generation_algorithms import (
BeamSearchAlgorithm ,
DecodingAlgorithm ,
@ -23,9 +25,20 @@ class RemoteGenerationMixin:
- * multinomial sampling * .
- * beam - search decoding *
This class is similar to transformer ' s [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences for remote usage.
This class is similar to transformer ' s [`generation_utils.GenerationMixin`], it can be used instead of it.
However , it has some differences for remote usage .
"""
def inference_session ( self , * * kwargs ) - > InferenceSession :
"""
Returns an inference session for the model ' s RemoteSequential module.
: param max_length : Maximal expected length of inference results . Servers use this parameter
to calculate the size of attention caches allocated to this client .
"""
return self . transformer . h . inference_session ( * * kwargs )
@torch.no_grad ( )
def generate (
self ,
@ -43,6 +56,8 @@ class RemoteGenerationMixin:
decoding_algorithm : Optional [ DecodingAlgorithm ] = None ,
provided_constraints : List [ ABCBloomConstraint ] = [ ] ,
num_return_sequences : Optional [ int ] = None ,
* ,
session : Optional [ InferenceSession ] = None ,
* * model_kwargs ,
) - > torch . LongTensor :
"""
@ -74,8 +89,6 @@ class RemoteGenerationMixin:
assert (
model_kwargs . get ( " stopping_criteria " , None ) is None
) , " For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria "
if inputs is not None :
assert isinstance ( inputs , torch . Tensor ) and inputs . ndim == 2 , " inputs must be a 2d tensor [batch, length] "
prefix_length = 0 if inputs is None else inputs . size ( 1 )
prefix_length + = self . config . pre_seq_len
@ -83,8 +96,6 @@ class RemoteGenerationMixin:
pad_token_id = pad_token_id if pad_token_id is not None else self . config . pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self . config . eos_token_id
batch_size = inputs . size ( 0 )
assert ( max_length is None ) != ( max_new_tokens is None ) , " please set max_length or max_new_tokens (not both) "
if max_length is not None and max_new_tokens is None :
max_new_tokens = max_length - prefix_length
@ -92,9 +103,22 @@ class RemoteGenerationMixin:
elif max_length is None and max_new_tokens is not None :
max_length = prefix_length + max_new_tokens
if inputs is None :
assert bos_token_id is not None , " You have to provide a bos_token_id if you do not provide inputs "
inputs = torch . tensor ( [ [ bos_token_id ] ] * num_beams , dtype = torch . long , device = self . device )
if num_beams > 1 and session is not None :
raise NotImplementedError (
" Reusing inference session in .generate() along with beam search is not supported yet "
)
if inputs is not None :
assert isinstance ( inputs , torch . Tensor ) and inputs . ndim == 2 , " inputs must be a 2d tensor [batch, length] "
if session is not None and session . last_token_id is not None :
inputs = torch . cat ( [ session . last_token_id , inputs ] , dim = 1 )
else :
if session is not None and session . last_token_id is not None :
inputs = session . last_token_id
else :
assert bos_token_id is not None , " You have to provide a bos_token_id if you do not provide inputs "
inputs = torch . tensor ( [ [ bos_token_id ] ] * num_beams , dtype = torch . long , device = self . device )
batch_size = inputs . size ( 0 )
if decoding_algorithm is None :
if do_sample :
@ -109,7 +133,8 @@ class RemoteGenerationMixin:
if batch_size > 1 :
# TODO: resolve padding problem
logger . warning (
f " You set batch_size { batch_size } within beam search generation. Be careful, results on sequences with different length may be padded wrong way "
f " You set batch_size { batch_size } within beam search generation. "
f " Be careful, results on sequences with different length may be padded wrong way "
)
if num_return_sequences is None :
@ -127,7 +152,11 @@ class RemoteGenerationMixin:
provided_constraints = provided_constraints ,
)
with self . transformer . h . inference_session ( max_length = max_length ) as sess :
if session is None :
context_manager = self . inference_session ( max_length = max_length )
else :
context_manager = contextlib . nullcontext ( session ) # Doesn't actually enter session or exit from it
with context_manager as session :
outputs = [ ]
# Find samples with padded inputs.
# They will be changed before all of the samples have right length.
@ -145,7 +174,7 @@ class RemoteGenerationMixin:
prompts , intermediate_prompts = self . transformer . get_prompt ( embs . size ( 0 ) )
embs = torch . cat ( [ prompts , embs ] , dim = 1 )
embs = self . transformer . word_embeddings_layernorm ( embs )
hidden_state = sess . step ( embs , prompts = intermediate_prompts , hypo_ids = hypo_ids ) [ : , - 1 ]
hidden_state = sess ion . step ( embs , prompts = intermediate_prompts , hypo_ids = hypo_ids ) [ : , - 1 ]
hidden_state = self . transformer . ln_f ( hidden_state )
lm_logits = self . lm_head ( hidden_state )
@ -166,6 +195,7 @@ class RemoteGenerationMixin:
outputs [ i - 1 ] = outputs [ i - 1 ] [ hypo_ids ]
outputs . append ( last_token_id )
session . last_token_id = last_token_id
seq_idx + = 1
if torch . all ( last_token_id == eos_token_id ) or len ( outputs ) > max_new_tokens :
break