feature: IPAdapter SDXL

pull/477/head
Bryce 3 months ago
parent 20d0193654
commit a81afc8bf0

@ -52,6 +52,8 @@ def edit_cmd(
prompt,
negative_prompt,
prompt_strength,
image_prompt,
image_prompt_strength,
outdir,
output_file_extension,
repeats,
@ -108,6 +110,8 @@ def edit_cmd(
prompt_strength=prompt_strength,
init_image=image_paths,
init_image_strength=image_strength,
image_prompt=image_prompt,
image_prompt_strength=image_prompt_strength,
outdir=outdir,
output_file_extension=output_file_extension,
repeats=repeats,

@ -256,7 +256,68 @@ class StableDiffusion_1(TileModeMixin, SD1ImagePromptMixin, RefinerStableDiffusi
return conditioning
class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
class SDXLImagePromptMixin(nn.Module):
def _get_ip_adapter(self, model_type: str):
valid_model_types = ["normal", "plus", "plus-face"]
if model_type not in valid_model_types:
msg = f"IP Adapter model_type must be one of {valid_model_types}"
raise ValueError(msg)
ip_adapter_weights_path = get_cached_url_path(
config.IP_ADAPTER_WEIGHT_LOCATIONS["sdxl"][model_type]
)
clip_image_weights_path = get_cached_url_path(config.SD21_UNCLIP_WEIGHTS_URL)
if "plus" in model_type:
ip_adapter_weight_translator = (
diffusers_ip_adapter_plus_sdxl_to_refiners_translator()
)
else:
ip_adapter_weight_translator = (
diffusers_ip_adapter_sdxl_to_refiners_translator()
)
clip_image_weight_translator = (
transformers_image_encoder_to_refiners_translator()
)
ip_adapter = SDXLIPAdapter(
target=self.unet,
weights=ip_adapter_weight_translator.load_and_translate_weights(
ip_adapter_weights_path
),
fine_grained="plus" in model_type,
)
ip_adapter.clip_image_encoder.load_state_dict(
clip_image_weight_translator.load_and_translate_weights(
clip_image_weights_path
),
assign=True,
)
ip_adapter.to(device=self.unet.device, dtype=self.unet.dtype)
ip_adapter.clip_image_encoder.to(device=self.unet.device, dtype=self.unet.dtype)
return ip_adapter
def set_image_prompt(
self, images: list[Image.Image], scale: float, model_type: str = "normal"
):
ip_adapter = self._get_ip_adapter(model_type)
ip_adapter.inject()
ip_adapter.set_scale(scale)
image_embeddings = []
for image in images:
image_embedding = ip_adapter.compute_clip_image_embedding(
ip_adapter.preprocess_image(image).to(device=self.unet.device)
)
image_embeddings.append(image_embedding)
clip_image_embedding = sum(image_embeddings) / len(image_embeddings)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
class StableDiffusion_XL(
TileModeMixin, SDXLImagePromptMixin, RefinerStableDiffusion_XL
):
def __init__(
self,
unet: SDXLUNet | None = None,
@ -374,63 +435,6 @@ class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
),
}
def _get_ip_adapter(self, model_type: str):
valid_model_types = ["normal", "plus", "plus-face"]
if model_type not in valid_model_types:
msg = f"IP Adapter model_type must be one of {valid_model_types}"
raise ValueError(msg)
ip_adapter_weights_path = get_cached_url_path(
config.IP_ADAPTER_WEIGHT_LOCATIONS["sdxl"][model_type]
)
clip_image_weights_path = get_cached_url_path(config.SD21_UNCLIP_WEIGHTS_URL)
if "plus" in model_type:
ip_adapter_weight_translator = (
diffusers_ip_adapter_plus_sdxl_to_refiners_translator()
)
else:
ip_adapter_weight_translator = (
diffusers_ip_adapter_sdxl_to_refiners_translator()
)
clip_image_weight_translator = (
transformers_image_encoder_to_refiners_translator()
)
ip_adapter = SDXLIPAdapter(
target=self.unet,
weights=ip_adapter_weight_translator.load_and_translate_weights(
ip_adapter_weights_path
),
fine_grained="plus" in model_type,
)
ip_adapter.clip_image_encoder.load_state_dict(
clip_image_weight_translator.load_and_translate_weights(
clip_image_weights_path
),
assign=True,
)
ip_adapter.to(device=self.unet.device, dtype=self.unet.dtype)
ip_adapter.clip_image_encoder.to(device=self.unet.device, dtype=self.unet.dtype)
return ip_adapter
def set_image_prompt(
self, images: list[Image.Image], scale: float, model_type: str = "normal"
):
ip_adapter = self._get_ip_adapter(model_type)
ip_adapter.inject()
ip_adapter.set_scale(scale)
image_embeddings = []
for image in images:
image_embedding = ip_adapter.compute_clip_image_embedding(
ip_adapter.preprocess_image(image).to(device=self.unet.device)
)
image_embeddings.append(image_embedding)
clip_image_embedding = sum(image_embeddings) / len(image_embeddings)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
def prompts_to_embeddings(
self, prompts: List[WeightedPrompt]
) -> tuple[Tensor, Tensor]:
@ -659,7 +663,17 @@ class StableDiffusion_XL_Inpainting(StableDiffusion_XL):
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
dim=1,
)
degraded_noise = self.unet(x)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context[
"clip_image_embedding"
].chunk(2)
degraded_noise = self.unet(x)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(x)
return sag.scale * (noise - degraded_noise)

Loading…
Cancel
Save