Expose request_timeout to DistributedBloomConfig (#105)

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
pull/106/head
Artem Chumachenko 1 year ago committed by GitHub
parent 9faf08b898
commit 7d859a947b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -184,7 +184,7 @@ class InferenceSession:
stub,
span_uids,
rpc_info=self._sequence_manager.rpc_info,
timeout=self._sequence_manager.timeout,
timeout=self._sequence_manager.request_timeout,
max_length=self._max_length,
**self._metadata,
)

@ -36,6 +36,7 @@ class DistributedBloomConfig(BloomConfig):
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
request_timeout: int = 20 # a number of seconds for waiting result from each node
original_register_parameter = nn.Module.register_parameter
@ -84,7 +85,7 @@ class DistributedBloomModel(BloomModel):
else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
)
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
self.h = RemoteSequential(config, dht, config.dht_prefix)
self.h = RemoteSequential(config, dht, config.dht_prefix, request_timeout=config.request_timeout)
# Forbid accumulate grads for embeddings and layernorm
self.set_requires_grad(False)

@ -30,6 +30,7 @@ class RemoteSequential(nn.Module):
dht_prefix: Optional[str] = None,
p2p: Optional[P2P] = None,
sequence_manager: Optional[RemoteSequenceManager] = None,
request_timeout: int = 20,
):
super().__init__()
self.config = config
@ -41,7 +42,7 @@ class RemoteSequential(nn.Module):
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)]
if sequence_manager is None:
logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p)
self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p, request_timeout=request_timeout)
self.is_subsequence = False
else:
logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")

@ -30,7 +30,7 @@ class RemoteSequenceManager:
block_uids: Sequence[ModuleUID],
p2p: P2P,
max_retries: int = 3,
timeout: float = 20,
request_timeout: float = 20,
min_backoff: float = 1,
):
assert len(block_uids) > 0, "Sequences must contain at least one block"
@ -41,7 +41,7 @@ class RemoteSequenceManager:
self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
self.last_update_time: DHTExpiration = -float("inf")
self.max_retries = max_retries
self.timeout, self.min_backoff = timeout, min_backoff
self.request_timeout, self.min_backoff = request_timeout, min_backoff
self._rpc_info = None
self.lock_changes = threading.Lock()
self.policy = NoSpendingPolicy()

@ -77,7 +77,7 @@ async def sequential_forward(
stub,
sequence_manager.rpc_info,
*inputs_and_prompts,
timeout=sequence_manager.timeout,
timeout=sequence_manager.request_timeout,
metadata=metadata,
)
@ -161,7 +161,7 @@ async def sequential_backward(
inputs,
grad_outputs,
prompts[span.start : span.end],
timeout=sequence_manager.timeout,
timeout=sequence_manager.request_timeout,
metadata=metadata,
)
grad_outputs = [grad_outputs]

Loading…
Cancel
Save