diff --git a/setup.cfg b/setup.cfg index 73bb117..a05ae6b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,7 @@ install_requires = humanfriendly async-timeout>=4.0.2 cpufeature>=0.2.0 - packaging>=23.0 + packaging>=20.9 [options.extras_require] dev = diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index 4182ca8..af4166d 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -104,17 +104,18 @@ class RemoteGenerationMixin: elif max_length is None and max_new_tokens is not None: max_length = prefix_length + max_new_tokens - if num_beams > 1 and session is not None: + resuming_session = session is not None and session.last_token_id is not None + if num_beams > 1 and resuming_session: raise NotImplementedError( - "Reusing inference session in .generate() along with beam search is not supported yet" + "Resuming inference session in .generate() along with beam search is not supported yet" ) if inputs is not None: assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]" - if session is not None and session.last_token_id is not None: + if resuming_session: inputs = torch.cat([session.last_token_id, inputs], dim=1) else: - if session is not None and session.last_token_id is not None: + if resuming_session: inputs = session.last_token_id else: assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs" @@ -207,6 +208,8 @@ class RemoteGenerationMixin: outputs = torch.cat(outputs, dim=-1) + if resuming_session: + outputs = outputs[:, 1:] if num_beams > 1: pre_return_idx = [ torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index f491a10..8b6dc9c 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -123,6 +123,8 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]: bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request + if network_rps == 0: + raise ValueError("speedtest has returned network_rps == 0") logger.info( f"Network throughput: "