feature: IP-Adapter (#477)

todo
- allow specification ip adapter weights/arch


---------

Co-authored-by: jaydrennan <jsdman1313@gmail.com>
bd/uv2
Bryce Drennan 2 months ago committed by GitHub
parent 9c48b749d8
commit 49f2c25b6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -276,11 +276,18 @@ def generate_single_image(
)
controlnets.append((controlnet, control_image_t))
if prompt.image_prompt:
sd.set_image_prompt(
prompt.image_prompt,
scale=prompt.image_prompt_strength,
model_type="plus",
)
for controlnet, control_image_t in controlnets:
controlnet.set_controlnet_condition(
control_image_t.to(device=sd.unet.device, dtype=sd.unet.dtype)
)
controlnet.inject()
if prompt.solver_type.lower() == SolverName.DPMPP:
sd.scheduler = DPMSolver(num_inference_steps=prompt.steps)
elif prompt.solver_type.lower() == SolverName.DDIM:
@ -292,6 +299,15 @@ def generate_single_image(
sd.set_inference_steps(prompt.steps, first_step=first_step)
if hasattr(sd, "mask_latents") and mask_image is not None:
# import numpy as np
# init_size = init_image.size
# noise_image = Image.fromarray(np.random.randint(0, 255, (init_size[1], init_size[0], 3), dtype=np.uint8))
# masked_image = Image.composite(init_image, noise_image, mask_image)
masked_image = Image.composite(
init_image, mask_image.convert("RGB"), mask_image
)
result_images["masked_image"] = masked_image
sd.set_inpainting_conditions(
target_image=init_image,
mask=ImageOps.invert(mask_image),

@ -52,6 +52,8 @@ def edit_cmd(
prompt,
negative_prompt,
prompt_strength,
image_prompt,
image_prompt_strength,
outdir,
output_file_extension,
repeats,
@ -108,6 +110,8 @@ def edit_cmd(
prompt_strength=prompt_strength,
init_image=image_paths,
init_image_strength=image_strength,
image_prompt=image_prompt,
image_prompt_strength=image_prompt_strength,
outdir=outdir,
output_file_extension=output_file_extension,
repeats=repeats,

@ -83,6 +83,8 @@ def imagine_cmd(
prompt_strength,
init_image,
init_image_strength,
image_prompt,
image_prompt_strength,
outdir,
output_file_extension,
repeats,
@ -191,6 +193,8 @@ def imagine_cmd(
prompt_strength=prompt_strength,
init_image=init_image,
init_image_strength=init_image_strength,
image_prompt=image_prompt,
image_prompt_strength=image_prompt_strength,
outdir=outdir,
output_file_extension=output_file_extension,
repeats=repeats,

@ -23,8 +23,10 @@ def imaginairy_click_context(log_level="INFO"):
yield
except errors_to_catch as e:
logger.error(e)
# import traceback
# traceback.print_exc()
if log_level.upper() == "DEBUG":
import traceback
traceback.print_exc()
def _imagine_cmd(
@ -35,6 +37,8 @@ def _imagine_cmd(
prompt_strength,
init_image,
init_image_strength,
image_prompt,
image_prompt_strength,
outdir,
output_file_extension,
repeats,
@ -161,6 +165,14 @@ def _imagine_cmd(
defaults={"negative_prompt": config.DEFAULT_NEGATIVE_PROMPT},
)
def _img(img_str):
if img_str.startswith("http"):
return LazyLoadingImage(url=img_str)
else:
return LazyLoadingImage(filepath=img_str)
image_prompt = [_img(i) for i in image_prompt] if image_prompt else None
for _ in range(repeats):
for prompt_text in prompt_texts:
if prompt_text not in prompt_expanding_iterators:
@ -186,6 +198,8 @@ def _imagine_cmd(
prompt_strength=prompt_strength,
init_image=_init_image,
init_image_strength=init_image_strength,
image_prompt=image_prompt,
image_prompt_strength=image_prompt_strength,
control_inputs=control_inputs,
seed=seed,
solver_type=solver,
@ -312,6 +326,19 @@ common_options = [
type=float,
help="Starting image strength. Between 0 and 1.",
),
click.option(
"--image-prompt",
metavar="PATH|URL",
help="Starting image.",
multiple=True,
),
click.option(
"--image-prompt-strength",
default=None,
show_default=False,
type=float,
help="Starting image strength. Between 0 and 1.",
),
click.option(
"--outdir",
default="./outputs",

@ -380,6 +380,26 @@ SOLVER_CONFIGS = [
),
]
_ip_adapter_commit = "92a2d51861c754afacf8b3aaf90845254b49f219"
IP_ADAPTER_WEIGHT_LOCATIONS = {
"sd15": {
"full-face": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter-full-face_sd15.safetensors",
"plus-face": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter-plus-face_sd15.safetensors",
"plus": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter-plus_sd15.safetensors",
"normal": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter_sd15.safetensors",
"light": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter_sd15_light.safetensors",
"vitg": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter_sd15_vit-G.safetensors",
},
"sdxl": {
"plus-face": f"https://huggingface.co/h94/IP-Adapter/blob/{_ip_adapter_commit}/sdxl_models/ip-adapter-plus-face_sdxl_vit-h.safetensors",
"plus": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/sdxl_models/ip-adapter-plus_sdxl_vit-h.safetensors",
"vit-g": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/sdxl_models/ip-adapter_sdxl.safetensors",
"normal": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/sdxl_models/ip-adapter_sdxl_vit-h.safetensors",
},
}
SD21_UNCLIP_WEIGHTS_URL = "https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/resolve/e99f66a92bdcd1b0fb0d4b6a9b81b3b37d8bea44/image_encoder/model.fp16.safetensors"
SOLVER_TYPE_NAMES = [s.aliases[0] for s in SOLVER_CONFIGS]
SOLVER_LOOKUP = {}

@ -145,7 +145,7 @@ class DPT(BaseModel):
class DPTDepthModel(DPT):
def __init__(self, path=None, non_negative=True, **kwargs):
features = kwargs.pop("features", 256)
features = kwargs.get("features", 256)
head_features_1 = kwargs.pop("head_features_1", features)
head_features_2 = kwargs.pop("head_features_2", 32)

@ -12,7 +12,9 @@ from torch import Tensor, device as Device, dtype as DType, nn
from torch.nn import functional as F
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy import config
from imaginairy.schema import WeightedPrompt
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.feather_tile import rebuild_image, tile_image
from imaginairy.vendored.refiners.fluxion.layers.attentions import (
ScaledDotProductAttention,
@ -22,6 +24,10 @@ from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpol
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
CLIPTextEncoderL,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
SD1IPAdapter,
SDXLIPAdapter,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import (
TLatentDiffusionModel,
)
@ -55,6 +61,13 @@ from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusio
SDXLUNet,
)
from imaginairy.weight_management.conversion import cast_weights
from imaginairy.weight_management.translators import (
diffusers_ip_adapter_plus_sd15_to_refiners_translator,
diffusers_ip_adapter_plus_sdxl_to_refiners_translator,
diffusers_ip_adapter_sd15_to_refiners_translator,
diffusers_ip_adapter_sdxl_to_refiners_translator,
transformers_image_encoder_to_refiners_translator,
)
logger = logging.getLogger(__name__)
@ -106,7 +119,66 @@ class TileModeMixin(nn.Module):
m.padding_y = (0, 0, rprt[2], rprt[3]) # type: ignore
class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1):
class SD1ImagePromptMixin(nn.Module):
def _get_ip_adapter(self, model_type: str):
valid_model_types = ["normal", "plus", "plus-face"]
if model_type not in valid_model_types:
msg = f"IP Adapter model_type must be one of {valid_model_types}"
raise ValueError(msg)
ip_adapter_weights_path = get_cached_url_path(
config.IP_ADAPTER_WEIGHT_LOCATIONS["sd15"][model_type]
)
clip_image_weights_path = get_cached_url_path(config.SD21_UNCLIP_WEIGHTS_URL)
if "plus" in model_type:
ip_adapter_weight_translator = (
diffusers_ip_adapter_plus_sd15_to_refiners_translator()
)
else:
ip_adapter_weight_translator = (
diffusers_ip_adapter_sd15_to_refiners_translator()
)
clip_image_weight_translator = (
transformers_image_encoder_to_refiners_translator()
)
ip_adapter = SD1IPAdapter(
target=self.unet,
weights=ip_adapter_weight_translator.load_and_translate_weights(
ip_adapter_weights_path
),
fine_grained="plus" in model_type,
)
ip_adapter.clip_image_encoder.load_state_dict(
clip_image_weight_translator.load_and_translate_weights(
clip_image_weights_path
),
assign=True,
)
ip_adapter.to(device=self.unet.device, dtype=self.unet.dtype)
ip_adapter.clip_image_encoder.to(device=self.unet.device, dtype=self.unet.dtype)
return ip_adapter
def set_image_prompt(
self, images: list[Image.Image], scale: float, model_type: str = "normal"
):
ip_adapter = self._get_ip_adapter(model_type)
ip_adapter.inject()
ip_adapter.set_scale(scale)
image_embeddings = []
for image in images:
image_embedding = ip_adapter.compute_clip_image_embedding(
ip_adapter.preprocess_image(image).to(device=self.unet.device)
)
image_embeddings.append(image_embedding)
clip_image_embedding = sum(image_embeddings) / len(image_embeddings)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
class StableDiffusion_1(TileModeMixin, SD1ImagePromptMixin, RefinerStableDiffusion_1):
def __init__(
self,
unet: SD1UNet | None = None,
@ -184,7 +256,68 @@ class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1):
return conditioning
class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
class SDXLImagePromptMixin(nn.Module):
def _get_ip_adapter(self, model_type: str):
valid_model_types = ["normal", "plus", "plus-face"]
if model_type not in valid_model_types:
msg = f"IP Adapter model_type must be one of {valid_model_types}"
raise ValueError(msg)
ip_adapter_weights_path = get_cached_url_path(
config.IP_ADAPTER_WEIGHT_LOCATIONS["sdxl"][model_type]
)
clip_image_weights_path = get_cached_url_path(config.SD21_UNCLIP_WEIGHTS_URL)
if "plus" in model_type:
ip_adapter_weight_translator = (
diffusers_ip_adapter_plus_sdxl_to_refiners_translator()
)
else:
ip_adapter_weight_translator = (
diffusers_ip_adapter_sdxl_to_refiners_translator()
)
clip_image_weight_translator = (
transformers_image_encoder_to_refiners_translator()
)
ip_adapter = SDXLIPAdapter(
target=self.unet,
weights=ip_adapter_weight_translator.load_and_translate_weights(
ip_adapter_weights_path
),
fine_grained="plus" in model_type,
)
ip_adapter.clip_image_encoder.load_state_dict(
clip_image_weight_translator.load_and_translate_weights(
clip_image_weights_path
),
assign=True,
)
ip_adapter.to(device=self.unet.device, dtype=self.unet.dtype)
ip_adapter.clip_image_encoder.to(device=self.unet.device, dtype=self.unet.dtype)
return ip_adapter
def set_image_prompt(
self, images: list[Image.Image], scale: float, model_type: str = "normal"
):
ip_adapter = self._get_ip_adapter(model_type)
ip_adapter.inject()
ip_adapter.set_scale(scale)
image_embeddings = []
for image in images:
image_embedding = ip_adapter.compute_clip_image_embedding(
ip_adapter.preprocess_image(image).to(device=self.unet.device)
)
image_embeddings.append(image_embedding)
clip_image_embedding = sum(image_embeddings) / len(image_embeddings)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
class StableDiffusion_XL(
TileModeMixin, SDXLImagePromptMixin, RefinerStableDiffusion_XL
):
def __init__(
self,
unet: SDXLUNet | None = None,
@ -324,7 +457,9 @@ class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
return clip_text_embedding, pooled_text_embedding # type: ignore
class StableDiffusion_1_Inpainting(TileModeMixin, RefinerStableDiffusion_1_Inpainting):
class StableDiffusion_1_Inpainting(
TileModeMixin, SD1ImagePromptMixin, RefinerStableDiffusion_1_Inpainting
):
def compute_self_attention_guidance(
self,
x: Tensor,
@ -356,7 +491,17 @@ class StableDiffusion_1_Inpainting(TileModeMixin, RefinerStableDiffusion_1_Inpai
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
dim=1,
)
degraded_noise = self.unet(x)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context[
"clip_image_embedding"
].chunk(2)
degraded_noise = self.unet(x)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(x)
return sag.scale * (noise - degraded_noise)
@ -518,7 +663,17 @@ class StableDiffusion_XL_Inpainting(StableDiffusion_XL):
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
dim=1,
)
degraded_noise = self.unet(x)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context[
"clip_image_embedding"
].chunk(2)
degraded_noise = self.unet(x)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(x)
return sag.scale * (noise - degraded_noise)

@ -324,39 +324,6 @@ InpaintMethod = Literal["finetune", "control"]
class ImaginePrompt(BaseModel, protected_namespaces=()):
"""
The ImaginePrompt class is used for configuring and generating image prompts.
Attributes:
prompt (str | WeightedPrompt | list[WeightedPrompt] | list[str] | None, optional): Primary prompt for the image generation.
negative_prompt (str | WeightedPrompt | list[WeightedPrompt] | list[str] | None, optional): Prompt specifying what to avoid in the image.
prompt_strength (float, optional): Strength of the influence of the prompt on the output.
init_image (LazyLoadingImage, optional): Initial image to base the generation on.
init_image_strength (float, optional): Strength of the influence of the initial image.
control_inputs (List[ControlInput], optional): Additional control inputs for image generation.
mask_prompt (str, optional): Mask prompt for selective area generation.
mask_image (LazyLoadingImage, optional): Image used for masking.
mask_mode (MaskMode | str): Mode of masking operation.
mask_modify_original (bool): Flag to modify the original image with mask.
outpaint (str, optional): Outpainting string for extending image boundaries.
model_weights (str): Weights configuration for the generation model.
solver_type (str): Type of solver used for image generation.
seed (int, optional): Seed for random number generator.
steps (int, optional): Number of steps for the generation process.
size (int | str | tuple[int, int], optional): Size of the generated image.
upscale (bool): Flag to enable upscaling of the generated image.
fix_faces (bool): Flag to apply face fixing in the generation.
fix_faces_fidelity (float, optional): Fidelity of face fixing.
conditioning (str, optional): Additional conditioning string.
tile_mode (str): Mode of tiling for the image.
allow_compose_phase (bool): Flag to allow composition phase in generation.
is_intermediate (bool): Flag for intermediate image processing.
collect_progress_latents (bool): Flag to collect progress latents.
caption_text (str): Caption text for the image.
composition_strength (float, optional): Strength of the composition effect.
inpaint_method (InpaintMethod): Method used for inpainting.
"""
model_config = ConfigDict(extra="forbid", validate_assignment=True)
prompt: List[WeightedPrompt] = Field(default=None, validate_default=True) # type: ignore
@ -370,6 +337,8 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
init_image_strength: float | None = Field(
ge=0, le=1, default=None, validate_default=True
)
image_prompt: List[LazyLoadingImage] | None = Field(None, validate_default=True)
image_prompt_strength: float = Field(ge=0, le=1, default=0.0)
control_inputs: List[ControlInput] = Field(
default_factory=list, validate_default=True
)
@ -411,6 +380,8 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
prompt_strength: float | None = 7.5,
init_image: LazyLoadingImage | None = None,
init_image_strength: float | None = None,
image_prompt: LazyLoadingImage | List[LazyLoadingImage] | None = None,
image_prompt_strength: float | None = 0.35,
control_inputs: List[ControlInput] | None = None,
mask_prompt: str | None = None,
mask_image: LazyLoadingImage | None = None,
@ -434,12 +405,20 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
composition_strength: float | None = 0.5,
inpaint_method: InpaintMethod = "finetune",
):
if image_prompt and not isinstance(image_prompt, list):
image_prompt = [image_prompt]
if not image_prompt_strength:
image_prompt_strength = 0.35
super().__init__(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_strength=prompt_strength,
init_image=init_image,
init_image_strength=init_image_strength,
image_prompt=image_prompt,
image_prompt_strength=image_prompt_strength,
control_inputs=control_inputs,
mask_prompt=mask_prompt,
mask_image=mask_image,
@ -807,6 +786,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
data = self.model_dump()
data["init_image"] = repr(self.init_image)
data["mask_image"] = repr(self.mask_image)
data["image_prompt"] = repr(self.image_prompt)
if self.control_inputs:
data["control_inputs"] = [repr(ci) for ci in self.control_inputs]
return data

@ -222,6 +222,7 @@ def _get_diffusion_model_refiners(
Weights location may also be shortcut name, e.g. "SD-1.5"
"""
global MOST_RECENTLY_LOADED_MODEL
_get_diffusion_model_refiners.cache_clear()
clear_gpu_cache()

@ -261,7 +261,9 @@
"vision_model.encoder.layers.29.mlp.fc2": "Chain.TransformerLayer_30.Residual_2.FeedForward.Linear_2",
"vision_model.encoder.layers.30.mlp.fc2": "Chain.TransformerLayer_31.Residual_2.FeedForward.Linear_2",
"vision_model.encoder.layers.31.mlp.fc2": "Chain.TransformerLayer_32.Residual_2.FeedForward.Linear_2",
"visual_projection": "Linear"
"visual_projection": "Linear",
"vision_model.embeddings.position_ids": null
},
"regex_map": {},
"ignore_prefixes": [],

Loading…
Cancel
Save