diff --git a/src/client/remote_generation.py b/src/client/remote_generation.py index 0fc9c00..e4875cc 100644 --- a/src/client/remote_generation.py +++ b/src/client/remote_generation.py @@ -63,6 +63,7 @@ class RemoteGenerationMixin: if inputs is not None: assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]" prefix_length = 0 if inputs is None else inputs.size(1) + prefix_length += self.config.pre_seq_len bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id