Fix max_length

fault-tolerant-inference
Aleksandr Borzunov 2 years ago
parent 2fafbaa119
commit 01cffeba5d

@ -162,7 +162,7 @@ class InferenceSession:
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
"""
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, **metadata):
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int, **metadata):
self._sequence_manager = sequence_manager
self._p2p = p2p
self._closed = False
@ -170,6 +170,7 @@ class InferenceSession:
self._server_sessions = []
self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
self._position = 0
self._max_length = max_length
self._metadata = metadata
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
@ -183,6 +184,7 @@ class InferenceSession:
span_uids,
rpc_info=self._sequence_manager.rpc_info,
timeout=self._sequence_manager.timeout,
max_length=self._max_length,
**self._metadata,
)
)
@ -210,6 +212,10 @@ class InferenceSession:
else:
assert prompts.ndim == 4 and prompts.shape[0] == len(self._sequence_manager)
n_input_tokens = inputs.shape[1]
if self._position + n_input_tokens > self._max_length:
raise ValueError(
f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}"
)
server_idx = 0
block_idx = 0

@ -33,7 +33,7 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
# test that max length is respected
with pytest.raises(P2PHandlerError) as exc_info:
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
sess.step(inputs[:, -1:, :])
assert "Maximum length exceeded" in repr(exc_info.value)

Loading…
Cancel
Save