@ -1,10 +1,18 @@
from typing import List , Optional
import torch
import torch . nn . functional as F
from hivemind . utils . logging import get_logger
from src . utils . generation_algorithms import DecodingAlgorithm , GreedyAlgorithm , NucleusAlgorithm , TopKAlgorithm
from src . utils . generation_constraints import ABCBloomConstraint , EosConstraint , MaxNewTokensConstraint
from src . utils . generation_algorithms import (
BeamSearchAlgorithm ,
DecodingAlgorithm ,
GreedyAlgorithm ,
NucleusAlgorithm ,
TopKAlgorithm ,
)
from src . utils . generation_constraints import ABCBloomConstraint , EosConstraint
logger = get_logger ( __file__ )
class RemoteGenerationMixin :
@ -13,8 +21,9 @@ class RemoteGenerationMixin:
The class exposes can be used for :
- * greedy decoding * .
- * multinomial sampling * .
- * beam - search decoding *
This class is similar to transformer ' s [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences .
This class is similar to transformer ' s [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences for remote usage .
"""
@torch.no_grad ( )
@ -25,6 +34,7 @@ class RemoteGenerationMixin:
temperature : float = 1.0 ,
top_k : Optional [ int ] = None ,
top_p : Optional [ float ] = None ,
num_beams : Optional [ int ] = 1 ,
bos_token_id : Optional [ int ] = None ,
eos_token_id : Optional [ int ] = None ,
pad_token_id : Optional [ int ] = None ,
@ -32,6 +42,7 @@ class RemoteGenerationMixin:
max_new_tokens : Optional [ int ] = None ,
decoding_algorithm : Optional [ DecodingAlgorithm ] = None ,
provided_constraints : List [ ABCBloomConstraint ] = [ ] ,
num_return_sequences : Optional [ int ] = None ,
* * model_kwargs ,
) - > torch . LongTensor :
"""
@ -42,6 +53,7 @@ class RemoteGenerationMixin:
: param temperature : The temperature to use for sampling .
: param top_k : The number of results to return .
: param top_p : The cumulative probability of results to return .
: param num_beams : The number of beams to use for beam search .
: param bos_token_id : The id of the beginning of sentence token .
: param eos_token_id : The id of the end of sentence token .
: param pad_token_id : The id of the padding token .
@ -49,6 +61,7 @@ class RemoteGenerationMixin:
: param decoding_algorithm : The decoding algorithm to use .
: param provided_constraints : A list of constraints to use .
: param model_kwargs : Additional arguments to pass to the model .
: param num_return_sequences : How many hypothesis from the beam will be in output .
"""
assert (
@ -69,6 +82,8 @@ class RemoteGenerationMixin:
pad_token_id = pad_token_id if pad_token_id is not None else self . config . pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self . config . eos_token_id
batch_size = inputs . size ( 0 )
assert ( max_length is None ) != ( max_new_tokens is None ) , " please set max_length or max_new_tokens (not both) "
if max_length is not None and max_new_tokens is None :
max_new_tokens = max_length - prefix_length
@ -78,24 +93,43 @@ class RemoteGenerationMixin:
if inputs is None :
assert bos_token_id is not None , " You have to provide a bos_token_id if you do not provide inputs "
inputs = torch . tensor ( [ [ bos_token_id ] ] )
inputs = torch . tensor ( [ [ bos_token_id ] ] * num_beams , dtype = torch . long , device = self . device )
if decoding_algorithm is None :
if do_sample :
decoding_algorithm = self . _choose_sample_algorithm ( temperature , top_k , top_p )
elif num_beams is not None and num_beams > 1 :
decoding_algorithm = BeamSearchAlgorithm ( num_beams , batch_size = batch_size )
else :
decoding_algorithm = GreedyAlgorithm ( )
if num_beams > 1 :
inputs = torch . cat ( [ inputs ] * num_beams , dim = 0 )
if batch_size > 1 :
# TODO: resolve padding problem
logger . warning (
f " You set batch_size { batch_size } within beam search generation. Be careful, results on sequences with different length may be padded wrong way "
)
if num_return_sequences is None :
num_return_sequences = 1
assert num_return_sequences < = num_beams , (
f " You want more sequences than the beam has. "
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams} . "
)
constraints = self . _get_constraints (
inputs = inputs ,
eos_token_id = eos_token_id ,
pad_token_id = pad_token_id ,
max_new_tokens = max_new_tokens ,
provided_constraints = provided_constraints ,
)
with self . transformer . h . inference_session ( max_length = max_length ) as sess :
outputs = [ ]
# Find samples with padded inputs.
# They will be changed before all of the samples have right length.
if torch . any ( inputs == pad_token_id ) : # TODO: move to prepare_inputs
outputs + = [ inputs [ : , : inputs . size ( 1 ) - ( inputs == pad_token_id ) . sum ( - 1 ) . max ( ) ] ]
else :
@ -117,19 +151,34 @@ class RemoteGenerationMixin:
for constraint in constraints :
lm_logits = constraint ( last_token_id , lm_logits , hypo_ids )
last_token_id , hypo_ids = decoding_algorithm ( lm_logits )
if seq_idx < inputs . size ( 1 ) : # TODO: why is it not a constraint?
# If some samples were padded, change only these samples
if seq_idx < inputs . size ( 1 ) :
pad_token_mask = inputs [ : , seq_idx : seq_idx + 1 ] == pad_token_id
last_token_id = ( ~ pad_token_mask ) * inputs [
: , seq_idx : seq_idx + 1
] + pad_token_mask * last_token_id
if torch . all ( last_token_id == eos_token_id ) :
break
# TODO: refactor outputs
if num_beams > 1 :
for i in range ( len ( outputs ) , 1 , - 1 ) :
outputs [ i - 1 ] = outputs [ i - 1 ] [ hypo_ids ]
outputs . append ( last_token_id )
seq_idx + = 1
if torch . all ( last_token_id == eos_token_id ) or len ( outputs ) > max_new_tokens :
break
outputs = torch . cat ( outputs , dim = - 1 )
return torch . cat ( outputs , dim = - 1 )
if num_beams > 1 :
pre_return_idx = [
torch . arange ( idx , num_return_sequences * batch_size , batch_size ) for idx in range ( batch_size )
]
return_idx = torch . cat ( pre_return_idx , dim = 0 )
outputs = outputs [ return_idx ]
return outputs
def greedy_search (
self ,
@ -198,13 +247,38 @@ class RemoteGenerationMixin:
def beam_search (
self ,
input_ids : torch . LongTensor ,
num_beams : int = 1 ,
max_length : Optional [ int ] = None ,
pad_token_id : Optional [ int ] = None ,
eos_token_id : Optional [ int ] = None ,
provided_constraints : List [ ABCBloomConstraint ] = [ ] ,
* * model_kwargs ,
) - > torch . LongTensor :
raise NotImplementedError
"""
Generates sequences of token ids for models with a language modeling head . Uses beam search .
: param input_ids : The input tokens to the model .
: param num_beams : The number of beams to use .
: param max_length : The maximum length of the sequence to generate .
: param pad_token_id : The id of the padding token .
: param eos_token_id : The id of the end of sentence token .
: param provided_constraints : A list of constraints to use .
: param : model_kwargs : Additional kwargs to pass to the model .
"""
decoding_algorithm = BeamSearchAlgorithm (
num_beams = num_beams ,
batch_size = input_ids . size ( 0 ) ,
)
return self . generate (
inputs = input_ids ,
num_beams = num_beams ,
max_new_tokens = max_length ,
pad_token_id = pad_token_id ,
eos_token_id = eos_token_id ,
decoding_algorithm = decoding_algorithm ,
provided_constraints = provided_constraints ,
* * model_kwargs ,
)
def beam_sample (
self ,
@ -246,12 +320,9 @@ class RemoteGenerationMixin:
inputs : Optional [ torch . Tensor ] = None ,
eos_token_id : Optional [ int ] = None ,
pad_token_id : Optional [ int ] = None ,
max_new_tokens : Optional [ int ] = None ,
provided_constraints : List [ ABCBloomConstraint ] = [ ] ,
) - > List [ ABCBloomConstraint ] :
constraints = [ ]
constraints . extend ( provided_constraints )
if max_new_tokens is not None :
constraints . append ( MaxNewTokensConstraint ( inputs , max_new_tokens , eos_token_id , pad_token_id ) )
constraints . append ( EosConstraint ( inputs , eos_token_id , pad_token_id ) )
return constraints