|
|
|
@ -56,7 +56,7 @@ class RemoteGenerationMixin:
|
|
|
|
|
context_manager = self.use_session(session)
|
|
|
|
|
elif self.active_session is not None:
|
|
|
|
|
# If there's an active session, don't do anything
|
|
|
|
|
context_manager = contextlib.nullcontext()
|
|
|
|
|
context_manager = contextlib.nullcontext(self.active_session)
|
|
|
|
|
else:
|
|
|
|
|
# If there's no active session, create a new one
|
|
|
|
|
|
|
|
|
@ -72,8 +72,24 @@ class RemoteGenerationMixin:
|
|
|
|
|
session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
|
|
|
|
|
context_manager = self.inference_session(max_length=session_max_length)
|
|
|
|
|
|
|
|
|
|
with context_manager:
|
|
|
|
|
return super().generate(inputs, *args, **kwargs)
|
|
|
|
|
with context_manager as session:
|
|
|
|
|
# Prepend the last tokens from the previous .generate() call
|
|
|
|
|
if session.last_token_id is not None:
|
|
|
|
|
assert session.last_token_id.shape[1] == 1, f"{session.last_token_id.shape=} is invalid"
|
|
|
|
|
if inputs is not None:
|
|
|
|
|
inputs = torch.cat([session.last_token_id, inputs], dim=1)
|
|
|
|
|
else:
|
|
|
|
|
inputs = session.last_token_id
|
|
|
|
|
|
|
|
|
|
result = super().generate(inputs, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
# Crop the last tokens from the previous call
|
|
|
|
|
if session.last_token_id is not None:
|
|
|
|
|
result = result[:, 1:]
|
|
|
|
|
# Save the last tokens from this call
|
|
|
|
|
session.last_token_id = result[:, -1:]
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
|
|
|
|
|