imaginAIry/scripts/controlnet_convert.py
2024-01-14 16:50:17 -08:00

131 lines
4.6 KiB
Python

import os
import torch
from safetensors.torch import load_file, save_file
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.paths import PKG_ROOT
sd15_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt"
def main():
"""Script to convert the controlnet weights into diffs that are ready to be applied to any s1.5 weights."""
control_types = [
"canny",
"depth",
"hed",
"mlsd",
"normal",
"openpose",
"scribble",
"seg",
]
url_template = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/control_sd15_{control_type}.pth"
urls = {
control_type: url_template.format(control_type=control_type)
for control_type in control_types
}
dest = f"{PKG_ROOT}/../other/weights/controlnet"
for control_type, url in urls.items():
print(f"Downloading {control_type} weights from {url}")
out_filepath = extract_controlnet_essence(
control_type=control_type,
controlnet_url=url,
dest_folder=dest,
)
sd15_path = get_cached_url_path(sd15_url)
sd15_state_dict = torch.load(sd15_path, map_location="cpu")
sd15_state_dict = sd15_state_dict.get("state_dict", sd15_state_dict)
reconstituted_controlnet_statedict = apply_controlnet(
base_state_dict=sd15_state_dict,
controlnet_state_dict=load_file(out_filepath),
)
controlnet_path = get_cached_url_path(url)
import time
time.sleep(1)
controlnet_statedict = torch.load(controlnet_path, map_location="cpu")
print("\n\nComparing reconstructed controlnet with original")
for k in controlnet_statedict:
if k not in reconstituted_controlnet_statedict:
print(f"Key {k} not in reconstituted")
elif (
controlnet_statedict[k].shape
!= reconstituted_controlnet_statedict[k].shape
):
print(f"Key {k} has different shape")
print(controlnet_statedict[k].shape)
print(reconstituted_controlnet_statedict[k].shape)
else:
diff = controlnet_statedict[k] - reconstituted_controlnet_statedict[k]
diff_sum = torch.abs(diff).sum()
if diff_sum > 3.467949682089966e-05:
print(f"Key {k} has different values {diff_sum}")
def extract_controlnet_essence(control_type, controlnet_url, dest_folder):
print(f"Extracting essence of {control_type} weights from {controlnet_url}")
outpath = f"{dest_folder}/controlnet15_diff_{control_type}.safetensors"
if os.path.exists(outpath):
print(f"File {outpath} already exists, skipping")
return outpath
os.makedirs(dest_folder, exist_ok=True)
sd15_path = get_cached_url_path(sd15_url)
controlnet_path = get_cached_url_path(controlnet_url)
print(f"sd15_path: {sd15_path}")
print(f"controlnet_path: {controlnet_path}")
sd15_state_dict = torch.load(sd15_path, map_location="cpu")
sd15_state_dict = sd15_state_dict.get("state_dict", sd15_state_dict)
controlnet_state_dict = torch.load(controlnet_path, map_location="cpu")
controlnet_state_dict = controlnet_state_dict.get(
"state_dict", controlnet_state_dict
)
final_state_dict = {}
skip_prefixes = ("first_stage_model", "cond_stage_model")
for key in controlnet_state_dict:
if key.startswith(skip_prefixes):
continue
if key.startswith("control_"):
sd15_key_name = "model.diffusion_" + key[len("control_") :]
else:
sd15_key_name = key
if sd15_key_name in sd15_state_dict:
diff_value = controlnet_state_dict[key] - sd15_state_dict[sd15_key_name]
final_state_dict[key] = diff_value
else:
final_state_dict[key] = controlnet_state_dict[key]
save_file(final_state_dict, outpath)
return outpath
def apply_controlnet(base_state_dict, controlnet_state_dict):
for key in controlnet_state_dict:
if key.startswith("control_"):
sd15_key_name = "model.diffusion_" + key[len("control_") :]
else:
sd15_key_name = key
if sd15_key_name in base_state_dict:
b = base_state_dict[sd15_key_name]
c_diff = controlnet_state_dict[key]
new_c = b + c_diff
base_state_dict[key] = new_c
else:
base_state_dict[key] = controlnet_state_dict[key]
return base_state_dict
if __name__ == "__main__":
main()