Fix checking for nonexistent keys

pull/498/head
Max Ryabinin 9 months ago
parent 16fb547960
commit 9cb4c721e7

@ -70,6 +70,7 @@ def load_pretrained_block(
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
for param_name, _ in block.named_parameters():
if param_name != "self_attn.qkv_proj.weight":
assert param_name in state_dict, f"{param_name} not in state dict"
param = state_dict[param_name]
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):

Loading…
Cancel
Save