fault-tolerant-inference
Aleksandr Borzunov 1 year ago
parent 0ef1d15c45
commit 8c50f65cf2

@ -258,10 +258,13 @@ class InferenceSession:
self._chosen_spans[server_idx : server_idx + 1] = updated_spans
self._server_sessions[server_idx : server_idx + 1] = updated_sessions
recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (len(updated_spans) - 1)
assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), \
f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, " \
self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (
len(updated_spans) - 1
)
assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), (
f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, "
f"{len(self._server_inputs)} inputs"
)
session = self._server_sessions[server_idx]
span = self._chosen_spans[server_idx]
@ -272,9 +275,10 @@ class InferenceSession:
self._server_inputs[server_idx] = torch.cat(
[self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
)
assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, \
f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} " \
assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, (
f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} "
f"position={self._position} n_input_tokens={n_input_tokens}"
)
if not session.stepped:
inputs = self._server_inputs[server_idx] # Pass full inputs including prefix
@ -282,8 +286,9 @@ class InferenceSession:
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
assert inputs.shape == outputs.shape, \
f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
assert (
inputs.shape == outputs.shape
), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
inputs = outputs
server_idx += 1

Loading…
Cancel
Save