Fix output shape when resuming generation (#211)

Before this PR, `model.generate()` returned one excess token when resuming generation with an existing (the last token of the previous session, `session.last_token_id`). This is an unexpected behavior not convenient for the downstream apps, so this PR changes it until it's too late.
pull/212/head
Alexander Borzunov 1 year ago committed by GitHub
parent cc5e5d32c0
commit 6ba63c6cc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -42,7 +42,7 @@ install_requires =
humanfriendly
async-timeout>=4.0.2
cpufeature>=0.2.0
packaging>=23.0
packaging>=20.9
[options.extras_require]
dev =

@ -104,17 +104,18 @@ class RemoteGenerationMixin:
elif max_length is None and max_new_tokens is not None:
max_length = prefix_length + max_new_tokens
if num_beams > 1 and session is not None:
resuming_session = session is not None and session.last_token_id is not None
if num_beams > 1 and resuming_session:
raise NotImplementedError(
"Reusing inference session in .generate() along with beam search is not supported yet"
"Resuming inference session in .generate() along with beam search is not supported yet"
)
if inputs is not None:
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
if session is not None and session.last_token_id is not None:
if resuming_session:
inputs = torch.cat([session.last_token_id, inputs], dim=1)
else:
if session is not None and session.last_token_id is not None:
if resuming_session:
inputs = session.last_token_id
else:
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
@ -207,6 +208,8 @@ class RemoteGenerationMixin:
outputs = torch.cat(outputs, dim=-1)
if resuming_session:
outputs = outputs[:, 1:]
if num_beams > 1:
pre_return_idx = [
torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)

@ -123,6 +123,8 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]:
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
if network_rps == 0:
raise ValueError("speedtest has returned network_rps == 0")
logger.info(
f"Network throughput: "

Loading…
Cancel
Save