Create dummy data when materializing qkv_proj

This commit is contained in:
Max Ryabinin 2023-09-03 19:20:07 +03:00
parent 9cb4c721e7
commit 4159e557bf

View File

@ -76,6 +76,10 @@ def load_pretrained_block(
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
param = param.to(torch_dtype) param = param.to(torch_dtype)
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
else:
cur_block = getattr(block, param_name)
dummy_value = torch.empty_like(cur_block, device="cpu")
set_module_tensor_to_device(block, param_name, "cpu", dummy_value)
logger.info(f"Loaded {model_name} block {block_index}") logger.info(f"Loaded {model_name} block {block_index}")
logger.debug(f"Details: {report}") logger.debug(f"Details: {report}")