Force transformers to use config.torch_dtype by default (#307)

pull/306/head^2
Alexander Borzunov 1 year ago committed by GitHub
parent 98be9ffe4c
commit c0e0e1319d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -71,20 +71,33 @@ def force_non_empty_weights():
nn.Module.register_parameter = possibly_patched_register_parameter
class _LowCPUMemoryMixin:
class _FromPretrainedDefaultsMixin:
@classmethod
def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
def from_pretrained(
cls,
*args,
low_cpu_mem_usage: Optional[bool] = None,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
**kwargs,
):
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
if torch_dtype is None:
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
torch_dtype = "auto"
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs)
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
"low_cpu_mem_usage(`bool`, *optional*)",
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
).replace(
"torch_dtype (`str` or `torch.dtype`, *optional*)",
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
)
class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel):
"""BloomModel, but all transformer layers are hosted by the swarm"""
_keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
@ -218,7 +231,7 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
)
class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, BloomForCausalLM):
class DistributedBloomForCausalLM(_FromPretrainedDefaultsMixin, RemoteGenerationMixin, BloomForCausalLM):
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
_keys_to_ignore_on_load_missing = (
@ -256,7 +269,7 @@ class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, Blo
self.lm_head.bias[...] = new_lm_head.bias
class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequenceClassification):
class DistributedBloomForSequenceClassification(_FromPretrainedDefaultsMixin, BloomForSequenceClassification):
_keys_to_ignore_on_load_missing = (
BloomForSequenceClassification._keys_to_ignore_on_load_missing
+ DistributedBloomModel._keys_to_ignore_on_load_missing

Loading…
Cancel
Save