From ab41223b17c17dd1035a42318b03d4b92decd063 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 29 Nov 2022 16:08:02 +0400 Subject: [PATCH] Fix dtype- and device-related client issues (#98) This PR: 1. Makes inference/forward/backward calls on client remember the dtype and device of source tensors, then move/cast the outputs to the same dtype/device. This way: - Users don't need to make changes in the code launching `RemoteSequential` to make it run on a different device. - `model.generate()` also starts to support both CPU and GPU. 2. Sets default `low_cpu_mem_usage=True`, client's request timeout to 20 sec. 3. Removes excess casts to float32 left in Dmitry's code. 4. (minor) Improves error messages. --- src/bloom/model.py | 27 ++++++++++++++++++++------- src/client/inference_session.py | 14 +++++++++++--- src/client/remote_model.py | 2 +- src/client/remote_sequential.py | 1 - src/client/sequence_manager.py | 2 +- src/client/sequential_autograd.py | 30 +++++++++++++++++++++++++++--- 6 files changed, 60 insertions(+), 16 deletions(-) diff --git a/src/bloom/model.py b/src/bloom/model.py index a5c7d9e..d63695e 100644 --- a/src/bloom/model.py +++ b/src/bloom/model.py @@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes. Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b See commit history for authorship. """ -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F @@ -108,11 +108,24 @@ BLOOM_INPUTS_DOCSTRING = r""" """ +class _BloomPreTrainedModelWithModifiedDefaults(BloomPreTrainedModel): + @classmethod + def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs): + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs) + + from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( + "low_cpu_mem_usage(`bool`, *optional*)", + "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)", + ) + + @add_start_docstrings( "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.", BLOOM_START_DOCSTRING, ) -class BloomModel(BloomPreTrainedModel): +class BloomModel(_BloomPreTrainedModelWithModifiedDefaults): def __init__(self, config): super().__init__(config) assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity" @@ -277,7 +290,7 @@ class BloomModel(BloomPreTrainedModel): """, BLOOM_START_DOCSTRING, ) -class BloomForCausalLM(BloomPreTrainedModel): +class BloomForCausalLM(_BloomPreTrainedModelWithModifiedDefaults): _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] def __init__(self, config): @@ -400,8 +413,8 @@ class BloomForCausalLM(BloomPreTrainedModel): @add_start_docstrings( """ The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input - embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries. - In addition, it provides an effcient way to deal with half-precision word embeddings on CPU. + embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries. + In addition, it provides an effcient way to deal with half-precision word embeddings on CPU. """, BLOOM_START_DOCSTRING, ) @@ -436,7 +449,7 @@ class LMHead(nn.Module): else: # Switch dtype in case word_embeddings are fp16/bf16 hidden_states = hidden_states.to(word_embeddings.dtype) - lm_logits = F.linear(hidden_states, word_embeddings).float() + lm_logits = F.linear(hidden_states, word_embeddings) return lm_logits def chunked_forward(self, hidden_states): @@ -470,7 +483,7 @@ class LMHead(nn.Module): """, BLOOM_START_DOCSTRING, ) -class BloomForSequenceClassification(BloomPreTrainedModel): +class BloomForSequenceClassification(_BloomPreTrainedModelWithModifiedDefaults): _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] def __init__(self, config): diff --git a/src/client/inference_session.py b/src/client/inference_session.py index da45fb7..9d98333 100644 --- a/src/client/inference_session.py +++ b/src/client/inference_session.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import itertools +import logging import time from typing import AsyncIterator, List, Optional @@ -18,7 +19,6 @@ from hivemind import ( from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.p2p import StubBase from hivemind.proto import runtime_pb2 -from hivemind.utils.asyncio import aiter_with_timeout from src.client.sequence_manager import RemoteSequenceManager from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo @@ -218,6 +218,11 @@ class InferenceSession: else: assert prompts.ndim == 4 and prompts.shape[0] == n_blocks + inputs_device = inputs.device + inputs_dtype = inputs.dtype + inputs = inputs.cpu() + prompts = prompts.cpu() + n_input_tokens = inputs.shape[1] if self._position + n_input_tokens > self._max_length: raise ValueError( @@ -300,11 +305,14 @@ class InferenceSession: f"Caught exception when running inference from block {block_idx} " f"(retry in {delay:.0f} sec): {repr(e)}" ) - logger.debug("See detailed traceback below:", exc_info=True) + traceback_level = logging.DEBUG if str(e) else logging.WARNING + logger.log(traceback_level, "See detailed traceback below:", exc_info=True) time.sleep(delay) self._position += n_input_tokens - return inputs + + outputs = inputs.to(device=inputs_device, dtype=inputs_dtype) + return outputs def close(self, *exc_details): """Finish a given inference session, close the underlying connection""" diff --git a/src/client/remote_model.py b/src/client/remote_model.py index b846d6e..21b41dd 100644 --- a/src/client/remote_model.py +++ b/src/client/remote_model.py @@ -129,7 +129,7 @@ class DistributedBloomModel(BloomModel): prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) - hidden_states = self.word_embeddings_layernorm(inputs_embeds.float()) + hidden_states = self.word_embeddings_layernorm(inputs_embeds) output_shape = input_shape + (hidden_states.size(-1),) if self.config.tuning_mode and "ptune" in self.config.tuning_mode: diff --git a/src/client/remote_sequential.py b/src/client/remote_sequential.py index fb62249..575306e 100644 --- a/src/client/remote_sequential.py +++ b/src/client/remote_sequential.py @@ -31,7 +31,6 @@ class RemoteSequential(nn.Module): p2p: Optional[P2P] = None, sequence_manager: Optional[RemoteSequenceManager] = None, ): - logger.warning(f"{self.__class__.__name__} is in active development; expect adventures") super().__init__() self.config = config self.dht = dht diff --git a/src/client/sequence_manager.py b/src/client/sequence_manager.py index de66d84..800bd4a 100644 --- a/src/client/sequence_manager.py +++ b/src/client/sequence_manager.py @@ -30,7 +30,7 @@ class RemoteSequenceManager: block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3, - timeout: float = 5, + timeout: float = 20, min_backoff: float = 1, ): assert len(block_uids) > 0, "Sequences must contain at least one block" diff --git a/src/client/sequential_autograd.py b/src/client/sequential_autograd.py index 364a6b5..2e9f62c 100644 --- a/src/client/sequential_autograd.py +++ b/src/client/sequential_autograd.py @@ -3,6 +3,7 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s """ import asyncio import itertools +import logging from collections import deque from typing import List, Optional, Sequence, Tuple @@ -36,6 +37,11 @@ async def sequential_forward( assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" + inputs_device = inputs.device + inputs_dtype = inputs.dtype + inputs = inputs.cpu() + prompts = prompts.cpu() + end_index = end_index if end_index is not None else len(sequence_manager.block_uids) assert start_index >= 0 and end_index <= len(sequence_manager.block_uids) assert is_dummy(prompts) or len(prompts) == len( @@ -86,9 +92,12 @@ async def sequential_forward( f"Caught exception when running forward from block {block_idx} " f"(retry in {delay:.0f} sec): {repr(e)}" ) - logger.debug("See detailed traceback below:", exc_info=True) + traceback_level = logging.DEBUG if str(e) else logging.WARNING + logger.log(traceback_level, "See detailed traceback below:", exc_info=True) await asyncio.sleep(delay) + outputs = inputs.to(device=inputs_device, dtype=inputs_dtype) + intermediate_inputs = [tensor.to(device=inputs_device, dtype=inputs_dtype) for tensor in intermediate_inputs] return outputs, intermediate_inputs, done_sequences @@ -98,13 +107,22 @@ async def sequential_backward( prompts: torch.Tensor, forward_sequences: List[RemoteSpanInfo], sequence_manager: RemoteSequenceManager, -) -> Sequence[torch.Tensor]: +) -> Tuple[Sequence[torch.Tensor], torch.Tensor]: """ Performs chained backward for each forward subsequence. If some subsequence fails, reconstructs the particular sub-path and recovers the backward. """ assert len(intermediate_inputs) == len(forward_sequences) + grad_outputs_device = grad_outputs[0].device if grad_outputs else None + grad_outputs_dtype = grad_outputs[0].dtype if grad_outputs else None + prompts_device = prompts.device + prompts_dtype = prompts.dtype + + grad_outputs = [tensor.cpu() for tensor in grad_outputs] + intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs] + prompts = prompts.cpu() + grad_prompts_reversed = [] while len(forward_sequences) > 0 and len(intermediate_inputs) > 0: inputs = intermediate_inputs.pop() @@ -146,12 +164,18 @@ async def sequential_backward( f"Caught exception when running backward between blocks {span.start}-{span.end} " f"(retry in {delay:.0f} sec): {repr(e)}" ) - logger.debug("See detailed traceback below:", exc_info=True) + traceback_level = logging.DEBUG if str(e) else logging.WARNING + logger.log(traceback_level, "See detailed traceback below:", exc_info=True) await asyncio.sleep(delay) # For now, we do not support mixed dummy and grad prompts # Concat in num_layer dimension grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None + + if grad_outputs_dtype is not None: + grad_outputs = [tensor.to(device=grad_outputs_device, dtype=grad_outputs_dtype) for tensor in grad_outputs] + if grad_prompts is not None: + grad_prompts = grad_prompts.to(device=prompts_device, dtype=prompts_dtype) return grad_outputs, grad_prompts