mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
Fix checking for nonexistent keys
This commit is contained in:
parent
16fb547960
commit
9cb4c721e7
@ -70,11 +70,12 @@ def load_pretrained_block(
|
|||||||
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
|
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
|
||||||
|
|
||||||
for param_name, _ in block.named_parameters():
|
for param_name, _ in block.named_parameters():
|
||||||
assert param_name in state_dict, f"{param_name} not in state dict"
|
if param_name != "self_attn.qkv_proj.weight":
|
||||||
param = state_dict[param_name]
|
assert param_name in state_dict, f"{param_name} not in state dict"
|
||||||
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
param = state_dict[param_name]
|
||||||
param = param.to(torch_dtype)
|
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||||
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
param = param.to(torch_dtype)
|
||||||
|
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
||||||
|
|
||||||
logger.info(f"Loaded {model_name} block {block_index}")
|
logger.info(f"Loaded {model_name} block {block_index}")
|
||||||
logger.debug(f"Details: {report}")
|
logger.debug(f"Details: {report}")
|
||||||
|
Loading…
Reference in New Issue
Block a user