|
|
|
@ -87,10 +87,11 @@ class RemoteGenerationMixin(_SkipTokensMixin):
|
|
|
|
|
max_new_tokens is None
|
|
|
|
|
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
|
|
|
|
|
|
|
|
|
|
session_max_length = self.transformer.config.pre_seq_len
|
|
|
|
|
if max_length is not None:
|
|
|
|
|
session_max_length = max_length
|
|
|
|
|
session_max_length += max_length
|
|
|
|
|
else:
|
|
|
|
|
session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
|
|
|
|
|
session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
|
|
|
|
|
context_manager = self.inference_session(max_length=session_max_length)
|
|
|
|
|
|
|
|
|
|
with context_manager as session:
|
|
|
|
|