Fix checking for nonexistent keys

This commit is contained in:
Max Ryabinin 2023-09-03 01:55:50 +03:00
parent 16fb547960
commit 9cb4c721e7

View File

@ -70,11 +70,12 @@ def load_pretrained_block(
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}" assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
for param_name, _ in block.named_parameters(): for param_name, _ in block.named_parameters():
assert param_name in state_dict, f"{param_name} not in state dict" if param_name != "self_attn.qkv_proj.weight":
param = state_dict[param_name] assert param_name in state_dict, f"{param_name} not in state dict"
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): param = state_dict[param_name]
param = param.to(torch_dtype) if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) param = param.to(torch_dtype)
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
logger.info(f"Loaded {model_name} block {block_index}") logger.info(f"Loaded {model_name} block {block_index}")
logger.debug(f"Details: {report}") logger.debug(f"Details: {report}")