fix: handle unexpected keys in weights better

pull/437/head
Bryce 5 months ago committed by Bryce Drennan
parent 5b3b04b877
commit f84406f12c

@ -23,6 +23,8 @@ def imaginairy_click_context(log_level="INFO"):
yield yield
except errors_to_catch as e: except errors_to_catch as e:
logger.error(e) logger.error(e)
# import traceback
# traceback.print_exc()
def _imagine_cmd( def _imagine_cmd(

@ -1,6 +1,5 @@
import importlib import importlib
import logging import logging
import numpy as np
import platform import platform
import random import random
import re import re
@ -9,6 +8,7 @@ from contextlib import contextmanager, nullcontext
from functools import lru_cache from functools import lru_cache
from typing import Any, List, Optional from typing import Any, List, Optional
import numpy as np
import torch import torch
from torch import Tensor, autocast from torch import Tensor, autocast
from torch.nn import functional from torch.nn import functional
@ -337,6 +337,7 @@ def clear_gpu_cache():
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
def seed_everything(seed: int | None = None) -> None: def seed_everything(seed: int | None = None) -> None:
if seed is None: if seed is None:
seed = random.randint(0, 2**32 - 1) seed = random.randint(0, 2**32 - 1)
@ -344,4 +345,4 @@ def seed_everything(seed: int | None = None) -> None:
random.seed(a=seed) random.seed(a=seed)
np.random.seed(seed=seed) np.random.seed(seed=seed)
torch.manual_seed(seed=seed) torch.manual_seed(seed=seed)
torch.cuda.manual_seed_all(seed=seed) torch.cuda.manual_seed_all(seed=seed)

@ -74,7 +74,10 @@ class WeightMap:
def cast_weights(self, source_weights) -> dict[str, "Tensor"]: def cast_weights(self, source_weights) -> dict[str, "Tensor"]:
converted_state_dict: dict[str, "Tensor"] = {} converted_state_dict: dict[str, "Tensor"] = {}
for source_key in source_weights: for source_key in source_weights:
source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1) try:
source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1)
except ValueError:
continue
# handle aliases # handle aliases
source_prefix = self.source_aliases.get(source_prefix, source_prefix) source_prefix = self.source_aliases.get(source_prefix, source_prefix)
try: try:

Loading…
Cancel
Save