Raise error for unexpected .generate() kwargs (#315)

Now, if a user passes unexpected kwargs to `.generate()`, they are __ignored__ and the code continues working as if the argument was correctly supported. For example, people often tried passing `repetition_penalty` and didn't notice that it does not have any effect. This PR fixes this problem.
pull/316/head
Alexander Borzunov 1 year ago committed by GitHub
parent d9e7bfc949
commit 6eb306a605
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -44,6 +44,7 @@ class RemoteGenerationMixin:
def generate(
self,
inputs: Optional[torch.Tensor] = None,
*,
do_sample: Optional[bool] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
@ -57,9 +58,7 @@ class RemoteGenerationMixin:
decoding_algorithm: Optional[DecodingAlgorithm] = None,
provided_constraints: List[ABCBloomConstraint] = [],
num_return_sequences: Optional[int] = None,
*,
session: Optional[InferenceSession] = None,
**model_kwargs,
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head.
@ -77,19 +76,9 @@ class RemoteGenerationMixin:
:param max_new_tokens: The maximum number of tokens to generate.
:param decoding_algorithm: The decoding algorithm to use.
:param provided_constraints: A list of constraints to use.
:param model_kwargs: Additional arguments to pass to the model.
:param num_return_sequences: How many hypothesis from the beam will be in output.
"""
assert (
model_kwargs.get("logits_processor", None) is None
), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
assert (
model_kwargs.get("logits_wrapper", None) is None
), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
assert (
model_kwargs.get("stopping_criteria", None) is None
), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
prefix_length = 0 if inputs is None else inputs.size(1)
prefix_length += self.config.pre_seq_len
@ -226,7 +215,6 @@ class RemoteGenerationMixin:
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
**model_kwargs,
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses greedy search.
@ -244,7 +232,6 @@ class RemoteGenerationMixin:
eos_token_id=eos_token_id,
decoding_algorithm=GreedyAlgorithm(),
provided_constraints=provided_constraints,
**model_kwargs,
)
def sample(
@ -257,7 +244,6 @@ class RemoteGenerationMixin:
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
**model_kwargs,
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
@ -271,7 +257,6 @@ class RemoteGenerationMixin:
:param: pad_token_id: The id of the padding token.
:param: eos_token_id: The id of the end of sentence token.
:param: provided_constraints: A list of constraints to use.
:param: model_kwargs: Additional kwargs to pass to the model.
"""
return self.generate(
@ -281,7 +266,6 @@ class RemoteGenerationMixin:
eos_token_id=eos_token_id,
decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
provided_constraints=provided_constraints,
**model_kwargs,
)
def beam_search(
@ -292,7 +276,6 @@ class RemoteGenerationMixin:
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
**model_kwargs,
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses beam search.
@ -303,7 +286,6 @@ class RemoteGenerationMixin:
:param pad_token_id: The id of the padding token.
:param eos_token_id: The id of the end of sentence token.
:param provided_constraints: A list of constraints to use.
:param: model_kwargs: Additional kwargs to pass to the model.
"""
decoding_algorithm = BeamSearchAlgorithm(
num_beams=num_beams,
@ -317,7 +299,6 @@ class RemoteGenerationMixin:
eos_token_id=eos_token_id,
decoding_algorithm=decoding_algorithm,
provided_constraints=provided_constraints,
**model_kwargs,
)
def beam_sample(
@ -327,7 +308,6 @@ class RemoteGenerationMixin:
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
**model_kwargs,
) -> torch.LongTensor:
raise NotImplementedError
@ -338,7 +318,6 @@ class RemoteGenerationMixin:
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
**model_kwargs,
) -> torch.LongTensor:
raise NotImplementedError

Loading…
Cancel
Save