Force transformers to use config.torch_dtype by default (#307)

pull/306/head^2
Alexander Borzunov 1 year ago committed by GitHub
parent 98be9ffe4c
commit c0e0e1319d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -71,20 +71,33 @@ def force_non_empty_weights():
nn.Module.register_parameter = possibly_patched_register_parameter nn.Module.register_parameter = possibly_patched_register_parameter
class _LowCPUMemoryMixin: class _FromPretrainedDefaultsMixin:
@classmethod @classmethod
def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs): def from_pretrained(
cls,
*args,
low_cpu_mem_usage: Optional[bool] = None,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
**kwargs,
):
if low_cpu_mem_usage is None: if low_cpu_mem_usage is None:
low_cpu_mem_usage = True low_cpu_mem_usage = True
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs) if torch_dtype is None:
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
torch_dtype = "auto"
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs)
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
"low_cpu_mem_usage(`bool`, *optional*)", "low_cpu_mem_usage(`bool`, *optional*)",
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)", "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
).replace(
"torch_dtype (`str` or `torch.dtype`, *optional*)",
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
) )
class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel): class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel):
"""BloomModel, but all transformer layers are hosted by the swarm""" """BloomModel, but all transformer layers are hosted by the swarm"""
_keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [ _keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
@ -218,7 +231,7 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
) )
class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, BloomForCausalLM): class DistributedBloomForCausalLM(_FromPretrainedDefaultsMixin, RemoteGenerationMixin, BloomForCausalLM):
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm""" """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
_keys_to_ignore_on_load_missing = ( _keys_to_ignore_on_load_missing = (
@ -256,7 +269,7 @@ class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, Blo
self.lm_head.bias[...] = new_lm_head.bias self.lm_head.bias[...] = new_lm_head.bias
class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequenceClassification): class DistributedBloomForSequenceClassification(_FromPretrainedDefaultsMixin, BloomForSequenceClassification):
_keys_to_ignore_on_load_missing = ( _keys_to_ignore_on_load_missing = (
BloomForSequenceClassification._keys_to_ignore_on_load_missing BloomForSequenceClassification._keys_to_ignore_on_load_missing
+ DistributedBloomModel._keys_to_ignore_on_load_missing + DistributedBloomModel._keys_to_ignore_on_load_missing

Loading…
Cancel
Save