Fix dtype- and device-related client issues (#98)

This PR:

1. Makes inference/forward/backward calls on client remember the dtype and device of source tensors, then move/cast the outputs to the same dtype/device. This way:
    - Users don't need to make changes in the code launching `RemoteSequential` to make it run on a different device.
    - `model.generate()` also starts to support both CPU and GPU.

2. Sets default `low_cpu_mem_usage=True`, client's request timeout to 20 sec.

3. Removes excess casts to float32 left in Dmitry's code.

4. (minor) Improves error messages.
fix-ptune
Alexander Borzunov 2 years ago committed by GitHub
parent c6e1b5a8e5
commit ab41223b17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes.
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
See commit history for authorship.
"""
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
@ -108,11 +108,24 @@ BLOOM_INPUTS_DOCSTRING = r"""
"""
class _BloomPreTrainedModelWithModifiedDefaults(BloomPreTrainedModel):
@classmethod
def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
"low_cpu_mem_usage(`bool`, *optional*)",
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
)
@add_start_docstrings(
"The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
BLOOM_START_DOCSTRING,
)
class BloomModel(BloomPreTrainedModel):
class BloomModel(_BloomPreTrainedModelWithModifiedDefaults):
def __init__(self, config):
super().__init__(config)
assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
@ -277,7 +290,7 @@ class BloomModel(BloomPreTrainedModel):
""",
BLOOM_START_DOCSTRING,
)
class BloomForCausalLM(BloomPreTrainedModel):
class BloomForCausalLM(_BloomPreTrainedModelWithModifiedDefaults):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
@ -436,7 +449,7 @@ class LMHead(nn.Module):
else:
# Switch dtype in case word_embeddings are fp16/bf16
hidden_states = hidden_states.to(word_embeddings.dtype)
lm_logits = F.linear(hidden_states, word_embeddings).float()
lm_logits = F.linear(hidden_states, word_embeddings)
return lm_logits
def chunked_forward(self, hidden_states):
@ -470,7 +483,7 @@ class LMHead(nn.Module):
""",
BLOOM_START_DOCSTRING,
)
class BloomForSequenceClassification(BloomPreTrainedModel):
class BloomForSequenceClassification(_BloomPreTrainedModelWithModifiedDefaults):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import itertools
import logging
import time
from typing import AsyncIterator, List, Optional
@ -18,7 +19,6 @@ from hivemind import (
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import StubBase
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import aiter_with_timeout
from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
@ -218,6 +218,11 @@ class InferenceSession:
else:
assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
inputs_device = inputs.device
inputs_dtype = inputs.dtype
inputs = inputs.cpu()
prompts = prompts.cpu()
n_input_tokens = inputs.shape[1]
if self._position + n_input_tokens > self._max_length:
raise ValueError(
@ -300,11 +305,14 @@ class InferenceSession:
f"Caught exception when running inference from block {block_idx} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
logger.debug("See detailed traceback below:", exc_info=True)
traceback_level = logging.DEBUG if str(e) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
time.sleep(delay)
self._position += n_input_tokens
return inputs
outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
return outputs
def close(self, *exc_details):
"""Finish a given inference session, close the underlying connection"""

@ -129,7 +129,7 @@ class DistributedBloomModel(BloomModel):
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:

@ -31,7 +31,6 @@ class RemoteSequential(nn.Module):
p2p: Optional[P2P] = None,
sequence_manager: Optional[RemoteSequenceManager] = None,
):
logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
super().__init__()
self.config = config
self.dht = dht

@ -30,7 +30,7 @@ class RemoteSequenceManager:
block_uids: Sequence[ModuleUID],
p2p: P2P,
max_retries: int = 3,
timeout: float = 5,
timeout: float = 20,
min_backoff: float = 1,
):
assert len(block_uids) > 0, "Sequences must contain at least one block"

@ -3,6 +3,7 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
"""
import asyncio
import itertools
import logging
from collections import deque
from typing import List, Optional, Sequence, Tuple
@ -36,6 +37,11 @@ async def sequential_forward(
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
inputs_device = inputs.device
inputs_dtype = inputs.dtype
inputs = inputs.cpu()
prompts = prompts.cpu()
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
assert is_dummy(prompts) or len(prompts) == len(
@ -86,9 +92,12 @@ async def sequential_forward(
f"Caught exception when running forward from block {block_idx} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
logger.debug("See detailed traceback below:", exc_info=True)
traceback_level = logging.DEBUG if str(e) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
await asyncio.sleep(delay)
outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
intermediate_inputs = [tensor.to(device=inputs_device, dtype=inputs_dtype) for tensor in intermediate_inputs]
return outputs, intermediate_inputs, done_sequences
@ -98,13 +107,22 @@ async def sequential_backward(
prompts: torch.Tensor,
forward_sequences: List[RemoteSpanInfo],
sequence_manager: RemoteSequenceManager,
) -> Sequence[torch.Tensor]:
) -> Tuple[Sequence[torch.Tensor], torch.Tensor]:
"""
Performs chained backward for each forward subsequence.
If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
"""
assert len(intermediate_inputs) == len(forward_sequences)
grad_outputs_device = grad_outputs[0].device if grad_outputs else None
grad_outputs_dtype = grad_outputs[0].dtype if grad_outputs else None
prompts_device = prompts.device
prompts_dtype = prompts.dtype
grad_outputs = [tensor.cpu() for tensor in grad_outputs]
intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
prompts = prompts.cpu()
grad_prompts_reversed = []
while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
inputs = intermediate_inputs.pop()
@ -146,12 +164,18 @@ async def sequential_backward(
f"Caught exception when running backward between blocks {span.start}-{span.end} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
logger.debug("See detailed traceback below:", exc_info=True)
traceback_level = logging.DEBUG if str(e) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
await asyncio.sleep(delay)
# For now, we do not support mixed dummy and grad prompts
# Concat in num_layer dimension
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
if grad_outputs_dtype is not None:
grad_outputs = [tensor.to(device=grad_outputs_device, dtype=grad_outputs_dtype) for tensor in grad_outputs]
if grad_prompts is not None:
grad_prompts = grad_prompts.to(device=prompts_device, dtype=prompts_dtype)
return grad_outputs, grad_prompts

Loading…
Cancel
Save