fix: handle unexpected keys in weights better

pull/437/head
Bryce 4 months ago committed by Bryce Drennan
parent 5b3b04b877
commit f84406f12c

@ -23,6 +23,8 @@ def imaginairy_click_context(log_level="INFO"):
yield
except errors_to_catch as e:
logger.error(e)
# import traceback
# traceback.print_exc()
def _imagine_cmd(

@ -1,6 +1,5 @@
import importlib
import logging
import numpy as np
import platform
import random
import re
@ -9,6 +8,7 @@ from contextlib import contextmanager, nullcontext
from functools import lru_cache
from typing import Any, List, Optional
import numpy as np
import torch
from torch import Tensor, autocast
from torch.nn import functional
@ -337,6 +337,7 @@ def clear_gpu_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
def seed_everything(seed: int | None = None) -> None:
if seed is None:
seed = random.randint(0, 2**32 - 1)
@ -344,4 +345,4 @@ def seed_everything(seed: int | None = None) -> None:
random.seed(a=seed)
np.random.seed(seed=seed)
torch.manual_seed(seed=seed)
torch.cuda.manual_seed_all(seed=seed)
torch.cuda.manual_seed_all(seed=seed)

@ -74,7 +74,10 @@ class WeightMap:
def cast_weights(self, source_weights) -> dict[str, "Tensor"]:
converted_state_dict: dict[str, "Tensor"] = {}
for source_key in source_weights:
source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1)
try:
source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1)
except ValueError:
continue
# handle aliases
source_prefix = self.source_aliases.get(source_prefix, source_prefix)
try:

Loading…
Cancel
Save