Fix multi-call generate

pull/464/head
Aleksandr Borzunov 10 months ago
parent c066ddf06e
commit 235e29f47c

@ -56,7 +56,7 @@ class RemoteGenerationMixin:
context_manager = self.use_session(session)
elif self.active_session is not None:
# If there's an active session, don't do anything
context_manager = contextlib.nullcontext()
context_manager = contextlib.nullcontext(self.active_session)
else:
# If there's no active session, create a new one
@ -72,8 +72,24 @@ class RemoteGenerationMixin:
session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
context_manager = self.inference_session(max_length=session_max_length)
with context_manager:
return super().generate(inputs, *args, **kwargs)
with context_manager as session:
# Prepend the last tokens from the previous .generate() call
if session.last_token_id is not None:
assert session.last_token_id.shape[1] == 1, f"{session.last_token_id.shape=} is invalid"
if inputs is not None:
inputs = torch.cat([session.last_token_id, inputs], dim=1)
else:
inputs = session.last_token_id
result = super().generate(inputs, *args, **kwargs)
# Crop the last tokens from the previous call
if session.last_token_id is not None:
result = result[:, 1:]
# Save the last tokens from this call
session.last_token_id = result[:, -1:]
return result
@staticmethod
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:

Loading…
Cancel
Save