From c0e0e1319dfae0c307cdfc8cb86825b827bc598e Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 13 Apr 2023 14:41:54 +0400 Subject: [PATCH] Force transformers to use config.torch_dtype by default (#307) --- src/petals/client/remote_model.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index dc987e4..d67d4bf 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -71,20 +71,33 @@ def force_non_empty_weights(): nn.Module.register_parameter = possibly_patched_register_parameter -class _LowCPUMemoryMixin: +class _FromPretrainedDefaultsMixin: @classmethod - def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs): + def from_pretrained( + cls, + *args, + low_cpu_mem_usage: Optional[bool] = None, + torch_dtype: Optional[Union[str, torch.dtype]] = None, + **kwargs, + ): if low_cpu_mem_usage is None: low_cpu_mem_usage = True - return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs) + if torch_dtype is None: + # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast, + # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights. + torch_dtype = "auto" + return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs) from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( "low_cpu_mem_usage(`bool`, *optional*)", "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)", + ).replace( + "torch_dtype (`str` or `torch.dtype`, *optional*)", + 'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)', ) -class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel): +class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel): """BloomModel, but all transformer layers are hosted by the swarm""" _keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [ @@ -218,7 +231,7 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel): ) -class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, BloomForCausalLM): +class DistributedBloomForCausalLM(_FromPretrainedDefaultsMixin, RemoteGenerationMixin, BloomForCausalLM): """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm""" _keys_to_ignore_on_load_missing = ( @@ -256,7 +269,7 @@ class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, Blo self.lm_head.bias[...] = new_lm_head.bias -class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequenceClassification): +class DistributedBloomForSequenceClassification(_FromPretrainedDefaultsMixin, BloomForSequenceClassification): _keys_to_ignore_on_load_missing = ( BloomForSequenceClassification._keys_to_ignore_on_load_missing + DistributedBloomModel._keys_to_ignore_on_load_missing