Use float32 as default client dtype

pull/530/head
Aleksandr Borzunov 8 months ago
parent 46aef2cf2e
commit 8d41e0ed3d

@ -17,7 +17,7 @@ from petals.models import *
from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "2.2.0"
__version__ = "2.3.0.dev0"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

@ -6,7 +6,6 @@ import tempfile
from contextvars import ContextVar
from typing import List, Optional, Tuple, Union
import torch
from hivemind.utils.logging import get_logger
from transformers import BloomPreTrainedModel, modeling_utils
@ -22,21 +21,14 @@ class FromPretrainedMixin:
model_name_or_path: Union[str, os.PathLike, None],
*args,
low_cpu_mem_usage: Optional[bool] = None,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
**kwargs,
):
model_name_or_path = get_compatible_model_repo(model_name_or_path)
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
if torch_dtype is None:
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
torch_dtype = "auto"
with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
return super().from_pretrained(
model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
)
return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
"low_cpu_mem_usage(`bool`, *optional*)",

Loading…
Cancel
Save