feature: IPAdapter for sdxl

todo
- allow specification ip adapter weights/arch
pull/477/head
Bryce 4 months ago
parent 16f58e1f8e
commit 2e5e20c0d5

@ -281,6 +281,7 @@ def generate_single_image(
control_image_t.to(device=sd.unet.device, dtype=sd.unet.dtype)
)
controlnet.inject()
if prompt.solver_type.lower() == SolverName.DPMPP:
sd.scheduler = DPMSolver(num_inference_steps=prompt.steps)
elif prompt.solver_type.lower() == SolverName.DDIM:
@ -303,6 +304,12 @@ def generate_single_image(
sd.mask_latents = sd.mask_latents.to(
dtype=sd.unet.dtype, device=sd.unet.device
)
if prompt.image_prompt:
sd.set_image_prompt(
prompt.image_prompt,
scale=prompt.image_prompt_strength,
model_type="plus",
)
if init_latent is not None:
noise_step = noise_step if noise_step is not None else first_step

@ -23,8 +23,10 @@ def imaginairy_click_context(log_level="INFO"):
yield
except errors_to_catch as e:
logger.error(e)
# import traceback
# traceback.print_exc()
if log_level.upper() == "DEBUG":
import traceback
traceback.print_exc()
def _imagine_cmd(
@ -163,6 +165,14 @@ def _imagine_cmd(
defaults={"negative_prompt": config.DEFAULT_NEGATIVE_PROMPT},
)
def _img(img_str):
if img_str.startswith("http"):
return LazyLoadingImage(url=img_str)
else:
return LazyLoadingImage(filepath=img_str)
image_prompt = [_img(i) for i in image_prompt] if image_prompt else None
for _ in range(repeats):
for prompt_text in prompt_texts:
if prompt_text not in prompt_expanding_iterators:

@ -380,6 +380,24 @@ SOLVER_CONFIGS = [
),
]
_ip_adapter_commit = "92a2d51861c754afacf8b3aaf90845254b49f219"
IP_ADAPTER_WEIGHT_LOCATIONS = {
"sd15": {
"full-face": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter-full-face_sd15.safetensors",
"plus-face": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter-plus-face_sd15.safetensors",
"plus": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter-plus_sd15.safetensors",
"normal": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter_sd15.safetensors",
"light": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter_sd15_light.safetensors",
"vitg": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/models/ip-adapter_sd15_vit-G.safetensors",
},
"sdxl": {
"plus-face": f"https://huggingface.co/h94/IP-Adapter/blob/{_ip_adapter_commit}/sdxl_models/ip-adapter-plus-face_sdxl_vit-h.safetensors",
"plus": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/sdxl_models/ip-adapter-plus_sdxl_vit-h.safetensors",
"vit-g": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/sdxl_models/ip-adapter_sdxl.safetensors",
"normal": f"https://huggingface.co/h94/IP-Adapter/resolve/{_ip_adapter_commit}/sdxl_models/ip-adapter_sdxl_vit-h.safetensors",
},
}
SOLVER_TYPE_NAMES = [s.aliases[0] for s in SOLVER_CONFIGS]
SOLVER_LOOKUP = {}

@ -12,7 +12,9 @@ from torch import Tensor, device as Device, dtype as DType, nn
from torch.nn import functional as F
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy import config
from imaginairy.schema import WeightedPrompt
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.feather_tile import rebuild_image, tile_image
from imaginairy.vendored.refiners.fluxion.layers.attentions import (
ScaledDotProductAttention,
@ -22,6 +24,7 @@ from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpol
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
CLIPTextEncoderL,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion import SDXLIPAdapter
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import (
TLatentDiffusionModel,
)
@ -55,6 +58,11 @@ from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusio
SDXLUNet,
)
from imaginairy.weight_management.conversion import cast_weights
from imaginairy.weight_management.translators import (
diffusers_ip_adapter_plus_sdxl_to_refiners_translator,
diffusers_ip_adapter_sdxl_to_refiners_translator,
transformers_image_encoder_to_refiners_translator,
)
logger = logging.getLogger(__name__)
@ -302,6 +310,65 @@ class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
),
}
def _get_ip_adapter(self, model_type: str):
valid_model_types = ["normal", "plus", "plus-face"]
if model_type not in valid_model_types:
msg = f"IP Adapter model_type must be one of {valid_model_types}"
raise ValueError(msg)
ip_adapter_weights_path = get_cached_url_path(
config.IP_ADAPTER_WEIGHT_LOCATIONS["sdxl"][model_type]
)
clip_image_weights_path = get_cached_url_path(
"https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/resolve/main/image_encoder/model.fp16.safetensors"
)
if "plus" in model_type:
ip_adapter_weight_translator = (
diffusers_ip_adapter_plus_sdxl_to_refiners_translator()
)
else:
ip_adapter_weight_translator = (
diffusers_ip_adapter_sdxl_to_refiners_translator()
)
clip_image_weight_translator = (
transformers_image_encoder_to_refiners_translator()
)
ip_adapter = SDXLIPAdapter(
target=self.unet,
weights=ip_adapter_weight_translator.load_and_translate_weights(
ip_adapter_weights_path
),
fine_grained="plus" in model_type,
)
ip_adapter.clip_image_encoder.load_state_dict(
clip_image_weight_translator.load_and_translate_weights(
clip_image_weights_path
),
assign=True,
)
ip_adapter.to(device=self.unet.device, dtype=self.unet.dtype)
ip_adapter.clip_image_encoder.to(device=self.unet.device, dtype=self.unet.dtype)
return ip_adapter
def set_image_prompt(
self, images: list[Image.Image], scale: float, model_type: str = "normal"
):
ip_adapter = self._get_ip_adapter(model_type)
ip_adapter.inject()
ip_adapter.set_scale(scale)
image_embeddings = []
for image in images:
image_embedding = ip_adapter.compute_clip_image_embedding(
ip_adapter.preprocess_image(image).to(device=self.unet.device)
)
image_embeddings.append(image_embedding)
clip_image_embedding = sum(image_embeddings) / len(image_embeddings)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
def prompts_to_embeddings(
self, prompts: List[WeightedPrompt]
) -> tuple[Tensor, Tensor]:

@ -372,8 +372,8 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
init_image_strength: float | None = Field(
ge=0, le=1, default=None, validate_default=True
)
image_prompt: LazyLoadingImage | None = Field(None, validate_default=True)
image_prompt_strength: float | None = Field(ge=0, le=1, default=0.0)
image_prompt: List[LazyLoadingImage] | None = Field(None, validate_default=True)
image_prompt_strength: float = Field(ge=0, le=1, default=0.0)
control_inputs: List[ControlInput] = Field(
default_factory=list, validate_default=True
)
@ -415,8 +415,8 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
prompt_strength: float | None = 7.5,
init_image: LazyLoadingImage | None = None,
init_image_strength: float | None = None,
image_prompt: LazyLoadingImage | None = None,
image_prompt_strength: float | None = None,
image_prompt: LazyLoadingImage | List[LazyLoadingImage] | None = None,
image_prompt_strength: float | None = 0.35,
control_inputs: List[ControlInput] | None = None,
mask_prompt: str | None = None,
mask_image: LazyLoadingImage | None = None,
@ -440,6 +440,12 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
composition_strength: float | None = 0.5,
inpaint_method: InpaintMethod = "finetune",
):
if image_prompt and not isinstance(image_prompt, list):
image_prompt = [image_prompt]
if not image_prompt_strength:
image_prompt_strength = 0.35
super().__init__(
prompt=prompt,
negative_prompt=negative_prompt,
@ -815,6 +821,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
data = self.model_dump()
data["init_image"] = repr(self.init_image)
data["mask_image"] = repr(self.mask_image)
data["image_prompt"] = repr(self.image_prompt)
if self.control_inputs:
data["control_inputs"] = [repr(ci) for ci in self.control_inputs]
return data

@ -204,42 +204,6 @@ def get_diffusion_model_refiners(
# ensures a "fresh" copy that doesn't have additional injected parts
sd = sd.structural_copy()
# inject ip-adapter (img to img prompt)
from PIL import Image
from imaginairy.vendored.refiners.fluxion.utils import (
load_from_safetensors,
no_grad,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
SDXLIPAdapter,
)
image_prompt = Image.open(
"/imaginAIry/docs/assets/000032_337692011_PLMS40_PS7.5_a_photo_of_a_dog.jpg"
)
ip_adapter = SDXLIPAdapter(
target=sd.unet,
weights=load_from_safetensors(
"/imaginAIry/imaginairy/utils/ip-adapter_sdxl_vit-h.safetensors"
),
)
ip_adapter.clip_image_encoder.load_from_safetensors(
"/imaginAIry/imaginairy/utils/clip_image.safetensors"
)
ip_adapter.inject()
scale = 0.4
ip_adapter.set_scale(scale)
print(f"SCALE: {scale}")
with no_grad():
clip_image_embedding = ip_adapter.compute_clip_image_embedding(
ip_adapter.preprocess_image(image_prompt)
)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
sd.set_self_attention_guidance(enable=True)
return sd

@ -261,7 +261,9 @@
"vision_model.encoder.layers.29.mlp.fc2": "Chain.TransformerLayer_30.Residual_2.FeedForward.Linear_2",
"vision_model.encoder.layers.30.mlp.fc2": "Chain.TransformerLayer_31.Residual_2.FeedForward.Linear_2",
"vision_model.encoder.layers.31.mlp.fc2": "Chain.TransformerLayer_32.Residual_2.FeedForward.Linear_2",
"visual_projection": "Linear"
"visual_projection": "Linear",
"vision_model.embeddings.position_ids": null
},
"regex_map": {},
"ignore_prefixes": [],

Loading…
Cancel
Save