diff --git a/imaginairy/api/generate_refiners.py b/imaginairy/api/generate_refiners.py index 5267686..5f0ed66 100644 --- a/imaginairy/api/generate_refiners.py +++ b/imaginairy/api/generate_refiners.py @@ -276,11 +276,18 @@ def generate_single_image( ) controlnets.append((controlnet, control_image_t)) + if prompt.image_prompt: + sd.set_image_prompt( + prompt.image_prompt, + scale=prompt.image_prompt_strength, + model_type="plus", + ) for controlnet, control_image_t in controlnets: controlnet.set_controlnet_condition( control_image_t.to(device=sd.unet.device, dtype=sd.unet.dtype) ) controlnet.inject() + if prompt.solver_type.lower() == SolverName.DPMPP: sd.scheduler = DPMSolver(num_inference_steps=prompt.steps) elif prompt.solver_type.lower() == SolverName.DDIM: @@ -292,6 +299,15 @@ def generate_single_image( sd.set_inference_steps(prompt.steps, first_step=first_step) if hasattr(sd, "mask_latents") and mask_image is not None: + # import numpy as np + # init_size = init_image.size + # noise_image = Image.fromarray(np.random.randint(0, 255, (init_size[1], init_size[0], 3), dtype=np.uint8)) + # masked_image = Image.composite(init_image, noise_image, mask_image) + + masked_image = Image.composite( + init_image, mask_image.convert("RGB"), mask_image + ) + result_images["masked_image"] = masked_image sd.set_inpainting_conditions( target_image=init_image, mask=ImageOps.invert(mask_image), diff --git a/imaginairy/cli/edit.py b/imaginairy/cli/edit.py index cdce40e..cbeae6a 100644 --- a/imaginairy/cli/edit.py +++ b/imaginairy/cli/edit.py @@ -52,6 +52,8 @@ def edit_cmd( prompt, negative_prompt, prompt_strength, + image_prompt, + image_prompt_strength, outdir, output_file_extension, repeats, @@ -108,6 +110,8 @@ def edit_cmd( prompt_strength=prompt_strength, init_image=image_paths, init_image_strength=image_strength, + image_prompt=image_prompt, + image_prompt_strength=image_prompt_strength, outdir=outdir, output_file_extension=output_file_extension, repeats=repeats, diff --git a/imaginairy/cli/imagine.py b/imaginairy/cli/imagine.py index c166366..573006b 100644 --- a/imaginairy/cli/imagine.py +++ b/imaginairy/cli/imagine.py @@ -83,6 +83,8 @@ def imagine_cmd( prompt_strength, init_image, init_image_strength, + image_prompt, + image_prompt_strength, outdir, output_file_extension, repeats, @@ -191,6 +193,8 @@ def imagine_cmd( prompt_strength=prompt_strength, init_image=init_image, init_image_strength=init_image_strength, + image_prompt=image_prompt, + image_prompt_strength=image_prompt_strength, outdir=outdir, output_file_extension=output_file_extension, repeats=repeats, diff --git a/imaginairy/cli/shared.py b/imaginairy/cli/shared.py index 220ddb4..f69257d 100644 --- a/imaginairy/cli/shared.py +++ b/imaginairy/cli/shared.py @@ -23,8 +23,10 @@ def imaginairy_click_context(log_level="INFO"): yield except errors_to_catch as e: logger.error(e) - # import traceback - # traceback.print_exc() + if log_level.upper() == "DEBUG": + import traceback + + traceback.print_exc() def _imagine_cmd( @@ -35,6 +37,8 @@ def _imagine_cmd( prompt_strength, init_image, init_image_strength, + image_prompt, + image_prompt_strength, outdir, output_file_extension, repeats, @@ -161,6 +165,14 @@ def _imagine_cmd( defaults={"negative_prompt": config.DEFAULT_NEGATIVE_PROMPT}, ) + def _img(img_str): + if img_str.startswith("http"): + return LazyLoadingImage(url=img_str) + else: + return LazyLoadingImage(filepath=img_str) + + image_prompt = [_img(i) for i in image_prompt] if image_prompt else None + for _ in range(repeats): for prompt_text in prompt_texts: if prompt_text not in prompt_expanding_iterators: @@ -186,6 +198,8 @@ def _imagine_cmd( prompt_strength=prompt_strength, init_image=_init_image, init_image_strength=init_image_strength, + image_prompt=image_prompt, + image_prompt_strength=image_prompt_strength, control_inputs=control_inputs, seed=seed, solver_type=solver, @@ -312,6 +326,19 @@ common_options = [ type=float, help="Starting image strength. Between 0 and 1.", ), + click.option( + "--image-prompt", + metavar="PATH|URL", + help="Starting image.", + multiple=True, + ), + click.option( + "--image-prompt-strength", + default=None, + show_default=False, + type=float, + help="Starting image strength. Between 0 and 1.", + ), click.option( "--outdir", default="./outputs", diff --git a/imaginairy/config.py b/imaginairy/config.py index eb945db..76bc7ec 100644 --- a/imaginairy/config.py +++ b/imaginairy/config.py @@ -380,6 +380,26 @@ SOLVER_CONFIGS = [ ), ] +_ip_adapter_commit = "92a2d51861c754afacf8b3aaf90845254b49f219" +IP_ADAPTER_WEIGHT_LOCATIONS = { + "sd15": { + "full-face": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter-full-face_sd15.safetensors", + "plus-face": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter-plus-face_sd15.safetensors", + "plus": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter-plus_sd15.safetensors", + "normal": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter_sd15.safetensors", + "light": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter_sd15_light.safetensors", + "vitg": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter_sd15_vit-G.safetensors", + }, + "sdxl": { + "plus-face": f"https://huggingface.co/h94/IP-Adapter/blob/{_ip_adapter_commit}/sdxl_models/ip-adapter-plus-face_sdxl_vit-h.safetensors", + "plus": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/sdxl_models/ip-adapter-plus_sdxl_vit-h.safetensors", + "vit-g": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/sdxl_models/ip-adapter_sdxl.safetensors", + "normal": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/sdxl_models/ip-adapter_sdxl_vit-h.safetensors", + }, +} +SD21_UNCLIP_WEIGHTS_URL = "https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/resolve/e99f66a92bdcd1b0fb0d4b6a9b81b3b37d8bea44/image_encoder/model.fp16.safetensors" + + SOLVER_TYPE_NAMES = [s.aliases[0] for s in SOLVER_CONFIGS] SOLVER_LOOKUP = {} diff --git a/imaginairy/image_prompts.py b/imaginairy/image_prompts.py new file mode 100644 index 0000000..e69de29 diff --git a/imaginairy/modules/midas/midas/dpt_depth.py b/imaginairy/modules/midas/midas/dpt_depth.py index c81e42a..090f025 100644 --- a/imaginairy/modules/midas/midas/dpt_depth.py +++ b/imaginairy/modules/midas/midas/dpt_depth.py @@ -145,7 +145,7 @@ class DPT(BaseModel): class DPTDepthModel(DPT): def __init__(self, path=None, non_negative=True, **kwargs): - features = kwargs.pop("features", 256) + features = kwargs.get("features", 256) head_features_1 = kwargs.pop("head_features_1", features) head_features_2 = kwargs.pop("head_features_2", 32) diff --git a/imaginairy/modules/refiners_sd.py b/imaginairy/modules/refiners_sd.py index e758073..232b787 100644 --- a/imaginairy/modules/refiners_sd.py +++ b/imaginairy/modules/refiners_sd.py @@ -12,7 +12,9 @@ from torch import Tensor, device as Device, dtype as DType, nn from torch.nn import functional as F import imaginairy.vendored.refiners.fluxion.layers as fl +from imaginairy import config from imaginairy.schema import WeightedPrompt +from imaginairy.utils.downloads import get_cached_url_path from imaginairy.utils.feather_tile import rebuild_image, tile_image from imaginairy.vendored.refiners.fluxion.layers.attentions import ( ScaledDotProductAttention, @@ -22,6 +24,10 @@ from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpol from imaginairy.vendored.refiners.foundationals.clip.text_encoder import ( CLIPTextEncoderL, ) +from imaginairy.vendored.refiners.foundationals.latent_diffusion import ( + SD1IPAdapter, + SDXLIPAdapter, +) from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import ( TLatentDiffusionModel, ) @@ -55,6 +61,13 @@ from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusio SDXLUNet, ) from imaginairy.weight_management.conversion import cast_weights +from imaginairy.weight_management.translators import ( + diffusers_ip_adapter_plus_sd15_to_refiners_translator, + diffusers_ip_adapter_plus_sdxl_to_refiners_translator, + diffusers_ip_adapter_sd15_to_refiners_translator, + diffusers_ip_adapter_sdxl_to_refiners_translator, + transformers_image_encoder_to_refiners_translator, +) logger = logging.getLogger(__name__) @@ -106,7 +119,66 @@ class TileModeMixin(nn.Module): m.padding_y = (0, 0, rprt[2], rprt[3]) # type: ignore -class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1): +class SD1ImagePromptMixin(nn.Module): + def _get_ip_adapter(self, model_type: str): + valid_model_types = ["normal", "plus", "plus-face"] + if model_type not in valid_model_types: + msg = f"IP Adapter model_type must be one of {valid_model_types}" + raise ValueError(msg) + + ip_adapter_weights_path = get_cached_url_path( + config.IP_ADAPTER_WEIGHT_LOCATIONS["sd15"][model_type] + ) + clip_image_weights_path = get_cached_url_path(config.SD21_UNCLIP_WEIGHTS_URL) + if "plus" in model_type: + ip_adapter_weight_translator = ( + diffusers_ip_adapter_plus_sd15_to_refiners_translator() + ) + else: + ip_adapter_weight_translator = ( + diffusers_ip_adapter_sd15_to_refiners_translator() + ) + clip_image_weight_translator = ( + transformers_image_encoder_to_refiners_translator() + ) + + ip_adapter = SD1IPAdapter( + target=self.unet, + weights=ip_adapter_weight_translator.load_and_translate_weights( + ip_adapter_weights_path + ), + fine_grained="plus" in model_type, + ) + ip_adapter.clip_image_encoder.load_state_dict( + clip_image_weight_translator.load_and_translate_weights( + clip_image_weights_path + ), + assign=True, + ) + ip_adapter.to(device=self.unet.device, dtype=self.unet.dtype) + ip_adapter.clip_image_encoder.to(device=self.unet.device, dtype=self.unet.dtype) + return ip_adapter + + def set_image_prompt( + self, images: list[Image.Image], scale: float, model_type: str = "normal" + ): + ip_adapter = self._get_ip_adapter(model_type) + ip_adapter.inject() + + ip_adapter.set_scale(scale) + image_embeddings = [] + for image in images: + image_embedding = ip_adapter.compute_clip_image_embedding( + ip_adapter.preprocess_image(image).to(device=self.unet.device) + ) + image_embeddings.append(image_embedding) + + clip_image_embedding = sum(image_embeddings) / len(image_embeddings) + + ip_adapter.set_clip_image_embedding(clip_image_embedding) + + +class StableDiffusion_1(TileModeMixin, SD1ImagePromptMixin, RefinerStableDiffusion_1): def __init__( self, unet: SD1UNet | None = None, @@ -184,7 +256,68 @@ class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1): return conditioning -class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL): +class SDXLImagePromptMixin(nn.Module): + def _get_ip_adapter(self, model_type: str): + valid_model_types = ["normal", "plus", "plus-face"] + if model_type not in valid_model_types: + msg = f"IP Adapter model_type must be one of {valid_model_types}" + raise ValueError(msg) + + ip_adapter_weights_path = get_cached_url_path( + config.IP_ADAPTER_WEIGHT_LOCATIONS["sdxl"][model_type] + ) + clip_image_weights_path = get_cached_url_path(config.SD21_UNCLIP_WEIGHTS_URL) + if "plus" in model_type: + ip_adapter_weight_translator = ( + diffusers_ip_adapter_plus_sdxl_to_refiners_translator() + ) + else: + ip_adapter_weight_translator = ( + diffusers_ip_adapter_sdxl_to_refiners_translator() + ) + clip_image_weight_translator = ( + transformers_image_encoder_to_refiners_translator() + ) + + ip_adapter = SDXLIPAdapter( + target=self.unet, + weights=ip_adapter_weight_translator.load_and_translate_weights( + ip_adapter_weights_path + ), + fine_grained="plus" in model_type, + ) + ip_adapter.clip_image_encoder.load_state_dict( + clip_image_weight_translator.load_and_translate_weights( + clip_image_weights_path + ), + assign=True, + ) + ip_adapter.to(device=self.unet.device, dtype=self.unet.dtype) + ip_adapter.clip_image_encoder.to(device=self.unet.device, dtype=self.unet.dtype) + return ip_adapter + + def set_image_prompt( + self, images: list[Image.Image], scale: float, model_type: str = "normal" + ): + ip_adapter = self._get_ip_adapter(model_type) + ip_adapter.inject() + + ip_adapter.set_scale(scale) + image_embeddings = [] + for image in images: + image_embedding = ip_adapter.compute_clip_image_embedding( + ip_adapter.preprocess_image(image).to(device=self.unet.device) + ) + image_embeddings.append(image_embedding) + + clip_image_embedding = sum(image_embeddings) / len(image_embeddings) + + ip_adapter.set_clip_image_embedding(clip_image_embedding) + + +class StableDiffusion_XL( + TileModeMixin, SDXLImagePromptMixin, RefinerStableDiffusion_XL +): def __init__( self, unet: SDXLUNet | None = None, @@ -324,7 +457,9 @@ class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL): return clip_text_embedding, pooled_text_embedding # type: ignore -class StableDiffusion_1_Inpainting(TileModeMixin, RefinerStableDiffusion_1_Inpainting): +class StableDiffusion_1_Inpainting( + TileModeMixin, SD1ImagePromptMixin, RefinerStableDiffusion_1_Inpainting +): def compute_self_attention_guidance( self, x: Tensor, @@ -356,7 +491,17 @@ class StableDiffusion_1_Inpainting(TileModeMixin, RefinerStableDiffusion_1_Inpai tensors=(degraded_latents, self.mask_latents, self.target_image_latents), dim=1, ) - degraded_noise = self.unet(x) + if "ip_adapter" in self.unet.provider.contexts: + # this implementation is a bit hacky, it should be refactored in the future + ip_adapter_context = self.unet.use_context("ip_adapter") + image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone() + ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context[ + "clip_image_embedding" + ].chunk(2) + degraded_noise = self.unet(x) + ip_adapter_context["clip_image_embedding"] = image_embedding_copy + else: + degraded_noise = self.unet(x) return sag.scale * (noise - degraded_noise) @@ -518,7 +663,17 @@ class StableDiffusion_XL_Inpainting(StableDiffusion_XL): tensors=(degraded_latents, self.mask_latents, self.target_image_latents), dim=1, ) - degraded_noise = self.unet(x) + if "ip_adapter" in self.unet.provider.contexts: + # this implementation is a bit hacky, it should be refactored in the future + ip_adapter_context = self.unet.use_context("ip_adapter") + image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone() + ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context[ + "clip_image_embedding" + ].chunk(2) + degraded_noise = self.unet(x) + ip_adapter_context["clip_image_embedding"] = image_embedding_copy + else: + degraded_noise = self.unet(x) return sag.scale * (noise - degraded_noise) diff --git a/imaginairy/schema.py b/imaginairy/schema.py index 7065722..73fc358 100644 --- a/imaginairy/schema.py +++ b/imaginairy/schema.py @@ -324,39 +324,6 @@ InpaintMethod = Literal["finetune", "control"] class ImaginePrompt(BaseModel, protected_namespaces=()): - """ - The ImaginePrompt class is used for configuring and generating image prompts. - - Attributes: - prompt (str | WeightedPrompt | list[WeightedPrompt] | list[str] | None, optional): Primary prompt for the image generation. - negative_prompt (str | WeightedPrompt | list[WeightedPrompt] | list[str] | None, optional): Prompt specifying what to avoid in the image. - prompt_strength (float, optional): Strength of the influence of the prompt on the output. - init_image (LazyLoadingImage, optional): Initial image to base the generation on. - init_image_strength (float, optional): Strength of the influence of the initial image. - control_inputs (List[ControlInput], optional): Additional control inputs for image generation. - mask_prompt (str, optional): Mask prompt for selective area generation. - mask_image (LazyLoadingImage, optional): Image used for masking. - mask_mode (MaskMode | str): Mode of masking operation. - mask_modify_original (bool): Flag to modify the original image with mask. - outpaint (str, optional): Outpainting string for extending image boundaries. - model_weights (str): Weights configuration for the generation model. - solver_type (str): Type of solver used for image generation. - seed (int, optional): Seed for random number generator. - steps (int, optional): Number of steps for the generation process. - size (int | str | tuple[int, int], optional): Size of the generated image. - upscale (bool): Flag to enable upscaling of the generated image. - fix_faces (bool): Flag to apply face fixing in the generation. - fix_faces_fidelity (float, optional): Fidelity of face fixing. - conditioning (str, optional): Additional conditioning string. - tile_mode (str): Mode of tiling for the image. - allow_compose_phase (bool): Flag to allow composition phase in generation. - is_intermediate (bool): Flag for intermediate image processing. - collect_progress_latents (bool): Flag to collect progress latents. - caption_text (str): Caption text for the image. - composition_strength (float, optional): Strength of the composition effect. - inpaint_method (InpaintMethod): Method used for inpainting. - """ - model_config = ConfigDict(extra="forbid", validate_assignment=True) prompt: List[WeightedPrompt] = Field(default=None, validate_default=True) # type: ignore @@ -370,6 +337,8 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): init_image_strength: float | None = Field( ge=0, le=1, default=None, validate_default=True ) + image_prompt: List[LazyLoadingImage] | None = Field(None, validate_default=True) + image_prompt_strength: float = Field(ge=0, le=1, default=0.0) control_inputs: List[ControlInput] = Field( default_factory=list, validate_default=True ) @@ -411,6 +380,8 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): prompt_strength: float | None = 7.5, init_image: LazyLoadingImage | None = None, init_image_strength: float | None = None, + image_prompt: LazyLoadingImage | List[LazyLoadingImage] | None = None, + image_prompt_strength: float | None = 0.35, control_inputs: List[ControlInput] | None = None, mask_prompt: str | None = None, mask_image: LazyLoadingImage | None = None, @@ -434,12 +405,20 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): composition_strength: float | None = 0.5, inpaint_method: InpaintMethod = "finetune", ): + if image_prompt and not isinstance(image_prompt, list): + image_prompt = [image_prompt] + + if not image_prompt_strength: + image_prompt_strength = 0.35 + super().__init__( prompt=prompt, negative_prompt=negative_prompt, prompt_strength=prompt_strength, init_image=init_image, init_image_strength=init_image_strength, + image_prompt=image_prompt, + image_prompt_strength=image_prompt_strength, control_inputs=control_inputs, mask_prompt=mask_prompt, mask_image=mask_image, @@ -807,6 +786,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): data = self.model_dump() data["init_image"] = repr(self.init_image) data["mask_image"] = repr(self.mask_image) + data["image_prompt"] = repr(self.image_prompt) if self.control_inputs: data["control_inputs"] = [repr(ci) for ci in self.control_inputs] return data diff --git a/imaginairy/utils/model_manager.py b/imaginairy/utils/model_manager.py index dc0bc16..0b14fb4 100644 --- a/imaginairy/utils/model_manager.py +++ b/imaginairy/utils/model_manager.py @@ -222,6 +222,7 @@ def _get_diffusion_model_refiners( Weights location may also be shortcut name, e.g. "SD-1.5" """ + global MOST_RECENTLY_LOADED_MODEL _get_diffusion_model_refiners.cache_clear() clear_gpu_cache() diff --git a/imaginairy/weight_management/weight_maps/Transformers-ClipImageEncoder-SD21.weightmap.json b/imaginairy/weight_management/weight_maps/Transformers-ClipImageEncoder-SD21.weightmap.json index 21518f9..6dd2ff9 100644 --- a/imaginairy/weight_management/weight_maps/Transformers-ClipImageEncoder-SD21.weightmap.json +++ b/imaginairy/weight_management/weight_maps/Transformers-ClipImageEncoder-SD21.weightmap.json @@ -261,7 +261,9 @@ "vision_model.encoder.layers.29.mlp.fc2": "Chain.TransformerLayer_30.Residual_2.FeedForward.Linear_2", "vision_model.encoder.layers.30.mlp.fc2": "Chain.TransformerLayer_31.Residual_2.FeedForward.Linear_2", "vision_model.encoder.layers.31.mlp.fc2": "Chain.TransformerLayer_32.Residual_2.FeedForward.Linear_2", - "visual_projection": "Linear" + "visual_projection": "Linear", + "vision_model.embeddings.position_ids": null + }, "regex_map": {}, "ignore_prefixes": [],