|
|
|
@ -256,7 +256,68 @@ class StableDiffusion_1(TileModeMixin, SD1ImagePromptMixin, RefinerStableDiffusi
|
|
|
|
|
return conditioning
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
|
|
|
|
|
class SDXLImagePromptMixin(nn.Module):
|
|
|
|
|
def _get_ip_adapter(self, model_type: str):
|
|
|
|
|
valid_model_types = ["normal", "plus", "plus-face"]
|
|
|
|
|
if model_type not in valid_model_types:
|
|
|
|
|
msg = f"IP Adapter model_type must be one of {valid_model_types}"
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
|
|
ip_adapter_weights_path = get_cached_url_path(
|
|
|
|
|
config.IP_ADAPTER_WEIGHT_LOCATIONS["sdxl"][model_type]
|
|
|
|
|
)
|
|
|
|
|
clip_image_weights_path = get_cached_url_path(config.SD21_UNCLIP_WEIGHTS_URL)
|
|
|
|
|
if "plus" in model_type:
|
|
|
|
|
ip_adapter_weight_translator = (
|
|
|
|
|
diffusers_ip_adapter_plus_sdxl_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
ip_adapter_weight_translator = (
|
|
|
|
|
diffusers_ip_adapter_sdxl_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
clip_image_weight_translator = (
|
|
|
|
|
transformers_image_encoder_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
ip_adapter = SDXLIPAdapter(
|
|
|
|
|
target=self.unet,
|
|
|
|
|
weights=ip_adapter_weight_translator.load_and_translate_weights(
|
|
|
|
|
ip_adapter_weights_path
|
|
|
|
|
),
|
|
|
|
|
fine_grained="plus" in model_type,
|
|
|
|
|
)
|
|
|
|
|
ip_adapter.clip_image_encoder.load_state_dict(
|
|
|
|
|
clip_image_weight_translator.load_and_translate_weights(
|
|
|
|
|
clip_image_weights_path
|
|
|
|
|
),
|
|
|
|
|
assign=True,
|
|
|
|
|
)
|
|
|
|
|
ip_adapter.to(device=self.unet.device, dtype=self.unet.dtype)
|
|
|
|
|
ip_adapter.clip_image_encoder.to(device=self.unet.device, dtype=self.unet.dtype)
|
|
|
|
|
return ip_adapter
|
|
|
|
|
|
|
|
|
|
def set_image_prompt(
|
|
|
|
|
self, images: list[Image.Image], scale: float, model_type: str = "normal"
|
|
|
|
|
):
|
|
|
|
|
ip_adapter = self._get_ip_adapter(model_type)
|
|
|
|
|
ip_adapter.inject()
|
|
|
|
|
|
|
|
|
|
ip_adapter.set_scale(scale)
|
|
|
|
|
image_embeddings = []
|
|
|
|
|
for image in images:
|
|
|
|
|
image_embedding = ip_adapter.compute_clip_image_embedding(
|
|
|
|
|
ip_adapter.preprocess_image(image).to(device=self.unet.device)
|
|
|
|
|
)
|
|
|
|
|
image_embeddings.append(image_embedding)
|
|
|
|
|
|
|
|
|
|
clip_image_embedding = sum(image_embeddings) / len(image_embeddings)
|
|
|
|
|
|
|
|
|
|
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StableDiffusion_XL(
|
|
|
|
|
TileModeMixin, SDXLImagePromptMixin, RefinerStableDiffusion_XL
|
|
|
|
|
):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
unet: SDXLUNet | None = None,
|
|
|
|
@ -374,63 +435,6 @@ class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _get_ip_adapter(self, model_type: str):
|
|
|
|
|
valid_model_types = ["normal", "plus", "plus-face"]
|
|
|
|
|
if model_type not in valid_model_types:
|
|
|
|
|
msg = f"IP Adapter model_type must be one of {valid_model_types}"
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
|
|
ip_adapter_weights_path = get_cached_url_path(
|
|
|
|
|
config.IP_ADAPTER_WEIGHT_LOCATIONS["sdxl"][model_type]
|
|
|
|
|
)
|
|
|
|
|
clip_image_weights_path = get_cached_url_path(config.SD21_UNCLIP_WEIGHTS_URL)
|
|
|
|
|
if "plus" in model_type:
|
|
|
|
|
ip_adapter_weight_translator = (
|
|
|
|
|
diffusers_ip_adapter_plus_sdxl_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
ip_adapter_weight_translator = (
|
|
|
|
|
diffusers_ip_adapter_sdxl_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
clip_image_weight_translator = (
|
|
|
|
|
transformers_image_encoder_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
ip_adapter = SDXLIPAdapter(
|
|
|
|
|
target=self.unet,
|
|
|
|
|
weights=ip_adapter_weight_translator.load_and_translate_weights(
|
|
|
|
|
ip_adapter_weights_path
|
|
|
|
|
),
|
|
|
|
|
fine_grained="plus" in model_type,
|
|
|
|
|
)
|
|
|
|
|
ip_adapter.clip_image_encoder.load_state_dict(
|
|
|
|
|
clip_image_weight_translator.load_and_translate_weights(
|
|
|
|
|
clip_image_weights_path
|
|
|
|
|
),
|
|
|
|
|
assign=True,
|
|
|
|
|
)
|
|
|
|
|
ip_adapter.to(device=self.unet.device, dtype=self.unet.dtype)
|
|
|
|
|
ip_adapter.clip_image_encoder.to(device=self.unet.device, dtype=self.unet.dtype)
|
|
|
|
|
return ip_adapter
|
|
|
|
|
|
|
|
|
|
def set_image_prompt(
|
|
|
|
|
self, images: list[Image.Image], scale: float, model_type: str = "normal"
|
|
|
|
|
):
|
|
|
|
|
ip_adapter = self._get_ip_adapter(model_type)
|
|
|
|
|
ip_adapter.inject()
|
|
|
|
|
|
|
|
|
|
ip_adapter.set_scale(scale)
|
|
|
|
|
image_embeddings = []
|
|
|
|
|
for image in images:
|
|
|
|
|
image_embedding = ip_adapter.compute_clip_image_embedding(
|
|
|
|
|
ip_adapter.preprocess_image(image).to(device=self.unet.device)
|
|
|
|
|
)
|
|
|
|
|
image_embeddings.append(image_embedding)
|
|
|
|
|
|
|
|
|
|
clip_image_embedding = sum(image_embeddings) / len(image_embeddings)
|
|
|
|
|
|
|
|
|
|
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
|
|
|
|
|
|
|
|
|
def prompts_to_embeddings(
|
|
|
|
|
self, prompts: List[WeightedPrompt]
|
|
|
|
|
) -> tuple[Tensor, Tensor]:
|
|
|
|
@ -659,7 +663,17 @@ class StableDiffusion_XL_Inpainting(StableDiffusion_XL):
|
|
|
|
|
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
|
|
|
|
|
dim=1,
|
|
|
|
|
)
|
|
|
|
|
degraded_noise = self.unet(x)
|
|
|
|
|
if "ip_adapter" in self.unet.provider.contexts:
|
|
|
|
|
# this implementation is a bit hacky, it should be refactored in the future
|
|
|
|
|
ip_adapter_context = self.unet.use_context("ip_adapter")
|
|
|
|
|
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
|
|
|
|
|
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context[
|
|
|
|
|
"clip_image_embedding"
|
|
|
|
|
].chunk(2)
|
|
|
|
|
degraded_noise = self.unet(x)
|
|
|
|
|
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
|
|
|
|
|
else:
|
|
|
|
|
degraded_noise = self.unet(x)
|
|
|
|
|
|
|
|
|
|
return sag.scale * (noise - degraded_noise)
|
|
|
|
|
|
|
|
|
|