@ -44,6 +44,7 @@ class RemoteGenerationMixin:
def generate (
self ,
inputs : Optional [ torch . Tensor ] = None ,
* ,
do_sample : Optional [ bool ] = None ,
temperature : float = 1.0 ,
top_k : Optional [ int ] = None ,
@ -57,9 +58,7 @@ class RemoteGenerationMixin:
decoding_algorithm : Optional [ DecodingAlgorithm ] = None ,
provided_constraints : List [ ABCBloomConstraint ] = [ ] ,
num_return_sequences : Optional [ int ] = None ,
* ,
session : Optional [ InferenceSession ] = None ,
* * model_kwargs ,
) - > torch . LongTensor :
"""
Generates sequences of token ids for models with a language modeling head .
@ -77,19 +76,9 @@ class RemoteGenerationMixin:
: param max_new_tokens : The maximum number of tokens to generate .
: param decoding_algorithm : The decoding algorithm to use .
: param provided_constraints : A list of constraints to use .
: param model_kwargs : Additional arguments to pass to the model .
: param num_return_sequences : How many hypothesis from the beam will be in output .
"""
assert (
model_kwargs . get ( " logits_processor " , None ) is None
) , " For RemoteGenerationMixin models use BloomConstraints instead of logits_processor "
assert (
model_kwargs . get ( " logits_wrapper " , None ) is None
) , " For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper "
assert (
model_kwargs . get ( " stopping_criteria " , None ) is None
) , " For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria "
prefix_length = 0 if inputs is None else inputs . size ( 1 )
prefix_length + = self . config . pre_seq_len
@ -226,7 +215,6 @@ class RemoteGenerationMixin:
pad_token_id : Optional [ int ] = None ,
eos_token_id : Optional [ int ] = None ,
provided_constraints : List [ ABCBloomConstraint ] = [ ] ,
* * model_kwargs ,
) - > torch . LongTensor :
"""
Generates sequences of token ids for models with a language modeling head . Uses greedy search .
@ -244,7 +232,6 @@ class RemoteGenerationMixin:
eos_token_id = eos_token_id ,
decoding_algorithm = GreedyAlgorithm ( ) ,
provided_constraints = provided_constraints ,
* * model_kwargs ,
)
def sample (
@ -257,7 +244,6 @@ class RemoteGenerationMixin:
pad_token_id : Optional [ int ] = None ,
eos_token_id : Optional [ int ] = None ,
provided_constraints : List [ ABCBloomConstraint ] = [ ] ,
* * model_kwargs ,
) - > torch . LongTensor :
"""
Generates sequences of token ids for models with a language modeling head . Uses multinomial sampling .
@ -271,7 +257,6 @@ class RemoteGenerationMixin:
: param : pad_token_id : The id of the padding token .
: param : eos_token_id : The id of the end of sentence token .
: param : provided_constraints : A list of constraints to use .
: param : model_kwargs : Additional kwargs to pass to the model .
"""
return self . generate (
@ -281,7 +266,6 @@ class RemoteGenerationMixin:
eos_token_id = eos_token_id ,
decoding_algorithm = self . _choose_sample_algorithm ( temperature , top_k , top_p ) ,
provided_constraints = provided_constraints ,
* * model_kwargs ,
)
def beam_search (
@ -292,7 +276,6 @@ class RemoteGenerationMixin:
pad_token_id : Optional [ int ] = None ,
eos_token_id : Optional [ int ] = None ,
provided_constraints : List [ ABCBloomConstraint ] = [ ] ,
* * model_kwargs ,
) - > torch . LongTensor :
"""
Generates sequences of token ids for models with a language modeling head . Uses beam search .
@ -303,7 +286,6 @@ class RemoteGenerationMixin:
: param pad_token_id : The id of the padding token .
: param eos_token_id : The id of the end of sentence token .
: param provided_constraints : A list of constraints to use .
: param : model_kwargs : Additional kwargs to pass to the model .
"""
decoding_algorithm = BeamSearchAlgorithm (
num_beams = num_beams ,
@ -317,7 +299,6 @@ class RemoteGenerationMixin:
eos_token_id = eos_token_id ,
decoding_algorithm = decoding_algorithm ,
provided_constraints = provided_constraints ,
* * model_kwargs ,
)
def beam_sample (
@ -327,7 +308,6 @@ class RemoteGenerationMixin:
pad_token_id : Optional [ int ] = None ,
eos_token_id : Optional [ int ] = None ,
provided_constraints : List [ ABCBloomConstraint ] = [ ] ,
* * model_kwargs ,
) - > torch . LongTensor :
raise NotImplementedError
@ -338,7 +318,6 @@ class RemoteGenerationMixin:
pad_token_id : Optional [ int ] = None ,
eos_token_id : Optional [ int ] = None ,
provided_constraints : List [ ABCBloomConstraint ] = [ ] ,
* * model_kwargs ,
) - > torch . LongTensor :
raise NotImplementedError