|
|
|
@ -71,20 +71,33 @@ def force_non_empty_weights():
|
|
|
|
|
nn.Module.register_parameter = possibly_patched_register_parameter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _LowCPUMemoryMixin:
|
|
|
|
|
class _FromPretrainedDefaultsMixin:
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
|
|
|
|
|
def from_pretrained(
|
|
|
|
|
cls,
|
|
|
|
|
*args,
|
|
|
|
|
low_cpu_mem_usage: Optional[bool] = None,
|
|
|
|
|
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
if low_cpu_mem_usage is None:
|
|
|
|
|
low_cpu_mem_usage = True
|
|
|
|
|
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
|
|
|
|
|
if torch_dtype is None:
|
|
|
|
|
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
|
|
|
|
|
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
|
|
|
|
|
torch_dtype = "auto"
|
|
|
|
|
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs)
|
|
|
|
|
|
|
|
|
|
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
|
|
|
|
"low_cpu_mem_usage(`bool`, *optional*)",
|
|
|
|
|
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
|
|
|
|
).replace(
|
|
|
|
|
"torch_dtype (`str` or `torch.dtype`, *optional*)",
|
|
|
|
|
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
|
|
|
|
|
class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel):
|
|
|
|
|
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
|
|
|
|
|
_keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
|
|
|
|
@ -218,7 +231,7 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, BloomForCausalLM):
|
|
|
|
|
class DistributedBloomForCausalLM(_FromPretrainedDefaultsMixin, RemoteGenerationMixin, BloomForCausalLM):
|
|
|
|
|
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
|
|
|
|
|
_keys_to_ignore_on_load_missing = (
|
|
|
|
@ -256,7 +269,7 @@ class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, Blo
|
|
|
|
|
self.lm_head.bias[...] = new_lm_head.bias
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequenceClassification):
|
|
|
|
|
class DistributedBloomForSequenceClassification(_FromPretrainedDefaultsMixin, BloomForSequenceClassification):
|
|
|
|
|
_keys_to_ignore_on_load_missing = (
|
|
|
|
|
BloomForSequenceClassification._keys_to_ignore_on_load_missing
|
|
|
|
|
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
|
|
|