From 8c50f65cf237f68ff66061611f6e46447a3981a4 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Sat, 26 Nov 2022 23:48:12 +0000 Subject: [PATCH] black --- src/client/inference_session.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/client/inference_session.py b/src/client/inference_session.py index 9359694..da45fb7 100644 --- a/src/client/inference_session.py +++ b/src/client/inference_session.py @@ -258,10 +258,13 @@ class InferenceSession: self._chosen_spans[server_idx : server_idx + 1] = updated_spans self._server_sessions[server_idx : server_idx + 1] = updated_sessions recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None - self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (len(updated_spans) - 1) - assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), \ - f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, " \ + self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * ( + len(updated_spans) - 1 + ) + assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), ( + f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, " f"{len(self._server_inputs)} inputs" + ) session = self._server_sessions[server_idx] span = self._chosen_spans[server_idx] @@ -272,9 +275,10 @@ class InferenceSession: self._server_inputs[server_idx] = torch.cat( [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1 ) - assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, \ - f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} " \ + assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, ( + f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} " f"position={self._position} n_input_tokens={n_input_tokens}" + ) if not session.stepped: inputs = self._server_inputs[server_idx] # Pass full inputs including prefix @@ -282,8 +286,9 @@ class InferenceSession: inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further outputs = session.step(inputs, prompts[span.start : span.end], **kwargs) - assert inputs.shape == outputs.shape, \ - f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})" + assert ( + inputs.shape == outputs.shape + ), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})" inputs = outputs server_idx += 1