From 3deb3858654162b44f02c69ac9a1d0c799a4a9ae Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Wed, 31 Aug 2022 11:05:42 +0300 Subject: [PATCH] Add preseq_length in prefix size --- src/client/remote_generation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/client/remote_generation.py b/src/client/remote_generation.py index 0fc9c00..e4875cc 100644 --- a/src/client/remote_generation.py +++ b/src/client/remote_generation.py @@ -63,6 +63,7 @@ class RemoteGenerationMixin: if inputs is not None: assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]" prefix_length = 0 if inputs is None else inputs.size(1) + prefix_length += self.config.pre_seq_len bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id