feature: IPAdapter sd1.5

pull/477/head
Bryce 3 months ago
parent 2e5e20c0d5
commit 20d0193654

@ -276,6 +276,12 @@ def generate_single_image(
)
controlnets.append((controlnet, control_image_t))
if prompt.image_prompt:
sd.set_image_prompt(
prompt.image_prompt,
scale=prompt.image_prompt_strength,
model_type="plus",
)
for controlnet, control_image_t in controlnets:
controlnet.set_controlnet_condition(
control_image_t.to(device=sd.unet.device, dtype=sd.unet.dtype)
@ -293,6 +299,15 @@ def generate_single_image(
sd.set_inference_steps(prompt.steps, first_step=first_step)
if hasattr(sd, "mask_latents") and mask_image is not None:
# import numpy as np
# init_size = init_image.size
# noise_image = Image.fromarray(np.random.randint(0, 255, (init_size[1], init_size[0], 3), dtype=np.uint8))
# masked_image = Image.composite(init_image, noise_image, mask_image)
masked_image = Image.composite(
init_image, mask_image.convert("RGB"), mask_image
)
result_images["masked_image"] = masked_image
sd.set_inpainting_conditions(
target_image=init_image,
mask=ImageOps.invert(mask_image),
@ -304,12 +319,6 @@ def generate_single_image(
sd.mask_latents = sd.mask_latents.to(
dtype=sd.unet.dtype, device=sd.unet.device
)
if prompt.image_prompt:
sd.set_image_prompt(
prompt.image_prompt,
scale=prompt.image_prompt_strength,
model_type="plus",
)
if init_latent is not None:
noise_step = noise_step if noise_step is not None else first_step

@ -397,6 +397,8 @@ IP_ADAPTER_WEIGHT_LOCATIONS = {
"normal": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/sdxl_models/ip-adapter_sdxl_vit-h.safetensors",
},
}
SD21_UNCLIP_WEIGHTS_URL = "https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/resolve/e99f66a92bdcd1b0fb0d4b6a9b81b3b37d8bea44/image_encoder/model.fp16.safetensors"
SOLVER_TYPE_NAMES = [s.aliases[0] for s in SOLVER_CONFIGS]

@ -145,7 +145,7 @@ class DPT(BaseModel):
class DPTDepthModel(DPT):
def __init__(self, path=None, non_negative=True, **kwargs):
features = kwargs.pop("features", 256)
features = kwargs.get("features", 256)
head_features_1 = kwargs.pop("head_features_1", features)
head_features_2 = kwargs.pop("head_features_2", 32)

@ -24,7 +24,10 @@ from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpol
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
CLIPTextEncoderL,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion import SDXLIPAdapter
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
SD1IPAdapter,
SDXLIPAdapter,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import (
TLatentDiffusionModel,
)
@ -59,7 +62,9 @@ from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusio
)
from imaginairy.weight_management.conversion import cast_weights
from imaginairy.weight_management.translators import (
diffusers_ip_adapter_plus_sd15_to_refiners_translator,
diffusers_ip_adapter_plus_sdxl_to_refiners_translator,
diffusers_ip_adapter_sd15_to_refiners_translator,
diffusers_ip_adapter_sdxl_to_refiners_translator,
transformers_image_encoder_to_refiners_translator,
)
@ -114,7 +119,66 @@ class TileModeMixin(nn.Module):
m.padding_y = (0, 0, rprt[2], rprt[3]) # type: ignore
class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1):
class SD1ImagePromptMixin(nn.Module):
def _get_ip_adapter(self, model_type: str):
valid_model_types = ["normal", "plus", "plus-face"]
if model_type not in valid_model_types:
msg = f"IP Adapter model_type must be one of {valid_model_types}"
raise ValueError(msg)
ip_adapter_weights_path = get_cached_url_path(
config.IP_ADAPTER_WEIGHT_LOCATIONS["sd15"][model_type]
)
clip_image_weights_path = get_cached_url_path(config.SD21_UNCLIP_WEIGHTS_URL)
if "plus" in model_type:
ip_adapter_weight_translator = (
diffusers_ip_adapter_plus_sd15_to_refiners_translator()
)
else:
ip_adapter_weight_translator = (
diffusers_ip_adapter_sd15_to_refiners_translator()
)
clip_image_weight_translator = (
transformers_image_encoder_to_refiners_translator()
)
ip_adapter = SD1IPAdapter(
target=self.unet,
weights=ip_adapter_weight_translator.load_and_translate_weights(
ip_adapter_weights_path
),
fine_grained="plus" in model_type,
)
ip_adapter.clip_image_encoder.load_state_dict(
clip_image_weight_translator.load_and_translate_weights(
clip_image_weights_path
),
assign=True,
)
ip_adapter.to(device=self.unet.device, dtype=self.unet.dtype)
ip_adapter.clip_image_encoder.to(device=self.unet.device, dtype=self.unet.dtype)
return ip_adapter
def set_image_prompt(
self, images: list[Image.Image], scale: float, model_type: str = "normal"
):
ip_adapter = self._get_ip_adapter(model_type)
ip_adapter.inject()
ip_adapter.set_scale(scale)
image_embeddings = []
for image in images:
image_embedding = ip_adapter.compute_clip_image_embedding(
ip_adapter.preprocess_image(image).to(device=self.unet.device)
)
image_embeddings.append(image_embedding)
clip_image_embedding = sum(image_embeddings) / len(image_embeddings)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
class StableDiffusion_1(TileModeMixin, SD1ImagePromptMixin, RefinerStableDiffusion_1):
def __init__(
self,
unet: SD1UNet | None = None,
@ -319,9 +383,7 @@ class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
ip_adapter_weights_path = get_cached_url_path(
config.IP_ADAPTER_WEIGHT_LOCATIONS["sdxl"][model_type]
)
clip_image_weights_path = get_cached_url_path(
"https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/resolve/main/image_encoder/model.fp16.safetensors"
)
clip_image_weights_path = get_cached_url_path(config.SD21_UNCLIP_WEIGHTS_URL)
if "plus" in model_type:
ip_adapter_weight_translator = (
diffusers_ip_adapter_plus_sdxl_to_refiners_translator()
@ -391,7 +453,9 @@ class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
return clip_text_embedding, pooled_text_embedding # type: ignore
class StableDiffusion_1_Inpainting(TileModeMixin, RefinerStableDiffusion_1_Inpainting):
class StableDiffusion_1_Inpainting(
TileModeMixin, SD1ImagePromptMixin, RefinerStableDiffusion_1_Inpainting
):
def compute_self_attention_guidance(
self,
x: Tensor,
@ -423,7 +487,17 @@ class StableDiffusion_1_Inpainting(TileModeMixin, RefinerStableDiffusion_1_Inpai
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
dim=1,
)
degraded_noise = self.unet(x)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context[
"clip_image_embedding"
].chunk(2)
degraded_noise = self.unet(x)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(x)
return sag.scale * (noise - degraded_noise)

@ -324,41 +324,6 @@ InpaintMethod = Literal["finetune", "control"]
class ImaginePrompt(BaseModel, protected_namespaces=()):
"""
The ImaginePrompt class is used for configuring and generating image prompts.
Attributes:
prompt (str | WeightedPrompt | list[WeightedPrompt] | list[str] | None, optional): Primary prompt for the image generation.
negative_prompt (str | WeightedPrompt | list[WeightedPrompt] | list[str] | None, optional): Prompt specifying what to avoid in the image.
prompt_strength (float, optional): Strength of the influence of the prompt on the output.
init_image (LazyLoadingImage, optional): Initial image to base the generation on.
init_image_strength (float, optional): Strength of the influence of the initial image.
image_prompt (LazyLoadingImage, optional): Image to be used as part of the prompt using IP-Adapter.
image_prompt_strength (float, optional): Strength of the influence of the prompt_image.
control_inputs (List[ControlInput], optional): Additional control inputs for image generation.
mask_prompt (str, optional): Mask prompt for selective area generation.
mask_image (LazyLoadingImage, optional): Image used for masking.
mask_mode (MaskMode | str): Mode of masking operation.
mask_modify_original (bool): Flag to modify the original image with mask.
outpaint (str, optional): Outpainting string for extending image boundaries.
model_weights (str): Weights configuration for the generation model.
solver_type (str): Type of solver used for image generation.
seed (int, optional): Seed for random number generator.
steps (int, optional): Number of steps for the generation process.
size (int | str | tuple[int, int], optional): Size of the generated image.
upscale (bool): Flag to enable upscaling of the generated image.
fix_faces (bool): Flag to apply face fixing in the generation.
fix_faces_fidelity (float, optional): Fidelity of face fixing.
conditioning (str, optional): Additional conditioning string.
tile_mode (str): Mode of tiling for the image.
allow_compose_phase (bool): Flag to allow composition phase in generation.
is_intermediate (bool): Flag for intermediate image processing.
collect_progress_latents (bool): Flag to collect progress latents.
caption_text (str): Caption text for the image.
composition_strength (float, optional): Strength of the composition effect.
inpaint_method (InpaintMethod): Method used for inpainting.
"""
model_config = ConfigDict(extra="forbid", validate_assignment=True)
prompt: List[WeightedPrompt] = Field(default=None, validate_default=True) # type: ignore

Loading…
Cancel
Save