|
|
|
@ -12,7 +12,9 @@ from torch import Tensor, device as Device, dtype as DType, nn
|
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
|
|
|
|
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
|
|
|
|
from imaginairy import config
|
|
|
|
|
from imaginairy.schema import WeightedPrompt
|
|
|
|
|
from imaginairy.utils.downloads import get_cached_url_path
|
|
|
|
|
from imaginairy.utils.feather_tile import rebuild_image, tile_image
|
|
|
|
|
from imaginairy.vendored.refiners.fluxion.layers.attentions import (
|
|
|
|
|
ScaledDotProductAttention,
|
|
|
|
@ -22,6 +24,7 @@ from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpol
|
|
|
|
|
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
|
|
|
|
|
CLIPTextEncoderL,
|
|
|
|
|
)
|
|
|
|
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion import SDXLIPAdapter
|
|
|
|
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import (
|
|
|
|
|
TLatentDiffusionModel,
|
|
|
|
|
)
|
|
|
|
@ -55,6 +58,11 @@ from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusio
|
|
|
|
|
SDXLUNet,
|
|
|
|
|
)
|
|
|
|
|
from imaginairy.weight_management.conversion import cast_weights
|
|
|
|
|
from imaginairy.weight_management.translators import (
|
|
|
|
|
diffusers_ip_adapter_plus_sdxl_to_refiners_translator,
|
|
|
|
|
diffusers_ip_adapter_sdxl_to_refiners_translator,
|
|
|
|
|
transformers_image_encoder_to_refiners_translator,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
@ -302,6 +310,65 @@ class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _get_ip_adapter(self, model_type: str):
|
|
|
|
|
valid_model_types = ["normal", "plus", "plus-face"]
|
|
|
|
|
if model_type not in valid_model_types:
|
|
|
|
|
msg = f"IP Adapter model_type must be one of {valid_model_types}"
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
|
|
ip_adapter_weights_path = get_cached_url_path(
|
|
|
|
|
config.IP_ADAPTER_WEIGHT_LOCATIONS["sdxl"][model_type]
|
|
|
|
|
)
|
|
|
|
|
clip_image_weights_path = get_cached_url_path(
|
|
|
|
|
"https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/resolve/main/image_encoder/model.fp16.safetensors"
|
|
|
|
|
)
|
|
|
|
|
if "plus" in model_type:
|
|
|
|
|
ip_adapter_weight_translator = (
|
|
|
|
|
diffusers_ip_adapter_plus_sdxl_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
ip_adapter_weight_translator = (
|
|
|
|
|
diffusers_ip_adapter_sdxl_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
clip_image_weight_translator = (
|
|
|
|
|
transformers_image_encoder_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
ip_adapter = SDXLIPAdapter(
|
|
|
|
|
target=self.unet,
|
|
|
|
|
weights=ip_adapter_weight_translator.load_and_translate_weights(
|
|
|
|
|
ip_adapter_weights_path
|
|
|
|
|
),
|
|
|
|
|
fine_grained="plus" in model_type,
|
|
|
|
|
)
|
|
|
|
|
ip_adapter.clip_image_encoder.load_state_dict(
|
|
|
|
|
clip_image_weight_translator.load_and_translate_weights(
|
|
|
|
|
clip_image_weights_path
|
|
|
|
|
),
|
|
|
|
|
assign=True,
|
|
|
|
|
)
|
|
|
|
|
ip_adapter.to(device=self.unet.device, dtype=self.unet.dtype)
|
|
|
|
|
ip_adapter.clip_image_encoder.to(device=self.unet.device, dtype=self.unet.dtype)
|
|
|
|
|
return ip_adapter
|
|
|
|
|
|
|
|
|
|
def set_image_prompt(
|
|
|
|
|
self, images: list[Image.Image], scale: float, model_type: str = "normal"
|
|
|
|
|
):
|
|
|
|
|
ip_adapter = self._get_ip_adapter(model_type)
|
|
|
|
|
ip_adapter.inject()
|
|
|
|
|
|
|
|
|
|
ip_adapter.set_scale(scale)
|
|
|
|
|
image_embeddings = []
|
|
|
|
|
for image in images:
|
|
|
|
|
image_embedding = ip_adapter.compute_clip_image_embedding(
|
|
|
|
|
ip_adapter.preprocess_image(image).to(device=self.unet.device)
|
|
|
|
|
)
|
|
|
|
|
image_embeddings.append(image_embedding)
|
|
|
|
|
|
|
|
|
|
clip_image_embedding = sum(image_embeddings) / len(image_embeddings)
|
|
|
|
|
|
|
|
|
|
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
|
|
|
|
|
|
|
|
|
def prompts_to_embeddings(
|
|
|
|
|
self, prompts: List[WeightedPrompt]
|
|
|
|
|
) -> tuple[Tensor, Tensor]:
|
|
|
|
|