From 9cb4c721e79ccb5e20fec449e62d092c271c95fc Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Sep 2023 01:55:50 +0300 Subject: [PATCH] Fix checking for nonexistent keys --- src/petals/server/from_pretrained.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 85617e7..c52ad4b 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -70,11 +70,12 @@ def load_pretrained_block( assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}" for param_name, _ in block.named_parameters(): - assert param_name in state_dict, f"{param_name} not in state dict" - param = state_dict[param_name] - if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): - param = param.to(torch_dtype) - set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) + if param_name != "self_attn.qkv_proj.weight": + assert param_name in state_dict, f"{param_name} not in state dict" + param = state_dict[param_name] + if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + param = param.to(torch_dtype) + set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) logger.info(f"Loaded {model_name} block {block_index}") logger.debug(f"Details: {report}")