2023-12-28 05:52:37 +00:00
|
|
|
import os
|
|
|
|
from functools import lru_cache
|
|
|
|
|
|
|
|
from torch import device as Device
|
|
|
|
|
|
|
|
from imaginairy.weight_management.translation import TensorDict, WeightTranslationMap
|
|
|
|
|
|
|
|
_current_folder = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
weight_map_folder = os.path.join(_current_folder, "weight_maps")
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache
|
2024-01-13 21:43:15 +00:00
|
|
|
def load_weight_map(map_name: str) -> WeightTranslationMap:
|
2023-12-28 05:52:37 +00:00
|
|
|
map_path = os.path.join(weight_map_folder, f"{map_name}.weightmap.json")
|
|
|
|
return WeightTranslationMap.load(map_path)
|
|
|
|
|
|
|
|
|
|
|
|
def transformers_text_encoder_to_refiners_translator() -> WeightTranslationMap:
|
2024-01-13 21:43:15 +00:00
|
|
|
return load_weight_map("Transformers-ClipTextEncoder")
|
2023-12-28 05:52:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
def transformers_image_encoder_to_refiners_translator() -> WeightTranslationMap:
|
2024-01-13 21:43:15 +00:00
|
|
|
return load_weight_map("Transformers-ClipImageEncoder-SD21")
|
2023-12-28 05:52:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
def diffusers_autoencoder_kl_to_refiners_translator() -> WeightTranslationMap:
|
2024-01-13 21:43:15 +00:00
|
|
|
return load_weight_map("Diffusers-AutoencoderKL-SD")
|
2023-12-28 05:52:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
def diffusers_unet_sd15_to_refiners_translator() -> WeightTranslationMap:
|
2024-01-13 21:43:15 +00:00
|
|
|
return load_weight_map("Diffusers-UNet-SD15")
|
2023-12-28 05:52:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
def diffusers_unet_sdxl_to_refiners_translator() -> WeightTranslationMap:
|
2024-01-13 21:43:15 +00:00
|
|
|
return load_weight_map("Diffusers-UNet-SDXL")
|
2023-12-28 05:52:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
def informative_drawings_to_refiners_translator() -> WeightTranslationMap:
|
2024-01-13 21:43:15 +00:00
|
|
|
return load_weight_map("InformativeDrawings")
|
2023-12-28 05:52:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
def diffusers_controlnet_sd15_to_refiners_translator() -> WeightTranslationMap:
|
2024-01-13 21:43:15 +00:00
|
|
|
return load_weight_map("Diffusers-Controlnet-SD15")
|
2023-12-28 05:52:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
def diffusers_ip_adapter_sd15_to_refiners_translator() -> WeightTranslationMap:
|
2024-01-13 21:43:15 +00:00
|
|
|
return load_weight_map("Diffusers-IPAdapter-SD15")
|
2023-12-28 05:52:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
def diffusers_ip_adapter_sdxl_to_refiners_translator() -> WeightTranslationMap:
|
2024-01-13 21:43:15 +00:00
|
|
|
return load_weight_map("Diffusers-IPAdapter-SDXL")
|
2023-12-28 05:52:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
def diffusers_ip_adapter_plus_sd15_to_refiners_translator() -> WeightTranslationMap:
|
2024-01-13 21:43:15 +00:00
|
|
|
return load_weight_map("Diffusers-IPAdapterPlus-SD15")
|
2023-12-28 05:52:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
def diffusers_ip_adapter_plus_sdxl_to_refiners_translator() -> WeightTranslationMap:
|
2024-01-13 21:43:15 +00:00
|
|
|
return load_weight_map("Diffusers-IPAdapterPlus-SDXL")
|
2023-12-28 05:52:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
def diffusers_t2i_adapter_sd15_to_refiners_translator() -> WeightTranslationMap:
|
2024-01-13 21:43:15 +00:00
|
|
|
return load_weight_map("Diffusers-T2IAdapter-SD15")
|
2023-12-28 05:52:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
def diffusers_t2i_adapter_sdxl_to_refiners_translator() -> WeightTranslationMap:
|
2024-01-13 21:43:15 +00:00
|
|
|
return load_weight_map("Diffusers-T2IAdapter-SDXL")
|
2023-12-28 05:52:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
class DoubleTextEncoderTranslator:
|
|
|
|
def __init__(self):
|
|
|
|
self.translator = transformers_text_encoder_to_refiners_translator()
|
|
|
|
|
2024-01-13 21:43:15 +00:00
|
|
|
def load_untranslated_weights(
|
2023-12-28 05:52:37 +00:00
|
|
|
self,
|
|
|
|
text_encoder_l_weights_path: str,
|
|
|
|
text_encoder_g_weights_path: str,
|
|
|
|
device: Device | str = "cpu",
|
2024-01-13 21:43:15 +00:00
|
|
|
) -> tuple[TensorDict, TensorDict]:
|
|
|
|
text_encoder_l_weights = self.translator.load_untranslated_weights(
|
2023-12-28 05:52:37 +00:00
|
|
|
text_encoder_l_weights_path, device=device
|
|
|
|
)
|
2024-01-13 21:43:15 +00:00
|
|
|
text_encoder_g_weights = self.translator.load_untranslated_weights(
|
2023-12-28 05:52:37 +00:00
|
|
|
text_encoder_g_weights_path, device=device
|
|
|
|
)
|
2024-01-13 21:43:15 +00:00
|
|
|
return text_encoder_l_weights, text_encoder_g_weights
|
|
|
|
|
|
|
|
def load_and_translate_weights(
|
|
|
|
self,
|
|
|
|
text_encoder_l_weights_path: str,
|
|
|
|
text_encoder_g_weights_path: str,
|
|
|
|
device: Device | str = "cpu",
|
|
|
|
) -> TensorDict:
|
|
|
|
text_encoder_l_weights, text_encoder_g_weights = self.load_untranslated_weights(
|
|
|
|
text_encoder_l_weights_path, text_encoder_g_weights_path, device=device
|
|
|
|
)
|
|
|
|
|
2023-12-28 05:52:37 +00:00
|
|
|
return self.translate_weights(text_encoder_l_weights, text_encoder_g_weights)
|
|
|
|
|
|
|
|
def translate_weights(
|
|
|
|
self, text_encoder_l_weights: TensorDict, text_encoder_g_weights: TensorDict
|
|
|
|
) -> TensorDict:
|
|
|
|
new_sd: TensorDict = {}
|
|
|
|
|
2024-01-13 21:43:15 +00:00
|
|
|
text_encoder_l_weights = self.translator.translate_weights(
|
|
|
|
text_encoder_l_weights
|
|
|
|
)
|
|
|
|
text_encoder_g_weights = self.translator.translate_weights(
|
|
|
|
text_encoder_g_weights
|
|
|
|
)
|
|
|
|
|
2023-12-28 05:52:37 +00:00
|
|
|
for k in list(text_encoder_l_weights.keys()):
|
|
|
|
if k.startswith("TransformerLayer_12"):
|
|
|
|
text_encoder_l_weights.pop(k)
|
|
|
|
elif k.startswith("LayerNorm"):
|
|
|
|
text_encoder_l_weights.pop(k)
|
|
|
|
else:
|
|
|
|
new_key = f"Parallel.CLIPTextEncoderL.{k}"
|
|
|
|
new_sd[new_key] = text_encoder_l_weights.pop(k)
|
|
|
|
|
2024-03-15 18:26:12 +00:00
|
|
|
new_sd["Parallel.TextEncoderWithPooling.Parallel.Chain.Linear.weight"] = (
|
|
|
|
text_encoder_g_weights.pop("Linear.weight")
|
|
|
|
)
|
2023-12-28 05:52:37 +00:00
|
|
|
for k in list(text_encoder_g_weights.keys()):
|
|
|
|
if k.startswith("TransformerLayer_32"):
|
|
|
|
new_key = f"Parallel.TextEncoderWithPooling.Parallel.Chain.CLIPTextEncoderG.TransformerLayer{k[19:]}"
|
|
|
|
elif k.startswith("LayerNorm"):
|
|
|
|
new_key = f"Parallel.TextEncoderWithPooling.Parallel.Chain.CLIPTextEncoderG.{k}"
|
|
|
|
else:
|
|
|
|
new_key = f"Parallel.TextEncoderWithPooling.CLIPTextEncoderG.{k}"
|
|
|
|
|
|
|
|
new_sd[new_key] = text_encoder_g_weights.pop(k)
|
|
|
|
|
|
|
|
return new_sd
|