Fix dtype- and device-related client issues (#98)

This PR:

1. Makes inference/forward/backward calls on client remember the dtype and device of source tensors, then move/cast the outputs to the same dtype/device. This way:
    - Users don't need to make changes in the code launching `RemoteSequential` to make it run on a different device.
    - `model.generate()` also starts to support both CPU and GPU.

2. Sets default `low_cpu_mem_usage=True`, client's request timeout to 20 sec.

3. Removes excess casts to float32 left in Dmitry's code.

4. (minor) Improves error messages.
fix-ptune
Alexander Borzunov 1 year ago committed by GitHub
parent c6e1b5a8e5
commit ab41223b17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes.
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
See commit history for authorship.
"""
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
@ -108,11 +108,24 @@ BLOOM_INPUTS_DOCSTRING = r"""
"""
class _BloomPreTrainedModelWithModifiedDefaults(BloomPreTrainedModel):
@classmethod
def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
"low_cpu_mem_usage(`bool`, *optional*)",
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
)
@add_start_docstrings(
"The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
BLOOM_START_DOCSTRING,
)
class BloomModel(BloomPreTrainedModel):
class BloomModel(_BloomPreTrainedModelWithModifiedDefaults):
def __init__(self, config):
super().__init__(config)
assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
@ -277,7 +290,7 @@ class BloomModel(BloomPreTrainedModel):
""",
BLOOM_START_DOCSTRING,
)
class BloomForCausalLM(BloomPreTrainedModel):
class BloomForCausalLM(_BloomPreTrainedModelWithModifiedDefaults):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
@ -400,8 +413,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
@add_start_docstrings(
"""
The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
""",
BLOOM_START_DOCSTRING,
)
@ -436,7 +449,7 @@ class LMHead(nn.Module):
else:
# Switch dtype in case word_embeddings are fp16/bf16
hidden_states = hidden_states.to(word_embeddings.dtype)
lm_logits = F.linear(hidden_states, word_embeddings).float()
lm_logits = F.linear(hidden_states, word_embeddings)
return lm_logits
def chunked_forward(self, hidden_states):
@ -470,7 +483,7 @@ class LMHead(nn.Module):
""",
BLOOM_START_DOCSTRING,
)
class BloomForSequenceClassification(BloomPreTrainedModel):
class BloomForSequenceClassification(_BloomPreTrainedModelWithModifiedDefaults):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import itertools
import logging
import time
from typing import AsyncIterator, List, Optional
@ -18,7 +19,6 @@ from hivemind import (
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import StubBase
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import aiter_with_timeout
from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
@ -218,6 +218,11 @@ class InferenceSession:
else:
assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
inputs_device = inputs.device
inputs_dtype = inputs.dtype
inputs = inputs.cpu()
prompts = prompts.cpu()
n_input_tokens = inputs.shape[1]
if self._position + n_input_tokens > self._max_length:
raise ValueError(
@ -300,11 +305,14 @@ class InferenceSession:
f"Caught exception when running inference from block {block_idx} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
logger.debug("See detailed traceback below:", exc_info=True)
traceback_level = logging.DEBUG if str(e) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
time.sleep(delay)
self._position += n_input_tokens
return inputs
outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
return outputs
def close(self, *exc_details):
"""Finish a given inference session, close the underlying connection"""

@ -129,7 +129,7 @@ class DistributedBloomModel(BloomModel):
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:

@ -31,7 +31,6 @@ class RemoteSequential(nn.Module):
p2p: Optional[P2P] = None,
sequence_manager: Optional[RemoteSequenceManager] = None,
):
logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
super().__init__()
self.config = config
self.dht = dht

@ -30,7 +30,7 @@ class RemoteSequenceManager:
block_uids: Sequence[ModuleUID],
p2p: P2P,
max_retries: int = 3,
timeout: float = 5,
timeout: float = 20,
min_backoff: float = 1,
):
assert len(block_uids) > 0, "Sequences must contain at least one block"

@ -3,6 +3,7 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
"""
import asyncio
import itertools
import logging
from collections import deque
from typing import List, Optional, Sequence, Tuple
@ -36,6 +37,11 @@ async def sequential_forward(
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
inputs_device = inputs.device
inputs_dtype = inputs.dtype
inputs = inputs.cpu()
prompts = prompts.cpu()
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
assert is_dummy(prompts) or len(prompts) == len(
@ -86,9 +92,12 @@ async def sequential_forward(
f"Caught exception when running forward from block {block_idx} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
logger.debug("See detailed traceback below:", exc_info=True)
traceback_level = logging.DEBUG if str(e) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
await asyncio.sleep(delay)
outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
intermediate_inputs = [tensor.to(device=inputs_device, dtype=inputs_dtype) for tensor in intermediate_inputs]
return outputs, intermediate_inputs, done_sequences
@ -98,13 +107,22 @@ async def sequential_backward(
prompts: torch.Tensor,
forward_sequences: List[RemoteSpanInfo],
sequence_manager: RemoteSequenceManager,
) -> Sequence[torch.Tensor]:
) -> Tuple[Sequence[torch.Tensor], torch.Tensor]:
"""
Performs chained backward for each forward subsequence.
If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
"""
assert len(intermediate_inputs) == len(forward_sequences)
grad_outputs_device = grad_outputs[0].device if grad_outputs else None
grad_outputs_dtype = grad_outputs[0].dtype if grad_outputs else None
prompts_device = prompts.device
prompts_dtype = prompts.dtype
grad_outputs = [tensor.cpu() for tensor in grad_outputs]
intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
prompts = prompts.cpu()
grad_prompts_reversed = []
while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
inputs = intermediate_inputs.pop()
@ -146,12 +164,18 @@ async def sequential_backward(
f"Caught exception when running backward between blocks {span.start}-{span.end} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
logger.debug("See detailed traceback below:", exc_info=True)
traceback_level = logging.DEBUG if str(e) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
await asyncio.sleep(delay)
# For now, we do not support mixed dummy and grad prompts
# Concat in num_layer dimension
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
if grad_outputs_dtype is not None:
grad_outputs = [tensor.to(device=grad_outputs_device, dtype=grad_outputs_dtype) for tensor in grad_outputs]
if grad_prompts is not None:
grad_prompts = grad_prompts.to(device=prompts_device, dtype=prompts_dtype)
return grad_outputs, grad_prompts

Loading…
Cancel
Save