Account for pre_seq_len in new session's max_length

pull/501/head
Aleksandr Borzunov 9 months ago
parent 8cb6b37e2f
commit db46bf3ac1

@ -87,10 +87,11 @@ class RemoteGenerationMixin(_SkipTokensMixin):
max_new_tokens is None
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
session_max_length = self.transformer.config.pre_seq_len
if max_length is not None:
session_max_length = max_length
session_max_length += max_length
else:
session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
context_manager = self.inference_session(max_length=session_max_length)
with context_manager as session:

Loading…
Cancel
Save