|
|
|
@ -162,7 +162,7 @@ class InferenceSession:
|
|
|
|
|
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, **metadata):
|
|
|
|
|
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int, **metadata):
|
|
|
|
|
self._sequence_manager = sequence_manager
|
|
|
|
|
self._p2p = p2p
|
|
|
|
|
self._closed = False
|
|
|
|
@ -170,6 +170,7 @@ class InferenceSession:
|
|
|
|
|
self._server_sessions = []
|
|
|
|
|
self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
|
|
|
|
|
self._position = 0
|
|
|
|
|
self._max_length = max_length
|
|
|
|
|
self._metadata = metadata
|
|
|
|
|
|
|
|
|
|
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
|
|
|
|
@ -183,6 +184,7 @@ class InferenceSession:
|
|
|
|
|
span_uids,
|
|
|
|
|
rpc_info=self._sequence_manager.rpc_info,
|
|
|
|
|
timeout=self._sequence_manager.timeout,
|
|
|
|
|
max_length=self._max_length,
|
|
|
|
|
**self._metadata,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
@ -210,6 +212,10 @@ class InferenceSession:
|
|
|
|
|
else:
|
|
|
|
|
assert prompts.ndim == 4 and prompts.shape[0] == len(self._sequence_manager)
|
|
|
|
|
n_input_tokens = inputs.shape[1]
|
|
|
|
|
if self._position + n_input_tokens > self._max_length:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
server_idx = 0
|
|
|
|
|
block_idx = 0
|
|
|
|
|