imaginAIry/imaginairy/weight_management/translators.py

135 lines
4.8 KiB
Python
Raw Normal View History

import os
from functools import lru_cache
from torch import device as Device
from imaginairy.weight_management.translation import TensorDict, WeightTranslationMap
_current_folder = os.path.dirname(os.path.abspath(__file__))
weight_map_folder = os.path.join(_current_folder, "weight_maps")
@lru_cache
def load_weight_map(map_name: str) -> WeightTranslationMap:
map_path = os.path.join(weight_map_folder, f"{map_name}.weightmap.json")
return WeightTranslationMap.load(map_path)
def transformers_text_encoder_to_refiners_translator() -> WeightTranslationMap:
return load_weight_map("Transformers-ClipTextEncoder")
def transformers_image_encoder_to_refiners_translator() -> WeightTranslationMap:
return load_weight_map("Transformers-ClipImageEncoder-SD21")
def diffusers_autoencoder_kl_to_refiners_translator() -> WeightTranslationMap:
return load_weight_map("Diffusers-AutoencoderKL-SD")
def diffusers_unet_sd15_to_refiners_translator() -> WeightTranslationMap:
return load_weight_map("Diffusers-UNet-SD15")
def diffusers_unet_sdxl_to_refiners_translator() -> WeightTranslationMap:
return load_weight_map("Diffusers-UNet-SDXL")
def informative_drawings_to_refiners_translator() -> WeightTranslationMap:
return load_weight_map("InformativeDrawings")
def diffusers_controlnet_sd15_to_refiners_translator() -> WeightTranslationMap:
return load_weight_map("Diffusers-Controlnet-SD15")
def diffusers_ip_adapter_sd15_to_refiners_translator() -> WeightTranslationMap:
return load_weight_map("Diffusers-IPAdapter-SD15")
def diffusers_ip_adapter_sdxl_to_refiners_translator() -> WeightTranslationMap:
return load_weight_map("Diffusers-IPAdapter-SDXL")
def diffusers_ip_adapter_plus_sd15_to_refiners_translator() -> WeightTranslationMap:
return load_weight_map("Diffusers-IPAdapterPlus-SD15")
def diffusers_ip_adapter_plus_sdxl_to_refiners_translator() -> WeightTranslationMap:
return load_weight_map("Diffusers-IPAdapterPlus-SDXL")
def diffusers_t2i_adapter_sd15_to_refiners_translator() -> WeightTranslationMap:
return load_weight_map("Diffusers-T2IAdapter-SD15")
def diffusers_t2i_adapter_sdxl_to_refiners_translator() -> WeightTranslationMap:
return load_weight_map("Diffusers-T2IAdapter-SDXL")
class DoubleTextEncoderTranslator:
def __init__(self):
self.translator = transformers_text_encoder_to_refiners_translator()
def load_untranslated_weights(
self,
text_encoder_l_weights_path: str,
text_encoder_g_weights_path: str,
device: Device | str = "cpu",
) -> tuple[TensorDict, TensorDict]:
text_encoder_l_weights = self.translator.load_untranslated_weights(
text_encoder_l_weights_path, device=device
)
text_encoder_g_weights = self.translator.load_untranslated_weights(
text_encoder_g_weights_path, device=device
)
return text_encoder_l_weights, text_encoder_g_weights
def load_and_translate_weights(
self,
text_encoder_l_weights_path: str,
text_encoder_g_weights_path: str,
device: Device | str = "cpu",
) -> TensorDict:
text_encoder_l_weights, text_encoder_g_weights = self.load_untranslated_weights(
text_encoder_l_weights_path, text_encoder_g_weights_path, device=device
)
return self.translate_weights(text_encoder_l_weights, text_encoder_g_weights)
def translate_weights(
self, text_encoder_l_weights: TensorDict, text_encoder_g_weights: TensorDict
) -> TensorDict:
new_sd: TensorDict = {}
text_encoder_l_weights = self.translator.translate_weights(
text_encoder_l_weights
)
text_encoder_g_weights = self.translator.translate_weights(
text_encoder_g_weights
)
for k in list(text_encoder_l_weights.keys()):
if k.startswith("TransformerLayer_12"):
text_encoder_l_weights.pop(k)
elif k.startswith("LayerNorm"):
text_encoder_l_weights.pop(k)
else:
new_key = f"Parallel.CLIPTextEncoderL.{k}"
new_sd[new_key] = text_encoder_l_weights.pop(k)
2024-03-15 18:26:12 +00:00
new_sd["Parallel.TextEncoderWithPooling.Parallel.Chain.Linear.weight"] = (
text_encoder_g_weights.pop("Linear.weight")
)
for k in list(text_encoder_g_weights.keys()):
if k.startswith("TransformerLayer_32"):
new_key = f"Parallel.TextEncoderWithPooling.Parallel.Chain.CLIPTextEncoderG.TransformerLayer{k[19:]}"
elif k.startswith("LayerNorm"):
new_key = f"Parallel.TextEncoderWithPooling.Parallel.Chain.CLIPTextEncoderG.{k}"
else:
new_key = f"Parallel.TextEncoderWithPooling.CLIPTextEncoderG.{k}"
new_sd[new_key] = text_encoder_g_weights.pop(k)
return new_sd