@ -70,6 +70,7 @@ def load_pretrained_block(
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
for param_name, _ in block.named_parameters():
if param_name != "self_attn.qkv_proj.weight":
assert param_name in state_dict, f"{param_name} not in state dict"
param = state_dict[param_name]
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):