|
|
@ -25,7 +25,7 @@ def test_sequence_manager_basics(mode: str):
|
|
|
|
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
|
|
|
|
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
|
|
|
|
sequential = RemoteSequential(
|
|
|
|
sequential = RemoteSequential(
|
|
|
|
config,
|
|
|
|
config,
|
|
|
|
sequence_manager=TestSequenceManager(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),
|
|
|
|
sequence_manager=RemoteSequenceManagerWithChecks(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
sequence = sequential.sequence_manager.make_sequence(mode=mode)
|
|
|
|
sequence = sequential.sequence_manager.make_sequence(mode=mode)
|
|
|
@ -43,7 +43,7 @@ def test_sequence_manager_basics(mode: str):
|
|
|
|
assert shutdown_evt.is_set()
|
|
|
|
assert shutdown_evt.is_set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSequenceManager(RemoteSequenceManager):
|
|
|
|
class RemoteSequenceManagerWithChecks(RemoteSequenceManager):
|
|
|
|
"""A sequence manager that signals if it was shut down"""
|
|
|
|
"""A sequence manager that signals if it was shut down"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, _was_shut_down: threading.Event, **kwargs):
|
|
|
|
def __init__(self, *args, _was_shut_down: threading.Event, **kwargs):
|
|
|
|