|
|
|
@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field
|
|
|
|
|
from typing import Dict
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from refiners.fluxion import load_from_safetensors
|
|
|
|
|
from safetensors import safe_open
|
|
|
|
|
from torch import device as Device
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@ -29,7 +29,8 @@ class WeightTranslationMap:
|
|
|
|
|
source_weights = torch.load(source_path, map_location="cpu")
|
|
|
|
|
|
|
|
|
|
elif extension in ["safetensors"]:
|
|
|
|
|
source_weights = load_from_safetensors(source_path, device=device)
|
|
|
|
|
with safe_open(source_path, framework="pt", device=device) as f: # type: ignore
|
|
|
|
|
source_weights = {k: f.get_tensor(k) for k in f.keys()} # noqa
|
|
|
|
|
else:
|
|
|
|
|
msg = f"Unsupported extension {extension}"
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
@ -79,10 +80,30 @@ class WeightTranslationMap:
|
|
|
|
|
return cls(**d)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_nan_path(path: str, device):
|
|
|
|
|
from safetensors import safe_open
|
|
|
|
|
|
|
|
|
|
with safe_open(path, framework="pt", device=device) as f: # type: ignore
|
|
|
|
|
for k in f.keys(): # noqa
|
|
|
|
|
if torch.any(torch.isnan(f.get_tensor(k))):
|
|
|
|
|
print(f"Found nan values in {k} of {path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def translate_weights(
|
|
|
|
|
source_weights: TensorDict, weight_map: WeightTranslationMap
|
|
|
|
|
) -> TensorDict:
|
|
|
|
|
new_state_dict: TensorDict = {}
|
|
|
|
|
# check source weights for nan
|
|
|
|
|
for k, v in source_weights.items():
|
|
|
|
|
nan_count = torch.sum(torch.isnan(v)).item()
|
|
|
|
|
if nan_count:
|
|
|
|
|
msg = (
|
|
|
|
|
f"Found {nan_count} nan values in {k} of source state dict."
|
|
|
|
|
" This could indicate the source weights are corrupted and "
|
|
|
|
|
"need to be re-downloaded. "
|
|
|
|
|
)
|
|
|
|
|
logger.warning(msg)
|
|
|
|
|
|
|
|
|
|
# print(f"Translating {len(source_weights)} weights")
|
|
|
|
|
# print(f"Using {len(weight_map.name_map)} name mappings")
|
|
|
|
|
# print(source_weights.keys())
|
|
|
|
@ -142,7 +163,7 @@ def translate_weights(
|
|
|
|
|
|
|
|
|
|
if source_weights:
|
|
|
|
|
msg = f"Unmapped keys: {list(source_weights.keys())}"
|
|
|
|
|
print(msg)
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
for k in source_weights:
|
|
|
|
|
if isinstance(source_weights[k], torch.Tensor):
|
|
|
|
|
print(f" {k}: {source_weights[k].shape}")
|
|
|
|
@ -154,6 +175,15 @@ def translate_weights(
|
|
|
|
|
if key in new_state_dict:
|
|
|
|
|
new_state_dict[key] = new_state_dict[key].reshape(new_shape)
|
|
|
|
|
|
|
|
|
|
# check for nan values
|
|
|
|
|
for k in list(new_state_dict.keys()):
|
|
|
|
|
v = new_state_dict[k]
|
|
|
|
|
nan_count = torch.sum(torch.isnan(v)).item()
|
|
|
|
|
if nan_count:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Found {nan_count} nan values in {k} of converted state dict."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return new_state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|