|
|
|
@ -6,7 +6,6 @@ import tempfile
|
|
|
|
|
from contextvars import ContextVar
|
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
from transformers import BloomPreTrainedModel, modeling_utils
|
|
|
|
|
|
|
|
|
@ -22,21 +21,14 @@ class FromPretrainedMixin:
|
|
|
|
|
model_name_or_path: Union[str, os.PathLike, None],
|
|
|
|
|
*args,
|
|
|
|
|
low_cpu_mem_usage: Optional[bool] = None,
|
|
|
|
|
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
model_name_or_path = get_compatible_model_repo(model_name_or_path)
|
|
|
|
|
if low_cpu_mem_usage is None:
|
|
|
|
|
low_cpu_mem_usage = True
|
|
|
|
|
if torch_dtype is None:
|
|
|
|
|
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
|
|
|
|
|
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
|
|
|
|
|
torch_dtype = "auto"
|
|
|
|
|
|
|
|
|
|
with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
|
|
|
|
|
return super().from_pretrained(
|
|
|
|
|
model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
|
|
|
|
|
)
|
|
|
|
|
return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
|
|
|
|
|
|
|
|
|
|
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
|
|
|
|
"low_cpu_mem_usage(`bool`, *optional*)",
|
|
|
|
|