You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/samplers/editing.py

74 lines
1.9 KiB
Python

"""
Wrapper for instruct pix2pix model.
modified from https://github.com/timothybrooks/instruct-pix2pix/blob/main/edit_cli.py
"""
import torch
from einops import einops
from torch import nn
from imaginairy.samplers.base import mask_blend
class CFGEditingDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(
self,
z,
sigma,
cond,
uncond,
cond_scale,
image_cfg_scale=1.5,
mask=None,
mask_noise=None,
orig_latent=None,
):
cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
cfg_cond = {
"c_crossattn": [
torch.cat(
[
cond["c_crossattn"][0],
uncond["c_crossattn"][0],
uncond["c_crossattn"][0],
]
)
],
"c_concat": [
torch.cat(
[cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]]
)
],
}
if mask is not None:
assert orig_latent is not None
t = self.inner_model.sigma_to_t(sigma, quantize=True)
big_sigma = max(sigma, 1)
cfg_z = mask_blend(
noisy_latent=cfg_z,
orig_latent=orig_latent * big_sigma,
mask=mask,
mask_noise=mask_noise * big_sigma,
ts=t,
model=self.inner_model.inner_model,
)
out_cond, out_img_cond, out_uncond = self.inner_model(
cfg_z, cfg_sigma, cond=cfg_cond
).chunk(3)
result = (
out_uncond
+ cond_scale * (out_cond - out_img_cond)
+ image_cfg_scale * (out_img_cond - out_uncond)
)
return result