feature: face enhancement and upscaling!!

This commit is contained in:
Bryce 2022-09-13 00:27:53 -07:00
parent 6fa776053f
commit 541ecb9701
29 changed files with 1328 additions and 100 deletions

3
.gitignore vendored
View File

@ -14,4 +14,5 @@ build
dist
**/*.ckpt
**/*.egg-info
tests/test_output
tests/test_output
gfpgan/**

View File

@ -16,6 +16,7 @@ init: require_pyenv ## Setup a dev environment for local development.
@echo -e "\033[0;32m ✔️ 🐍 $(venv_name) virtualenv activated \033[0m"
pip install --upgrade pip pip-tools
pip-sync requirements-dev.txt
pip install -e . --no-deps
@echo -e "\nEnvironment setup! ✨ 🍰 ✨ 🐍 \n\nCopy this path to tell PyCharm where your virtualenv is. You may have to click the refresh button in the pycharm file explorer.\n"
@echo -e "\033[0;32m"
@pyenv which python
@ -75,6 +76,31 @@ vendor_openai_clip:
git --git-dir ./downloads/CLIP/.git rev-parse HEAD | tee ./imaginairy/vendored/clip/clip-commit-hash.txt
echo "vendored from git@github.com:openai/CLIP.git" | tee ./imaginairy/vendored/clip/readme.txt
revendorize:
make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip
make vendorize REPO=git@github.com:xinntao/Real-ESRGAN.git PKG=realesrgan
vendorize: ## vendorize a github repo. `make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip`
mkdir -p ./downloads
-cd ./downloads && git clone $(REPO) $(PKG)
cd ./downloads/$(PKG) && git pull
rm -rf ./imaginairy/vendored/$(PKG)
cp -R ./downloads/$(PKG)/$(PKG) imaginairy/vendored/
git --git-dir ./downloads/$(PKG)/.git rev-parse HEAD | tee ./imaginairy/vendored/$(PKG)/clip-commit-hash.txt
touch ./imaginairy/vendored/$(PKG)/version.py
echo "vendored from $(REPO)" | tee ./imaginairy/vendored/$(PKG)/readme.txt
vendorize_whole_repo:
mkdir -p ./downloads
-cd ./downloads && git clone $(REPO) $(PKG)
cd ./downloads/$(PKG) && git pull
rm -rf ./imaginairy/vendored/$(PKG)
cp -R ./downloads/$(PKG) imaginairy/vendored/
git --git-dir ./downloads/$(PKG)/.git rev-parse HEAD | tee ./imaginairy/vendored/$(PKG)/clip-commit-hash.txt
touch ./imaginairy/vendored/$(PKG)/version.py
echo "vendored from $(REPO)" | tee ./imaginairy/vendored/$(PKG)/readme.txt
help: ## Show this help message.
@## https://gist.github.com/prwhite/8168133#gistcomment-1716694

View File

@ -53,6 +53,17 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/tests/data/girl_with_a_pearl_earring.jpg" height="256"> =>
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000105_33084057_DDIM40_PS7.5_portrait_of_a_smiling_lady._oil_painting._.jpg" height="256">
### Face Enhancement [by CodeFormer](https://github.com/sczhou/CodeFormer)
```bash
>> imagine "a couple smiling" --steps 40 --seed 1 --fix-faces
```
<img src="assets/000178_1_PLMS40_PS7.5_a_couple_smiling_nofix.png" height="256"> => <img src="assets/000178_1_PLMS40_PS7.5_a_couple_smiling_fixed.png" height="256">
### Upscaling [by RealESRGAN](https://github.com/xinntao/Real-ESRGAN)
<img src="assets/000206_856637805_PLMS40_PS7.5_colorful_smoke.jpg">
<img src="assets/000206_856637805_PLMS40_PS7.5_colorful_smoke_upscaled.jpg" height="512">
## Features
@ -68,20 +79,20 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
## How To
```python
from imaginairy import imagine_images, imagine_image_files, ImaginePrompt, WeightedPrompt
from imaginairy import imagine, imagine_image_files, ImaginePrompt, WeightedPrompt
prompts = [
ImaginePrompt("a scenic landscape", seed=1),
ImaginePrompt("a bowl of fruit"),
ImaginePrompt([
WeightedPrompt("cat", weight=1),
WeightedPrompt("dog", weight=1),
WeightedPrompt("cat", weight=1),
WeightedPrompt("dog", weight=1),
])
]
for result in imagine_images(prompts):
for result in imagine(prompts):
# do something
result.save("my_image.jpg")
# or
imagine_image_files(prompts, outdir="./my-art")
@ -109,11 +120,12 @@ imagine_image_files(prompts, outdir="./my-art")
## Todo
- performance optimizations
- https://github.com/neonsecret/stable-diffusion https://github.com/CompVis/stable-diffusion/pull/177
- ✅ https://github.com/huggingface/diffusers/blob/main/docs/source/optimization/fp16.mdx
- ✅ https://github.com/CompVis/stable-diffusion/compare/main...Doggettx:stable-diffusion:autocast-improvements#
- ✅ https://www.reddit.com/r/StableDiffusion/comments/xalaws/test_update_for_less_memory_usage_and_higher/
- deploy to pypi
- https://github.com/neonsecret/stable-diffusion https://github.com/CompVis/stable-diffusion/pull/177
- ✅ deploy to pypi
- add tests
- set up ci (test/lint/format)
- add docs
@ -124,11 +136,14 @@ imagine_image_files(prompts, outdir="./my-art")
- ✅ init-image at command line
- prompt expansion
- Image Generation Features
- add in all the samplers
- upscaling
- ✅ realesrgan
- ldm
- https://github.com/lowfuel/progrock-stable
- face improvements
- gfpgan - https://github.com/TencentARC/GFPGAN
- codeformer - https://github.com/sczhou/CodeFormer
- ✅ face enhancers
- gfpgan - https://github.com/TencentARC/GFPGAN
- codeformer - https://github.com/sczhou/CodeFormer
- image describe feature - https://replicate.com/methexis-inc/img2prompt
- outpainting
- inpainting
@ -148,5 +163,16 @@ imagine_image_files(prompts, outdir="./my-art")
- textual inversion
- https://www.reddit.com/r/StableDiffusion/comments/xbwb5y/how_to_run_textual_inversion_locally_train_your/
- https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb#scrollTo=50JuJUM8EG1h
- https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion_textual_inversion_library_navigator.ipynb
- zooming videos? a la disco diffusion
- fix saturation at high CFG https://www.reddit.com/r/StableDiffusion/comments/xalo78/fixing_excessive_contrastsaturation_resulting/
- https://www.reddit.com/r/StableDiffusion/comments/xbrrgt/a_rundown_of_twenty_new_methodsoptions_added_to/
## Noteable Stable Diffusion Implementations
- https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion
- https://github.com/lstein/stable-diffusion
- https://github.com/AUTOMATIC1111/stable-diffusion-webui
## Further Reading
- Differences between samplers
- https://www.reddit.com/r/StableDiffusion/comments/xbeyw3/can_anyone_offer_a_little_guidance_on_the/

Binary file not shown.

After

Width:  |  Height:  |  Size: 378 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 286 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 220 KiB

View File

@ -2,5 +2,5 @@ import os
os.putenv("PYTORCH_ENABLE_MPS_FALLBACK", "1")
from .api import imagine_image_files, imagine_images # noqa
from .api import imagine, imagine_image_files # noqa
from .schema import ImaginePrompt, ImagineResult, WeightedPrompt # noqa

View File

@ -1,7 +1,6 @@
import logging
import os
import re
import subprocess
from contextlib import nullcontext
from functools import lru_cache
@ -10,11 +9,13 @@ import torch
import torch.nn
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image, ImageDraw
from PIL import Image, ImageDraw, ImageFilter
from pytorch_lightning import seed_everything
from torch import autocast
from transformers import cached_path
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.modules.diffusion.ddim import DDIMSampler
from imaginairy.modules.diffusion.plms import PLMSSampler
from imaginairy.safety import is_nsfw, safety_models
@ -93,7 +94,7 @@ def imagine_image_files(
):
big_path = os.path.join(outdir, "upscaled")
os.makedirs(outdir, exist_ok=True)
os.makedirs(big_path, exist_ok=True)
base_count = len(os.listdir(outdir))
step_count = 0
output_file_extension = output_file_extension.lower()
@ -116,7 +117,7 @@ def imagine_image_files(
img.save(os.path.join(steps_path, filename))
img_callback = _record_steps if record_step_images else None
for result in imagine_images(
for result in imagine(
prompts,
latent_channels=latent_channels,
downsampling_factor=downsampling_factor,
@ -131,14 +132,15 @@ def imagine_image_files(
result.save(filepath)
logger.info(f" 🖼 saved to: {filepath}")
if prompt.upscale:
bigfilepath = (os.path.join(big_path, basefilename) + ".jpg",)
enlarge_realesrgan2x(filepath, bigfilepath)
logger.info(f" upscaled 🖼 saved to: {filepath}")
if result.upscaled_img:
os.makedirs(big_path, exist_ok=True)
bigfilepath = os.path.join(big_path, basefilename) + "_upscaled.jpg"
result.save_upscaled(bigfilepath)
logger.info(f" Upscaled 🖼 saved to: {bigfilepath}")
base_count += 1
def imagine_images(
def imagine(
prompts,
latent_channels=4,
downsampling_factor=8,
@ -149,15 +151,16 @@ def imagine_images(
half_mode=None,
):
model = load_model(tile_mode=tile_mode)
if not IMAGINAIRY_ALLOW_NSFW:
# needs to be loaded before we set default tensor type to half
safety_models()
# only run half-mode on cuda. run it by default
half_mode = True if half_mode is None and get_device() == "cuda" else False
half_mode = half_mode is None and get_device() == "cuda"
if half_mode:
model = model.half()
# needed when model is in half mode, remove if not using half mode
torch.set_default_tensor_type(torch.HalfTensor)
# torch.set_default_tensor_type(torch.HalfTensor)
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
_img_callback = None
@ -242,20 +245,23 @@ def imagine_images(
for x_sample in x_samples:
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
img = Image.fromarray(x_sample.astype(np.uint8))
x_sample_8_orig = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample_8_orig)
upscaled_img = None
if not IMAGINAIRY_ALLOW_NSFW and is_nsfw(
img, x_sample, half_mode=half_mode
):
logger.info(" ⚠️ Filtering NSFW image")
img = Image.new("RGB", img.size, (228, 150, 150))
img = img.filter(ImageFilter.GaussianBlur(radius=10))
if prompt.fix_faces:
img = fix_faces_GFPGAN(img)
# if prompt.upscale:
# enlarge_realesrgan2x(
# filepath,
# os.path.join(big_path, basefilename) + ".jpg",
# )
yield ImagineResult(img=img, prompt=prompt)
logger.info(" Fixing 😊 's in 🖼 using GFPGAN...")
img = enhance_faces(img, fidelity=0.2)
if prompt.upscale:
logger.info(" Upscaling 🖼 using real-ESRGAN...")
upscaled_img = upscale_image(img)
yield ImagineResult(img=img, prompt=prompt, upscaled_img=upscaled_img)
def prompt_normalized(prompt):
@ -263,14 +269,6 @@ def prompt_normalized(prompt):
DOWNLOADED_FILES_PATH = f"{LIB_PATH}/../downloads/"
ESRGAN_PATH = DOWNLOADED_FILES_PATH + "realesrgan-ncnn-vulkan/realesrgan-ncnn-vulkan"
def enlarge_realesrgan2x(src, dst):
process = subprocess.Popen(
[ESRGAN_PATH, "-i", src, "-o", dst, "-n", "realesrgan-x4plus"]
)
process.wait()
def get_sampler(sampler_type, model):
@ -279,30 +277,3 @@ def get_sampler(sampler_type, model):
return PLMSSampler(model)
elif sampler_type == "DDIM":
return DDIMSampler(model)
def gfpgan_model():
from gfpgan import GFPGANer
return GFPGANer(
model_path=DOWNLOADED_FILES_PATH
+ "GFPGAN/experiments/pretrained_models/GFPGANv1.3.pth",
upscale=1,
arch="clean",
channel_multiplier=2,
bg_upsampler=None,
device=torch.device(get_device()),
)
def fix_faces_GFPGAN(image):
image = image.convert("RGB")
cropped_faces, restored_faces, restored_img = gfpgan_model().enhance(
np.array(image, dtype=np.uint8),
has_aligned=False,
only_center_face=False,
paste_back=True,
)
res = Image.fromarray(restored_img)
return res

View File

@ -46,6 +46,12 @@ def filter_torch_warnings():
category=UserWarning,
message=r"The operator .*?is not currently supported.*",
)
warnings.filterwarnings(
"ignore", category=UserWarning, message=r"The parameter 'pretrained' is.*"
)
warnings.filterwarnings(
"ignore", category=UserWarning, message=r"Arguments other than a weight.*"
)
def setup_env():

View File

@ -92,6 +92,12 @@ def configure_logging(level="INFO"):
type=int,
help="What seed to use for randomness. Allows reproducible image renders",
)
@click.option("--upscale", is_flag=True)
@click.option(
"--upscale-method", default="realesrgan", type=click.Choice(["realesrgan"])
)
@click.option("--fix-faces", is_flag=True)
@click.option("--fix-faces-method", default="gfpgan", type=click.Choice(["gfpgan"]))
@click.option(
"--sampler-type",
default="PLMS",
@ -128,6 +134,10 @@ def imagine_cmd(
width,
steps,
seed,
upscale,
upscale_method,
fix_faces,
fix_faces_method,
sampler_type,
ddim_eta,
log_level,
@ -142,7 +152,7 @@ def imagine_cmd(
total_image_count = len(prompt_texts) * repeats
logger.info(
f"🤖🧠 received {len(prompt_texts)} prompt(s) and will repeat them {repeats} times to create {total_image_count} images."
f"🤖🧠 imaginAIry received {len(prompt_texts)} prompt(s) and will repeat them {repeats} times to create {total_image_count} images."
)
if init_image and sampler_type != "DDIM":
sampler_type = "DDIM"
@ -161,8 +171,8 @@ def imagine_cmd(
steps=steps,
height=height,
width=width,
upscale=False,
fix_faces=False,
upscale=upscale,
fix_faces=fix_faces,
)
prompts.append(prompt)
@ -172,6 +182,7 @@ def imagine_cmd(
ddim_eta=ddim_eta,
record_step_images="images" in show_work,
tile_mode=tile,
output_file_extension="png",
)

View File

View File

@ -0,0 +1,86 @@
import logging
from functools import lru_cache
import numpy as np
import torch
from basicsr.utils import img2tensor, tensor2img
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from PIL import Image
from torchvision.transforms.functional import normalize
from imaginairy.utils import get_cached_url_path
from imaginairy.vendored.codeformer.codeformer_arch import CodeFormer
logger = logging.getLogger(__name__)
@lru_cache()
def codeformer_model():
model = CodeFormer(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32", "64", "128", "256"],
).to("cpu")
url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
ckpt_path = get_cached_url_path(url)
checkpoint = torch.load(ckpt_path)["params_ema"]
model.load_state_dict(checkpoint)
model.eval()
return model
def enhance_faces(img, fidelity=0):
net = codeformer_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
face_helper = FaceRestoreHelper(
1,
face_size=512,
crop_ratio=(1, 1),
det_model="retinaface_resnet50",
save_ext="png",
use_parse=True,
device=device,
)
face_helper.clean_all()
image = img.convert("RGB")
np_img = np.array(image, dtype=np.uint8)
# rotate to BGR
np_img = np_img[:, :, ::-1]
face_helper.read_image(np_img)
# get face landmarks for each face
num_det_faces = face_helper.get_face_landmarks_5(
only_center_face=False, resize=640, eye_dist_threshold=5
)
logger.info(f" Enhancing {num_det_faces} faces")
# align and warp each face
face_helper.align_warp_face()
# face restoration for each cropped face
for idx, cropped_face in enumerate(face_helper.cropped_faces):
# prepare data
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to("cpu")
try:
with torch.no_grad():
output = net(cropped_face_t, w=fidelity, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
except Exception as error:
logger.error(f"\tFailed inference for CodeFormer: {error}")
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype("uint8")
face_helper.add_restored_face(restored_face)
face_helper.get_inverse_affine(None)
# paste each restored face to the input image
restored_img = face_helper.paste_faces_to_input_image()
res = Image.fromarray(restored_img[:, :, ::-1])
return res

View File

@ -0,0 +1,50 @@
from functools import lru_cache
import numpy as np
import torch
from PIL import Image
from imaginairy.utils import get_cached_url_path, get_device
@lru_cache()
def face_enhance_model(model_type="codeformer"):
from gfpgan import GFPGANer
if model_type == "gfpgan":
arch = "clean"
url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
elif model_type == "codeformer":
arch = "CodeFormer"
url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth"
else:
raise ValueError("model_type must be one of gfpgan, codeformer")
model_path = get_cached_url_path(url)
if get_device() == "cuda":
device = "cuda"
else:
device = "cpu"
return GFPGANer(
model_path=model_path,
upscale=1,
arch=arch,
channel_multiplier=2,
bg_upsampler=None,
device=device,
)
def fix_faces_gfpgan(image, model_type):
image = image.convert("RGB")
np_img = np.array(image, dtype=np.uint8)
cropped_faces, restored_faces, restored_img = face_enhance_model(
model_type
).enhance(
np_img, has_aligned=False, only_center_face=False, paste_back=True, weight=0
)
res = Image.fromarray(restored_img)
return res

View File

@ -0,0 +1,40 @@
from functools import lru_cache
import numpy as np
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from realesrgan import RealESRGANer
from imaginairy.utils import get_cached_url_path, get_device
@lru_cache()
def realesrgan_upsampler():
model = RRDBNet(
num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
)
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
model_path = get_cached_url_path(url)
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0)
if get_device() == "cuda":
device = "cuda"
else:
device = "cpu"
upsampler.device = torch.device(device)
upsampler.model.to(device)
return upsampler
def upscale_image(img):
img = img.convert("RGB")
np_img = np.array(img, dtype=np.uint8)
upsampler_output, img_mode = realesrgan_upsampler().enhance(np_img[:, :, ::-1])
return Image.fromarray(upsampler_output[:, :, ::-1], mode=img_mode)
if __name__ == "__main__":
realesrgan_upsampler()

View File

@ -51,7 +51,7 @@ class VectorQuantizer(nn.Module):
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
print(
logger.info(
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)

View File

@ -174,7 +174,8 @@ class DDIMSampler:
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
# run on CPU for seed consistency. M1/mps runs were not consistent otherwise
img = torch.randn(shape, device="cpu").to(device)
else:
img = x_T

View File

@ -175,7 +175,8 @@ class PLMSSampler(object):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
img = torch.randn(shape, device="cpu").to(device)
else:
img = x_T

View File

@ -26,9 +26,7 @@ def pil_img_to_latent(model, img, batch_size=1, device="cuda", half=True):
return model.get_first_stage_encoding(model.encode_first_stage(init_image))
def find_noise_for_image(
model, pil_img, prompt, steps=50, cond_scale=1.0, verbose=False, half=True
):
def find_noise_for_image(model, pil_img, prompt, steps=50, cond_scale=1.0, half=True):
img_latent = pil_img_to_latent(
model, pil_img, batch_size=1, device="cuda", half=half
)
@ -38,13 +36,12 @@ def find_noise_for_image(
prompt,
steps=steps,
cond_scale=cond_scale,
verbose=verbose,
half=half,
)
def find_noise_for_latent(
model, img_latent, prompt, steps=50, cond_scale=1.0, verbose=False, half=True
model, img_latent, prompt, steps=50, cond_scale=1.0, half=True
):
import k_diffusion as K
@ -59,9 +56,6 @@ def find_noise_for_latent(
dnw = K.external.CompVisDenoiser(model)
sigmas = dnw.get_sigmas(steps).flip(0)
if verbose:
print(sigmas)
with (torch.no_grad(), _autocast(get_device())):
for i in range(1, len(sigmas)):
x_in = torch.cat([x] * 2)

View File

@ -91,8 +91,9 @@ class ExifCodes:
class ImagineResult:
def __init__(self, img, prompt: ImaginePrompt):
def __init__(self, img, prompt: ImaginePrompt, upscaled_img=None):
self.img = img
self.upscaled_img = upscaled_img
self.prompt = prompt
self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc)
self.torch_backend = get_device()
@ -113,7 +114,7 @@ class ImagineResult:
"prompt": self.prompt.as_dict(),
}
def save(self, save_path):
def _exif(self):
exif = Exif()
exif[ExifCodes.ImageDescription] = self.prompt.prompt_description()
exif[ExifCodes.UserComment] = json.dumps(self.metadata_dict())
@ -121,4 +122,10 @@ class ImagineResult:
exif[ExifCodes.Software] = "Imaginairy / Stable Diffusion v1.4"
exif[ExifCodes.DateTime] = self.created_at.isoformat(sep=" ")[:19]
exif[ExifCodes.HostComputer] = f"{self.torch_backend}:{self.hardware_name}"
self.img.save(save_path, exif=exif)
return exif
def save(self, save_path):
self.img.save(save_path, exif=self._exif())
def save_upscaled(self, save_path):
self.upscaled_img.save(save_path, exif=self._exif())

View File

@ -1,14 +1,17 @@
import importlib
import logging
import os.path
import platform
from contextlib import contextmanager
from functools import lru_cache
from typing import List, Optional
import numpy as np
import requests
import torch
from PIL import Image
from torch import Tensor
from transformers import cached_path
logger = logging.getLogger(__name__)
@ -115,3 +118,34 @@ def pillow_img_to_torch_image(image, max_height=512, max_width=512):
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0, w, h
def get_cache_dir():
xdg_cache_home = os.getenv("XDG_CACHE_HOME", None)
if xdg_cache_home is None:
user_home = os.getenv("HOME", None)
if user_home:
xdg_cache_home = os.path.join(user_home, ".cache")
if xdg_cache_home is not None:
return os.path.join(xdg_cache_home, "imaginairy", "weights")
return os.path.join(os.path.dirname(__file__), ".cached-downloads")
def get_cached_url_path(url):
try:
return cached_path(url)
except OSError:
pass
filename = url.split("/")[-1]
dest = get_cache_dir()
os.makedirs(dest, exist_ok=True)
dest_path = os.path.join(dest, filename)
if os.path.exists(dest_path):
return dest_path
r = requests.get(url)
with open(dest_path, "wb") as f:
f.write(r.content)
return dest_path

View File

View File

@ -0,0 +1,325 @@
import math
from typing import Optional
import torch
import torch.nn.functional as F
# from basicsr.archs.vqgan_arch import *
from basicsr.utils.registry import ARCH_REGISTRY
from torch import Tensor, nn
from imaginairy.vendored.codeformer.vqgan_arch import ResBlock, VQAutoEncoder
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, "The input feature should be 4D tensor."
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
size
)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask=None):
if mask is None:
mask = torch.zeros(
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
class TransformerSALayer(nn.Module):
def __init__(
self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
# Implementation of Feedforward model - MLP
self.linear1 = nn.Linear(embed_dim, dim_mlp)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_mlp, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(
self,
tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
# self attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
# ffn
tgt2 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout2(tgt2)
return tgt
class Fuse_sft_block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.encode_enc = ResBlock(2 * in_ch, out_ch)
self.scale = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
)
self.shift = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
)
def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
scale = self.scale(enc_feat)
shift = self.shift(enc_feat)
residual = w * (dec_feat * scale + shift)
out = dec_feat + residual
return out
@ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder):
def __init__(
self,
dim_embd=512,
n_head=8,
n_layers=9,
codebook_size=1024,
latent_size=256,
connect_list=["32", "64", "128", "256"],
fix_modules=["quantize", "generator"],
):
super(CodeFormer, self).__init__(
512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
)
if fix_modules is not None:
for module in fix_modules:
for param in getattr(self, module).parameters():
param.requires_grad = False
self.connect_list = connect_list
self.n_layers = n_layers
self.dim_embd = dim_embd
self.dim_mlp = dim_embd * 2
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer
self.ft_layers = nn.Sequential(
*[
TransformerSALayer(
embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
)
for _ in range(self.n_layers)
]
)
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
)
self.channels = {
"16": 512,
"32": 256,
"64": 256,
"128": 128,
"256": 128,
"512": 64,
}
# after second residual block for > 16, before attn layer for ==16
self.fuse_encoder_block = {
"512": 2,
"256": 5,
"128": 8,
"64": 11,
"32": 14,
"16": 18,
}
# after first residual block for > 16, before attn layer for ==16
self.fuse_generator_block = {
"16": 6,
"32": 9,
"64": 12,
"128": 15,
"256": 18,
"512": 21,
}
# fuse_convs_dict
self.fuse_convs_dict = nn.ModuleDict()
for f_size in self.connect_list:
in_ch = self.channels[f_size]
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
# ################### Encoder #####################
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
lq_feat = x
# ################# Transformer ###################
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
# BCHW -> BC(HW) -> (HW)BC
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
query_emb = feat_emb
# Transformer encoder
for layer in self.ft_layers:
query_emb = layer(query_emb, query_pos=pos_emb)
# output logits
logits = self.idx_pred_layer(query_emb) # (hw)bn
logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
if code_only: # for training stage II
# logits doesn't need softmax before cross_entropy loss
return logits, lq_feat
# ################# Quantization ###################
# if self.training:
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
# # b(hw)c -> bc(hw) -> bchw
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
# ------------
soft_one_hot = F.softmax(logits, dim=2)
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
quant_feat = self.quantize.get_codebook_feat(
top_idx, shape=[x.shape[0], 16, 16, 256]
)
# preserve gradients
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
if detach_16:
quant_feat = quant_feat.detach() # for training stage III
if adain:
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
# ################## Generator ####################
x = quant_feat
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if w > 0:
x = self.fuse_convs_dict[f_size](
enc_feat_dict[f_size].detach(), x, w
)
out = x
# logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat

View File

@ -0,0 +1 @@
vendored from git@github.com:sczhou/CodeFormer.git c5b4593074ba6214284d6acd5f1719b6c5d739af

View File

@ -0,0 +1,514 @@
"""
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
"""
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
@torch.jit.script
def swish(x):
return x * torch.sigmoid(x)
# Define VQVAE classes
class VectorQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, beta):
super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(
-1.0 / self.codebook_size, 1.0 / self.codebook_size
)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
(z_flattened**2).sum(dim=1, keepdim=True)
+ (self.embedding.weight**2).sum(1)
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
)
mean_distance = torch.mean(d)
# find closest encodings
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_scores, min_encoding_indices = torch.topk(
d, 1, dim=1, largest=False
)
# [0-1], higher score, higher confidence
min_encoding_scores = torch.exp(-min_encoding_scores / 10)
min_encodings = torch.zeros(
min_encoding_indices.shape[0], self.codebook_size
).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return (
z_q,
loss,
{
"perplexity": perplexity,
"min_encodings": min_encodings,
"min_encoding_indices": min_encoding_indices,
"min_encoding_scores": min_encoding_scores,
"mean_distance": mean_distance,
},
)
def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1
# shape: batch, height, width, channel
indices = indices.view(-1, 1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantizer(nn.Module):
def __init__(
self,
codebook_size,
emb_dim,
num_hiddens,
straight_through=False,
kl_weight=5e-4,
temp_init=1.0,
):
super().__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(
num_hiddens, codebook_size, 1
) # projects last encoder layer to quantized logits
self.embed = nn.Embedding(codebook_size, emb_dim)
def forward(self, z):
hard = self.straight_through if self.training else True
logits = self.proj(z)
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = (
self.kl_weight
* torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
)
min_encoding_indices = soft_one_hot.argmax(dim=1)
return z_q, diff, {"min_encoding_indices": min_encoding_indices}
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
super(ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x_in):
x = x_in
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class Encoder(nn.Module):
def __init__(
self,
in_channels,
nf,
emb_dim,
ch_mult,
num_res_blocks,
resolution,
attn_resolutions,
):
super().__init__()
self.nf = nf
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
curr_res = self.resolution
in_ch_mult = (1,) + tuple(ch_mult)
blocks = []
# initial convultion
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
# residual and downsampling blocks, with attention on smaller res (16x16)
for i in range(self.num_resolutions):
block_in_ch = nf * in_ch_mult[i]
block_out_ch = nf * ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != self.num_resolutions - 1:
blocks.append(Downsample(block_in_ch))
curr_res = curr_res // 2
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
# normalise and convert to latent size
blocks.append(normalize(block_in_ch))
blocks.append(
nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
block_in_ch = self.nf * self.ch_mult[-1]
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
blocks = []
# initial conv
blocks.append(
nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
)
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
for i in reversed(range(self.num_resolutions)):
block_out_ch = self.nf * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in self.attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != 0:
blocks.append(Upsample(block_in_ch))
curr_res = curr_res * 2
blocks.append(normalize(block_in_ch))
blocks.append(
nn.Conv2d(
block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(
self,
img_size,
nf,
ch_mult,
quantizer="nearest",
res_blocks=2,
attn_resolutions=[16],
codebook_size=1024,
emb_dim=256,
beta=0.25,
gumbel_straight_through=False,
gumbel_kl_weight=1e-8,
model_path=None,
):
super().__init__()
logger = get_root_logger()
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions,
)
if self.quantizer_type == "nearest":
self.beta = beta # 0.25
self.quantize = VectorQuantizer(
self.codebook_size, self.embed_dim, self.beta
)
elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through
self.kl_weight = gumbel_kl_weight
self.quantize = GumbelQuantizer(
self.codebook_size,
self.embed_dim,
self.gumbel_num_hiddens,
self.straight_through,
self.kl_weight,
)
self.generator = Generator(
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions,
)
if model_path is not None:
chkpt = torch.load(model_path, map_location="cpu")
if "params_ema" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params_ema"]
)
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
elif "params" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params"]
)
logger.info(f"vqgan is loaded from: {model_path} [params]")
else:
raise ValueError(f"Wrong params!")
def forward(self, x):
x = self.encoder(x)
quant, codebook_loss, quant_stats = self.quantize(x)
x = self.generator(quant)
return x, codebook_loss, quant_stats
# patch based discriminator
@ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module):
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
super().__init__()
layers = [
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, True),
]
ndf_mult = 1
ndf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
ndf_mult_prev = ndf_mult
ndf_mult = min(2**n, 8)
layers += [
nn.Conv2d(
ndf * ndf_mult_prev,
ndf * ndf_mult,
kernel_size=4,
stride=2,
padding=1,
bias=False,
),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True),
]
ndf_mult_prev = ndf_mult
ndf_mult = min(2**n_layers, 8)
layers += [
nn.Conv2d(
ndf * ndf_mult_prev,
ndf * ndf_mult,
kernel_size=4,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True),
]
layers += [
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
] # output 1 channel prediction map
self.main = nn.Sequential(*layers)
if model_path is not None:
chkpt = torch.load(model_path, map_location="cpu")
if "params_d" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params_d"]
)
elif "params" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params"]
)
else:
raise ValueError(f"Wrong params!")
def forward(self, x):
return self.main(x)

View File

@ -5,7 +5,11 @@
# pip-compile --output-file=requirements-dev.txt requirements-dev.in setup.py
#
absl-py==1.2.0
# via tensorboard
# via
# tb-nightly
# tensorboard
addict==2.4.0
# via basicsr
aiohttp==3.8.1
# via fsspec
aiosignal==1.2.0
@ -20,6 +24,10 @@ attrs==22.1.0
# via
# aiohttp
# pytest
basicsr==1.4.2
# via
# gfpgan
# realesrgan
black==22.8.0
# via -r requirements-dev.in
cachetools==5.2.0
@ -36,33 +44,56 @@ click==8.1.3
# imaginAIry (setup.py)
coverage==6.4.4
# via -r requirements-dev.in
cycler==0.11.0
# via matplotlib
diffusers==0.3.0
# via imaginAIry (setup.py)
dill==0.3.5.1
# via pylint
einops==0.3.0
# via imaginAIry (setup.py)
facexlib==0.2.5
# via
# gfpgan
# realesrgan
filelock==3.8.0
# via
# diffusers
# huggingface-hub
# transformers
filterpy==1.4.5
# via facexlib
fonttools==4.37.1
# via matplotlib
frozenlist==1.3.1
# via
# aiohttp
# aiosignal
fsspec[http]==2022.8.2
# via pytorch-lightning
ftfy==6.1.1
# via imaginAIry (setup.py)
future==0.18.2
# via pytorch-lightning
# via
# basicsr
# pytorch-lightning
gfpgan==1.3.7
# via
# imaginAIry (setup.py)
# realesrgan
google-auth==2.11.0
# via
# google-auth-oauthlib
# tb-nightly
# tensorboard
google-auth-oauthlib==0.4.6
# via tensorboard
# via
# tb-nightly
# tensorboard
grpcio==1.48.1
# via tensorboard
# via
# tb-nightly
# tensorboard
huggingface-hub==0.9.1
# via
# diffusers
@ -72,7 +103,9 @@ idna==3.3
# requests
# yarl
imageio==2.9.0
# via imaginAIry (setup.py)
# via
# imaginAIry (setup.py)
# scikit-image
importlib-metadata==4.12.0
# via diffusers
iniconfig==1.1.1
@ -81,14 +114,26 @@ isort==5.10.1
# via
# -r requirements-dev.in
# pylint
kiwisolver==1.4.4
# via matplotlib
kornia==0.6
# via imaginAIry (setup.py)
lazy-object-proxy==1.7.1
# via astroid
llvmlite==0.39.1
# via numba
lmdb==1.3.0
# via
# basicsr
# gfpgan
markdown==3.4.1
# via tensorboard
# via
# tb-nightly
# tensorboard
markupsafe==2.1.1
# via werkzeug
matplotlib==3.5.3
# via filterpy
mccabe==0.7.0
# via
# pylama
@ -99,13 +144,30 @@ multidict==6.0.2
# yarl
mypy-extensions==0.4.3
# via black
networkx==2.8.6
# via scikit-image
numba==0.56.2
# via facexlib
numpy==1.23.3
# via
# basicsr
# diffusers
# facexlib
# filterpy
# gfpgan
# imageio
# imaginAIry (setup.py)
# matplotlib
# numba
# opencv-python
# pytorch-lightning
# pywavelets
# realesrgan
# scikit-image
# scipy
# tb-nightly
# tensorboard
# tifffile
# torchmetrics
# torchvision
# transformers
@ -113,20 +175,33 @@ oauthlib==3.2.1
# via requests-oauthlib
omegaconf==2.1.1
# via imaginAIry (setup.py)
opencv-python==4.6.0.66
# via
# basicsr
# facexlib
# gfpgan
# realesrgan
packaging==21.3
# via
# huggingface-hub
# kornia
# matplotlib
# pytest
# pytorch-lightning
# scikit-image
# torchmetrics
# transformers
pathspec==0.10.1
# via black
pillow==9.2.0
# via
# basicsr
# diffusers
# facexlib
# imageio
# matplotlib
# realesrgan
# scikit-image
# torchvision
platformdirs==2.5.2
# via
@ -135,7 +210,9 @@ platformdirs==2.5.2
pluggy==1.0.0
# via pytest
protobuf==3.19.4
# via tensorboard
# via
# tb-nightly
# tensorboard
py==1.11.0
# via pytest
pyasn1==0.4.8
@ -159,27 +236,39 @@ pylama==8.4.1
pylint==2.15.2
# via -r requirements-dev.in
pyparsing==3.0.9
# via packaging
# via
# matplotlib
# packaging
pytest==7.1.3
# via -r requirements-dev.in
python-dateutil==2.8.2
# via matplotlib
pytorch-lightning==1.4.2
# via imaginAIry (setup.py)
pywavelets==1.3.0
# via scikit-image
pyyaml==6.0
# via
# basicsr
# gfpgan
# huggingface-hub
# omegaconf
# pytorch-lightning
# transformers
realesrgan==0.2.5.0
# via imaginAIry (setup.py)
regex==2022.9.11
# via
# diffusers
# transformers
requests==2.28.1
# via
# basicsr
# diffusers
# fsspec
# huggingface-hub
# requests-oauthlib
# tb-nightly
# tensorboard
# torchvision
# transformers
@ -187,18 +276,38 @@ requests-oauthlib==1.3.1
# via google-auth-oauthlib
rsa==4.9
# via google-auth
scikit-image==0.19.3
# via basicsr
scipy==1.9.1
# via
# basicsr
# facexlib
# filterpy
# gfpgan
# scikit-image
six==1.16.0
# via
# google-auth
# grpcio
# python-dateutil
snowballstemmer==2.2.0
# via pydocstyle
tb-nightly==2.11.0a20220912
# via
# basicsr
# gfpgan
tensorboard==2.10.0
# via pytorch-lightning
tensorboard-data-server==0.6.1
# via tensorboard
# via
# tb-nightly
# tensorboard
tensorboard-plugin-wit==1.8.1
# via tensorboard
# via
# tb-nightly
# tensorboard
tifffile==2022.8.12
# via scikit-image
tokenizers==0.12.1
# via transformers
tomli==2.0.1
@ -210,10 +319,14 @@ tomlkit==0.11.4
# via pylint
torch==1.12.1
# via
# basicsr
# diffusers
# facexlib
# gfpgan
# imaginAIry (setup.py)
# kornia
# pytorch-lightning
# realesrgan
# torchmetrics
# torchvision
torchmetrics==0.6.0
@ -221,12 +334,21 @@ torchmetrics==0.6.0
# imaginAIry (setup.py)
# pytorch-lightning
torchvision==0.13.1
# via imaginAIry (setup.py)
# via
# basicsr
# facexlib
# gfpgan
# imaginAIry (setup.py)
# realesrgan
tqdm==4.64.1
# via
# basicsr
# facexlib
# gfpgan
# huggingface-hub
# imaginAIry (setup.py)
# pytorch-lightning
# realesrgan
# transformers
transformers==4.19.2
# via imaginAIry (setup.py)
@ -238,12 +360,22 @@ typing-extensions==4.3.0
# torchvision
urllib3==1.26.12
# via requests
wcwidth==0.2.5
# via ftfy
werkzeug==2.2.2
# via tensorboard
# via
# tb-nightly
# tensorboard
wheel==0.37.1
# via tensorboard
# via
# tb-nightly
# tensorboard
wrapt==1.14.1
# via astroid
yapf==0.32.0
# via
# basicsr
# gfpgan
yarl==1.8.1
# via aiohttp
zipp==3.8.1

View File

@ -23,7 +23,7 @@ setup(
install_requires=[
"click",
"ftfy", # for vendored clip
"torch",
"torch>=1.2.0",
"numpy",
"tqdm",
"diffusers",
@ -37,5 +37,7 @@ setup(
"kornia==0.6",
# k-diffusion for use with find_noise.py
# "k-diffusion@git+https://github.com/crowsonkb/k-diffusion.git@71ba7d6735e9cba1945b429a21345960eb3f151c#egg=k-diffusion",
"realesrgan",
"gfpgan>=1.3.7",
],
)

View File

@ -1,4 +1,4 @@
from imaginairy.api import imagine_image_files, imagine_images
from imaginairy.api import imagine, imagine_image_files
from imaginairy.schema import ImaginePrompt
from . import TESTS_FOLDER
@ -8,7 +8,7 @@ def test_imagine():
prompt = ImaginePrompt(
"a scenic landscape", width=512, height=256, steps=20, seed=1
)
result = next(imagine_images(prompt))
result = next(imagine(prompt))
assert result.md5() == "4c5957c498881d365cfcf13014812af0"
result.img.save(f"{TESTS_FOLDER}/test_output/scenic_landscape.png")