mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-17 09:25:47 +00:00
73 lines
1.9 KiB
Python
73 lines
1.9 KiB
Python
"""
|
|
Wrapper for instruct pix2pix model.
|
|
|
|
modified from https://github.com/timothybrooks/instruct-pix2pix/blob/main/edit_cli.py
|
|
"""
|
|
import torch
|
|
from einops import einops
|
|
from torch import nn
|
|
|
|
from imaginairy.samplers.base import mask_blend
|
|
|
|
|
|
class CFGEditingDenoiser(nn.Module):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.inner_model = model
|
|
|
|
def forward(
|
|
self,
|
|
z,
|
|
sigma,
|
|
cond,
|
|
uncond,
|
|
cond_scale,
|
|
image_cfg_scale=1.5,
|
|
mask=None,
|
|
mask_noise=None,
|
|
orig_latent=None,
|
|
):
|
|
cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
|
|
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
|
|
cfg_cond = {
|
|
"c_crossattn": [
|
|
torch.cat(
|
|
[
|
|
cond["c_crossattn"][0],
|
|
uncond["c_crossattn"][0],
|
|
uncond["c_crossattn"][0],
|
|
]
|
|
)
|
|
],
|
|
"c_concat": [
|
|
torch.cat(
|
|
[cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]]
|
|
)
|
|
],
|
|
}
|
|
|
|
if mask is not None:
|
|
assert orig_latent is not None
|
|
t = self.inner_model.sigma_to_t(sigma, quantize=True)
|
|
big_sigma = max(sigma, 1)
|
|
cfg_z = mask_blend(
|
|
noisy_latent=cfg_z,
|
|
orig_latent=orig_latent * big_sigma,
|
|
mask=mask,
|
|
mask_noise=mask_noise * big_sigma,
|
|
ts=t,
|
|
model=self.inner_model.inner_model,
|
|
)
|
|
|
|
out_cond, out_img_cond, out_uncond = self.inner_model(
|
|
cfg_z, cfg_sigma, cond=cfg_cond
|
|
).chunk(3)
|
|
|
|
result = (
|
|
out_uncond
|
|
+ cond_scale * (out_cond - out_img_cond)
|
|
+ image_cfg_scale * (out_img_cond - out_uncond)
|
|
)
|
|
|
|
return result
|