Add preseq_length in prefix size

prompt-inference
Artem Chumachenko 2 years ago
parent a5f84fd6fb
commit 3deb385865

@ -63,6 +63,7 @@ class RemoteGenerationMixin:
if inputs is not None:
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
prefix_length = 0 if inputs is None else inputs.size(1)
prefix_length += self.config.pre_seq_len
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id

Loading…
Cancel
Save