mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
feature: face enhancement and upscaling!!
This commit is contained in:
parent
6fa776053f
commit
541ecb9701
3
.gitignore
vendored
3
.gitignore
vendored
@ -14,4 +14,5 @@ build
|
||||
dist
|
||||
**/*.ckpt
|
||||
**/*.egg-info
|
||||
tests/test_output
|
||||
tests/test_output
|
||||
gfpgan/**
|
26
Makefile
26
Makefile
@ -16,6 +16,7 @@ init: require_pyenv ## Setup a dev environment for local development.
|
||||
@echo -e "\033[0;32m ✔️ 🐍 $(venv_name) virtualenv activated \033[0m"
|
||||
pip install --upgrade pip pip-tools
|
||||
pip-sync requirements-dev.txt
|
||||
pip install -e . --no-deps
|
||||
@echo -e "\nEnvironment setup! ✨ 🍰 ✨ 🐍 \n\nCopy this path to tell PyCharm where your virtualenv is. You may have to click the refresh button in the pycharm file explorer.\n"
|
||||
@echo -e "\033[0;32m"
|
||||
@pyenv which python
|
||||
@ -75,6 +76,31 @@ vendor_openai_clip:
|
||||
git --git-dir ./downloads/CLIP/.git rev-parse HEAD | tee ./imaginairy/vendored/clip/clip-commit-hash.txt
|
||||
echo "vendored from git@github.com:openai/CLIP.git" | tee ./imaginairy/vendored/clip/readme.txt
|
||||
|
||||
revendorize:
|
||||
make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip
|
||||
make vendorize REPO=git@github.com:xinntao/Real-ESRGAN.git PKG=realesrgan
|
||||
|
||||
|
||||
vendorize: ## vendorize a github repo. `make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip`
|
||||
mkdir -p ./downloads
|
||||
-cd ./downloads && git clone $(REPO) $(PKG)
|
||||
cd ./downloads/$(PKG) && git pull
|
||||
rm -rf ./imaginairy/vendored/$(PKG)
|
||||
cp -R ./downloads/$(PKG)/$(PKG) imaginairy/vendored/
|
||||
git --git-dir ./downloads/$(PKG)/.git rev-parse HEAD | tee ./imaginairy/vendored/$(PKG)/clip-commit-hash.txt
|
||||
touch ./imaginairy/vendored/$(PKG)/version.py
|
||||
echo "vendored from $(REPO)" | tee ./imaginairy/vendored/$(PKG)/readme.txt
|
||||
|
||||
vendorize_whole_repo:
|
||||
mkdir -p ./downloads
|
||||
-cd ./downloads && git clone $(REPO) $(PKG)
|
||||
cd ./downloads/$(PKG) && git pull
|
||||
rm -rf ./imaginairy/vendored/$(PKG)
|
||||
cp -R ./downloads/$(PKG) imaginairy/vendored/
|
||||
git --git-dir ./downloads/$(PKG)/.git rev-parse HEAD | tee ./imaginairy/vendored/$(PKG)/clip-commit-hash.txt
|
||||
touch ./imaginairy/vendored/$(PKG)/version.py
|
||||
echo "vendored from $(REPO)" | tee ./imaginairy/vendored/$(PKG)/readme.txt
|
||||
|
||||
|
||||
help: ## Show this help message.
|
||||
@## https://gist.github.com/prwhite/8168133#gistcomment-1716694
|
||||
|
46
README.md
46
README.md
@ -53,6 +53,17 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/tests/data/girl_with_a_pearl_earring.jpg" height="256"> =>
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000105_33084057_DDIM40_PS7.5_portrait_of_a_smiling_lady._oil_painting._.jpg" height="256">
|
||||
|
||||
### Face Enhancement [by CodeFormer](https://github.com/sczhou/CodeFormer)
|
||||
|
||||
```bash
|
||||
>> imagine "a couple smiling" --steps 40 --seed 1 --fix-faces
|
||||
```
|
||||
<img src="assets/000178_1_PLMS40_PS7.5_a_couple_smiling_nofix.png" height="256"> => <img src="assets/000178_1_PLMS40_PS7.5_a_couple_smiling_fixed.png" height="256">
|
||||
|
||||
|
||||
### Upscaling [by RealESRGAN](https://github.com/xinntao/Real-ESRGAN)
|
||||
<img src="assets/000206_856637805_PLMS40_PS7.5_colorful_smoke.jpg">
|
||||
<img src="assets/000206_856637805_PLMS40_PS7.5_colorful_smoke_upscaled.jpg" height="512">
|
||||
|
||||
## Features
|
||||
|
||||
@ -68,20 +79,20 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
|
||||
## How To
|
||||
|
||||
```python
|
||||
from imaginairy import imagine_images, imagine_image_files, ImaginePrompt, WeightedPrompt
|
||||
from imaginairy import imagine, imagine_image_files, ImaginePrompt, WeightedPrompt
|
||||
|
||||
prompts = [
|
||||
ImaginePrompt("a scenic landscape", seed=1),
|
||||
ImaginePrompt("a bowl of fruit"),
|
||||
ImaginePrompt([
|
||||
WeightedPrompt("cat", weight=1),
|
||||
WeightedPrompt("dog", weight=1),
|
||||
WeightedPrompt("cat", weight=1),
|
||||
WeightedPrompt("dog", weight=1),
|
||||
])
|
||||
]
|
||||
for result in imagine_images(prompts):
|
||||
for result in imagine(prompts):
|
||||
# do something
|
||||
result.save("my_image.jpg")
|
||||
|
||||
|
||||
# or
|
||||
|
||||
imagine_image_files(prompts, outdir="./my-art")
|
||||
@ -109,11 +120,12 @@ imagine_image_files(prompts, outdir="./my-art")
|
||||
|
||||
## Todo
|
||||
- performance optimizations
|
||||
- https://github.com/neonsecret/stable-diffusion https://github.com/CompVis/stable-diffusion/pull/177
|
||||
|
||||
- ✅ https://github.com/huggingface/diffusers/blob/main/docs/source/optimization/fp16.mdx
|
||||
- ✅ https://github.com/CompVis/stable-diffusion/compare/main...Doggettx:stable-diffusion:autocast-improvements#
|
||||
- ✅ https://www.reddit.com/r/StableDiffusion/comments/xalaws/test_update_for_less_memory_usage_and_higher/
|
||||
- deploy to pypi
|
||||
- https://github.com/neonsecret/stable-diffusion https://github.com/CompVis/stable-diffusion/pull/177
|
||||
- ✅ deploy to pypi
|
||||
- add tests
|
||||
- set up ci (test/lint/format)
|
||||
- add docs
|
||||
@ -124,11 +136,14 @@ imagine_image_files(prompts, outdir="./my-art")
|
||||
- ✅ init-image at command line
|
||||
- prompt expansion
|
||||
- Image Generation Features
|
||||
- add in all the samplers
|
||||
- upscaling
|
||||
- ✅ realesrgan
|
||||
- ldm
|
||||
- https://github.com/lowfuel/progrock-stable
|
||||
- face improvements
|
||||
- gfpgan - https://github.com/TencentARC/GFPGAN
|
||||
- codeformer - https://github.com/sczhou/CodeFormer
|
||||
- ✅ face enhancers
|
||||
- ✅ gfpgan - https://github.com/TencentARC/GFPGAN
|
||||
- ✅ codeformer - https://github.com/sczhou/CodeFormer
|
||||
- image describe feature - https://replicate.com/methexis-inc/img2prompt
|
||||
- outpainting
|
||||
- inpainting
|
||||
@ -148,5 +163,16 @@ imagine_image_files(prompts, outdir="./my-art")
|
||||
- textual inversion
|
||||
- https://www.reddit.com/r/StableDiffusion/comments/xbwb5y/how_to_run_textual_inversion_locally_train_your/
|
||||
- https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb#scrollTo=50JuJUM8EG1h
|
||||
- https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion_textual_inversion_library_navigator.ipynb
|
||||
- zooming videos? a la disco diffusion
|
||||
- fix saturation at high CFG https://www.reddit.com/r/StableDiffusion/comments/xalo78/fixing_excessive_contrastsaturation_resulting/
|
||||
- https://www.reddit.com/r/StableDiffusion/comments/xbrrgt/a_rundown_of_twenty_new_methodsoptions_added_to/
|
||||
|
||||
## Noteable Stable Diffusion Implementations
|
||||
- https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion
|
||||
- https://github.com/lstein/stable-diffusion
|
||||
- https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
||||
|
||||
## Further Reading
|
||||
- Differences between samplers
|
||||
- https://www.reddit.com/r/StableDiffusion/comments/xbeyw3/can_anyone_offer_a_little_guidance_on_the/
|
BIN
assets/000178_1_PLMS40_PS7.5_a_couple_smiling_fixed.png
Normal file
BIN
assets/000178_1_PLMS40_PS7.5_a_couple_smiling_fixed.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 378 KiB |
BIN
assets/000178_1_PLMS40_PS7.5_a_couple_smiling_nofix.png
Normal file
BIN
assets/000178_1_PLMS40_PS7.5_a_couple_smiling_nofix.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 286 KiB |
BIN
assets/000206_856637805_PLMS40_PS7.5_colorful_smoke.jpg
Normal file
BIN
assets/000206_856637805_PLMS40_PS7.5_colorful_smoke.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 33 KiB |
BIN
assets/000206_856637805_PLMS40_PS7.5_colorful_smoke_upscaled.jpg
Normal file
BIN
assets/000206_856637805_PLMS40_PS7.5_colorful_smoke_upscaled.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 220 KiB |
@ -2,5 +2,5 @@ import os
|
||||
|
||||
os.putenv("PYTORCH_ENABLE_MPS_FALLBACK", "1")
|
||||
|
||||
from .api import imagine_image_files, imagine_images # noqa
|
||||
from .api import imagine, imagine_image_files # noqa
|
||||
from .schema import ImaginePrompt, ImagineResult, WeightedPrompt # noqa
|
||||
|
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from contextlib import nullcontext
|
||||
from functools import lru_cache
|
||||
|
||||
@ -10,11 +9,13 @@ import torch
|
||||
import torch.nn
|
||||
from einops import rearrange
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image, ImageDraw
|
||||
from PIL import Image, ImageDraw, ImageFilter
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import autocast
|
||||
from transformers import cached_path
|
||||
|
||||
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
||||
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
||||
from imaginairy.modules.diffusion.ddim import DDIMSampler
|
||||
from imaginairy.modules.diffusion.plms import PLMSSampler
|
||||
from imaginairy.safety import is_nsfw, safety_models
|
||||
@ -93,7 +94,7 @@ def imagine_image_files(
|
||||
):
|
||||
big_path = os.path.join(outdir, "upscaled")
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
os.makedirs(big_path, exist_ok=True)
|
||||
|
||||
base_count = len(os.listdir(outdir))
|
||||
step_count = 0
|
||||
output_file_extension = output_file_extension.lower()
|
||||
@ -116,7 +117,7 @@ def imagine_image_files(
|
||||
img.save(os.path.join(steps_path, filename))
|
||||
|
||||
img_callback = _record_steps if record_step_images else None
|
||||
for result in imagine_images(
|
||||
for result in imagine(
|
||||
prompts,
|
||||
latent_channels=latent_channels,
|
||||
downsampling_factor=downsampling_factor,
|
||||
@ -131,14 +132,15 @@ def imagine_image_files(
|
||||
|
||||
result.save(filepath)
|
||||
logger.info(f" 🖼 saved to: {filepath}")
|
||||
if prompt.upscale:
|
||||
bigfilepath = (os.path.join(big_path, basefilename) + ".jpg",)
|
||||
enlarge_realesrgan2x(filepath, bigfilepath)
|
||||
logger.info(f" upscaled 🖼 saved to: {filepath}")
|
||||
if result.upscaled_img:
|
||||
os.makedirs(big_path, exist_ok=True)
|
||||
bigfilepath = os.path.join(big_path, basefilename) + "_upscaled.jpg"
|
||||
result.save_upscaled(bigfilepath)
|
||||
logger.info(f" Upscaled 🖼 saved to: {bigfilepath}")
|
||||
base_count += 1
|
||||
|
||||
|
||||
def imagine_images(
|
||||
def imagine(
|
||||
prompts,
|
||||
latent_channels=4,
|
||||
downsampling_factor=8,
|
||||
@ -149,15 +151,16 @@ def imagine_images(
|
||||
half_mode=None,
|
||||
):
|
||||
model = load_model(tile_mode=tile_mode)
|
||||
|
||||
if not IMAGINAIRY_ALLOW_NSFW:
|
||||
# needs to be loaded before we set default tensor type to half
|
||||
safety_models()
|
||||
# only run half-mode on cuda. run it by default
|
||||
half_mode = True if half_mode is None and get_device() == "cuda" else False
|
||||
half_mode = half_mode is None and get_device() == "cuda"
|
||||
if half_mode:
|
||||
model = model.half()
|
||||
# needed when model is in half mode, remove if not using half mode
|
||||
torch.set_default_tensor_type(torch.HalfTensor)
|
||||
# torch.set_default_tensor_type(torch.HalfTensor)
|
||||
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
|
||||
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
|
||||
_img_callback = None
|
||||
@ -242,20 +245,23 @@ def imagine_images(
|
||||
|
||||
for x_sample in x_samples:
|
||||
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
||||
img = Image.fromarray(x_sample.astype(np.uint8))
|
||||
x_sample_8_orig = x_sample.astype(np.uint8)
|
||||
img = Image.fromarray(x_sample_8_orig)
|
||||
upscaled_img = None
|
||||
if not IMAGINAIRY_ALLOW_NSFW and is_nsfw(
|
||||
img, x_sample, half_mode=half_mode
|
||||
):
|
||||
logger.info(" ⚠️ Filtering NSFW image")
|
||||
img = Image.new("RGB", img.size, (228, 150, 150))
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=10))
|
||||
|
||||
if prompt.fix_faces:
|
||||
img = fix_faces_GFPGAN(img)
|
||||
# if prompt.upscale:
|
||||
# enlarge_realesrgan2x(
|
||||
# filepath,
|
||||
# os.path.join(big_path, basefilename) + ".jpg",
|
||||
# )
|
||||
yield ImagineResult(img=img, prompt=prompt)
|
||||
logger.info(" Fixing 😊 's in 🖼 using GFPGAN...")
|
||||
img = enhance_faces(img, fidelity=0.2)
|
||||
if prompt.upscale:
|
||||
logger.info(" Upscaling 🖼 using real-ESRGAN...")
|
||||
upscaled_img = upscale_image(img)
|
||||
|
||||
yield ImagineResult(img=img, prompt=prompt, upscaled_img=upscaled_img)
|
||||
|
||||
|
||||
def prompt_normalized(prompt):
|
||||
@ -263,14 +269,6 @@ def prompt_normalized(prompt):
|
||||
|
||||
|
||||
DOWNLOADED_FILES_PATH = f"{LIB_PATH}/../downloads/"
|
||||
ESRGAN_PATH = DOWNLOADED_FILES_PATH + "realesrgan-ncnn-vulkan/realesrgan-ncnn-vulkan"
|
||||
|
||||
|
||||
def enlarge_realesrgan2x(src, dst):
|
||||
process = subprocess.Popen(
|
||||
[ESRGAN_PATH, "-i", src, "-o", dst, "-n", "realesrgan-x4plus"]
|
||||
)
|
||||
process.wait()
|
||||
|
||||
|
||||
def get_sampler(sampler_type, model):
|
||||
@ -279,30 +277,3 @@ def get_sampler(sampler_type, model):
|
||||
return PLMSSampler(model)
|
||||
elif sampler_type == "DDIM":
|
||||
return DDIMSampler(model)
|
||||
|
||||
|
||||
def gfpgan_model():
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
return GFPGANer(
|
||||
model_path=DOWNLOADED_FILES_PATH
|
||||
+ "GFPGAN/experiments/pretrained_models/GFPGANv1.3.pth",
|
||||
upscale=1,
|
||||
arch="clean",
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=None,
|
||||
device=torch.device(get_device()),
|
||||
)
|
||||
|
||||
|
||||
def fix_faces_GFPGAN(image):
|
||||
image = image.convert("RGB")
|
||||
cropped_faces, restored_faces, restored_img = gfpgan_model().enhance(
|
||||
np.array(image, dtype=np.uint8),
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
return res
|
||||
|
@ -46,6 +46,12 @@ def filter_torch_warnings():
|
||||
category=UserWarning,
|
||||
message=r"The operator .*?is not currently supported.*",
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=UserWarning, message=r"The parameter 'pretrained' is.*"
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=UserWarning, message=r"Arguments other than a weight.*"
|
||||
)
|
||||
|
||||
|
||||
def setup_env():
|
||||
|
@ -92,6 +92,12 @@ def configure_logging(level="INFO"):
|
||||
type=int,
|
||||
help="What seed to use for randomness. Allows reproducible image renders",
|
||||
)
|
||||
@click.option("--upscale", is_flag=True)
|
||||
@click.option(
|
||||
"--upscale-method", default="realesrgan", type=click.Choice(["realesrgan"])
|
||||
)
|
||||
@click.option("--fix-faces", is_flag=True)
|
||||
@click.option("--fix-faces-method", default="gfpgan", type=click.Choice(["gfpgan"]))
|
||||
@click.option(
|
||||
"--sampler-type",
|
||||
default="PLMS",
|
||||
@ -128,6 +134,10 @@ def imagine_cmd(
|
||||
width,
|
||||
steps,
|
||||
seed,
|
||||
upscale,
|
||||
upscale_method,
|
||||
fix_faces,
|
||||
fix_faces_method,
|
||||
sampler_type,
|
||||
ddim_eta,
|
||||
log_level,
|
||||
@ -142,7 +152,7 @@ def imagine_cmd(
|
||||
|
||||
total_image_count = len(prompt_texts) * repeats
|
||||
logger.info(
|
||||
f"🤖🧠 received {len(prompt_texts)} prompt(s) and will repeat them {repeats} times to create {total_image_count} images."
|
||||
f"🤖🧠 imaginAIry received {len(prompt_texts)} prompt(s) and will repeat them {repeats} times to create {total_image_count} images."
|
||||
)
|
||||
if init_image and sampler_type != "DDIM":
|
||||
sampler_type = "DDIM"
|
||||
@ -161,8 +171,8 @@ def imagine_cmd(
|
||||
steps=steps,
|
||||
height=height,
|
||||
width=width,
|
||||
upscale=False,
|
||||
fix_faces=False,
|
||||
upscale=upscale,
|
||||
fix_faces=fix_faces,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
|
||||
@ -172,6 +182,7 @@ def imagine_cmd(
|
||||
ddim_eta=ddim_eta,
|
||||
record_step_images="images" in show_work,
|
||||
tile_mode=tile,
|
||||
output_file_extension="png",
|
||||
)
|
||||
|
||||
|
||||
|
0
imaginairy/enhancers/__init__.py
Normal file
0
imaginairy/enhancers/__init__.py
Normal file
86
imaginairy/enhancers/face_restoration_codeformer.py
Normal file
86
imaginairy/enhancers/face_restoration_codeformer.py
Normal file
@ -0,0 +1,86 @@
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from imaginairy.utils import get_cached_url_path
|
||||
from imaginairy.vendored.codeformer.codeformer_arch import CodeFormer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def codeformer_model():
|
||||
model = CodeFormer(
|
||||
dim_embd=512,
|
||||
codebook_size=1024,
|
||||
n_head=8,
|
||||
n_layers=9,
|
||||
connect_list=["32", "64", "128", "256"],
|
||||
).to("cpu")
|
||||
url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||
ckpt_path = get_cached_url_path(url)
|
||||
checkpoint = torch.load(ckpt_path)["params_ema"]
|
||||
model.load_state_dict(checkpoint)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def enhance_faces(img, fidelity=0):
|
||||
net = codeformer_model()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
face_helper = FaceRestoreHelper(
|
||||
1,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model="retinaface_resnet50",
|
||||
save_ext="png",
|
||||
use_parse=True,
|
||||
device=device,
|
||||
)
|
||||
face_helper.clean_all()
|
||||
|
||||
image = img.convert("RGB")
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
# rotate to BGR
|
||||
np_img = np_img[:, :, ::-1]
|
||||
|
||||
face_helper.read_image(np_img)
|
||||
# get face landmarks for each face
|
||||
num_det_faces = face_helper.get_face_landmarks_5(
|
||||
only_center_face=False, resize=640, eye_dist_threshold=5
|
||||
)
|
||||
logger.info(f" Enhancing {num_det_faces} faces")
|
||||
# align and warp each face
|
||||
face_helper.align_warp_face()
|
||||
|
||||
# face restoration for each cropped face
|
||||
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
||||
# prepare data
|
||||
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to("cpu")
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output = net(cropped_face_t, w=fidelity, adain=True)[0]
|
||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
||||
del output
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as error:
|
||||
logger.error(f"\tFailed inference for CodeFormer: {error}")
|
||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||
|
||||
restored_face = restored_face.astype("uint8")
|
||||
face_helper.add_restored_face(restored_face)
|
||||
|
||||
face_helper.get_inverse_affine(None)
|
||||
# paste each restored face to the input image
|
||||
restored_img = face_helper.paste_faces_to_input_image()
|
||||
res = Image.fromarray(restored_img[:, :, ::-1])
|
||||
return res
|
50
imaginairy/enhancers/face_restoration_gfpgan.py
Normal file
50
imaginairy/enhancers/face_restoration_gfpgan.py
Normal file
@ -0,0 +1,50 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from imaginairy.utils import get_cached_url_path, get_device
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def face_enhance_model(model_type="codeformer"):
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
if model_type == "gfpgan":
|
||||
arch = "clean"
|
||||
url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||
elif model_type == "codeformer":
|
||||
arch = "CodeFormer"
|
||||
url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth"
|
||||
else:
|
||||
raise ValueError("model_type must be one of gfpgan, codeformer")
|
||||
|
||||
model_path = get_cached_url_path(url)
|
||||
|
||||
if get_device() == "cuda":
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
return GFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=1,
|
||||
arch=arch,
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=None,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
def fix_faces_gfpgan(image, model_type):
|
||||
image = image.convert("RGB")
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
cropped_faces, restored_faces, restored_img = face_enhance_model(
|
||||
model_type
|
||||
).enhance(
|
||||
np_img, has_aligned=False, only_center_face=False, paste_back=True, weight=0
|
||||
)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
return res
|
40
imaginairy/enhancers/upscale_realesrgan.py
Normal file
40
imaginairy/enhancers/upscale_realesrgan.py
Normal file
@ -0,0 +1,40 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from imaginairy.utils import get_cached_url_path, get_device
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def realesrgan_upsampler():
|
||||
model = RRDBNet(
|
||||
num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
|
||||
)
|
||||
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
|
||||
model_path = get_cached_url_path(url)
|
||||
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0)
|
||||
|
||||
if get_device() == "cuda":
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
upsampler.device = torch.device(device)
|
||||
upsampler.model.to(device)
|
||||
|
||||
return upsampler
|
||||
|
||||
|
||||
def upscale_image(img):
|
||||
img = img.convert("RGB")
|
||||
np_img = np.array(img, dtype=np.uint8)
|
||||
upsampler_output, img_mode = realesrgan_upsampler().enhance(np_img[:, :, ::-1])
|
||||
return Image.fromarray(upsampler_output[:, :, ::-1], mode=img_mode)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
realesrgan_upsampler()
|
@ -51,7 +51,7 @@ class VectorQuantizer(nn.Module):
|
||||
if self.unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed + 1
|
||||
print(
|
||||
logger.info(
|
||||
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices."
|
||||
)
|
||||
|
@ -174,7 +174,8 @@ class DDIMSampler:
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
# run on CPU for seed consistency. M1/mps runs were not consistent otherwise
|
||||
img = torch.randn(shape, device="cpu").to(device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
|
@ -175,7 +175,8 @@ class PLMSSampler(object):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
|
||||
img = torch.randn(shape, device="cpu").to(device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
|
@ -26,9 +26,7 @@ def pil_img_to_latent(model, img, batch_size=1, device="cuda", half=True):
|
||||
return model.get_first_stage_encoding(model.encode_first_stage(init_image))
|
||||
|
||||
|
||||
def find_noise_for_image(
|
||||
model, pil_img, prompt, steps=50, cond_scale=1.0, verbose=False, half=True
|
||||
):
|
||||
def find_noise_for_image(model, pil_img, prompt, steps=50, cond_scale=1.0, half=True):
|
||||
img_latent = pil_img_to_latent(
|
||||
model, pil_img, batch_size=1, device="cuda", half=half
|
||||
)
|
||||
@ -38,13 +36,12 @@ def find_noise_for_image(
|
||||
prompt,
|
||||
steps=steps,
|
||||
cond_scale=cond_scale,
|
||||
verbose=verbose,
|
||||
half=half,
|
||||
)
|
||||
|
||||
|
||||
def find_noise_for_latent(
|
||||
model, img_latent, prompt, steps=50, cond_scale=1.0, verbose=False, half=True
|
||||
model, img_latent, prompt, steps=50, cond_scale=1.0, half=True
|
||||
):
|
||||
import k_diffusion as K
|
||||
|
||||
@ -59,9 +56,6 @@ def find_noise_for_latent(
|
||||
dnw = K.external.CompVisDenoiser(model)
|
||||
sigmas = dnw.get_sigmas(steps).flip(0)
|
||||
|
||||
if verbose:
|
||||
print(sigmas)
|
||||
|
||||
with (torch.no_grad(), _autocast(get_device())):
|
||||
for i in range(1, len(sigmas)):
|
||||
x_in = torch.cat([x] * 2)
|
||||
|
@ -91,8 +91,9 @@ class ExifCodes:
|
||||
|
||||
|
||||
class ImagineResult:
|
||||
def __init__(self, img, prompt: ImaginePrompt):
|
||||
def __init__(self, img, prompt: ImaginePrompt, upscaled_img=None):
|
||||
self.img = img
|
||||
self.upscaled_img = upscaled_img
|
||||
self.prompt = prompt
|
||||
self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc)
|
||||
self.torch_backend = get_device()
|
||||
@ -113,7 +114,7 @@ class ImagineResult:
|
||||
"prompt": self.prompt.as_dict(),
|
||||
}
|
||||
|
||||
def save(self, save_path):
|
||||
def _exif(self):
|
||||
exif = Exif()
|
||||
exif[ExifCodes.ImageDescription] = self.prompt.prompt_description()
|
||||
exif[ExifCodes.UserComment] = json.dumps(self.metadata_dict())
|
||||
@ -121,4 +122,10 @@ class ImagineResult:
|
||||
exif[ExifCodes.Software] = "Imaginairy / Stable Diffusion v1.4"
|
||||
exif[ExifCodes.DateTime] = self.created_at.isoformat(sep=" ")[:19]
|
||||
exif[ExifCodes.HostComputer] = f"{self.torch_backend}:{self.hardware_name}"
|
||||
self.img.save(save_path, exif=exif)
|
||||
return exif
|
||||
|
||||
def save(self, save_path):
|
||||
self.img.save(save_path, exif=self._exif())
|
||||
|
||||
def save_upscaled(self, save_path):
|
||||
self.upscaled_img.save(save_path, exif=self._exif())
|
||||
|
@ -1,14 +1,17 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os.path
|
||||
import platform
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from transformers import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -115,3 +118,34 @@ def pillow_img_to_torch_image(image, max_height=512, max_width=512):
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0, w, h
|
||||
|
||||
|
||||
def get_cache_dir():
|
||||
xdg_cache_home = os.getenv("XDG_CACHE_HOME", None)
|
||||
if xdg_cache_home is None:
|
||||
user_home = os.getenv("HOME", None)
|
||||
if user_home:
|
||||
xdg_cache_home = os.path.join(user_home, ".cache")
|
||||
|
||||
if xdg_cache_home is not None:
|
||||
return os.path.join(xdg_cache_home, "imaginairy", "weights")
|
||||
|
||||
return os.path.join(os.path.dirname(__file__), ".cached-downloads")
|
||||
|
||||
|
||||
def get_cached_url_path(url):
|
||||
try:
|
||||
return cached_path(url)
|
||||
except OSError:
|
||||
pass
|
||||
filename = url.split("/")[-1]
|
||||
dest = get_cache_dir()
|
||||
os.makedirs(dest, exist_ok=True)
|
||||
dest_path = os.path.join(dest, filename)
|
||||
if os.path.exists(dest_path):
|
||||
return dest_path
|
||||
r = requests.get(url)
|
||||
|
||||
with open(dest_path, "wb") as f:
|
||||
f.write(r.content)
|
||||
return dest_path
|
||||
|
0
imaginairy/vendored/clip/version.py
Normal file
0
imaginairy/vendored/clip/version.py
Normal file
0
imaginairy/vendored/codeformer/__init__.py
Normal file
0
imaginairy/vendored/codeformer/__init__.py
Normal file
325
imaginairy/vendored/codeformer/codeformer_arch.py
Normal file
325
imaginairy/vendored/codeformer/codeformer_arch.py
Normal file
@ -0,0 +1,325 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# from basicsr.archs.vqgan_arch import *
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import Tensor, nn
|
||||
|
||||
from imaginairy.vendored.codeformer.vqgan_arch import ResBlock, VQAutoEncoder
|
||||
|
||||
|
||||
def calc_mean_std(feat, eps=1e-5):
|
||||
"""Calculate mean and std for adaptive_instance_normalization.
|
||||
|
||||
Args:
|
||||
feat (Tensor): 4D tensor.
|
||||
eps (float): A small value added to the variance to avoid
|
||||
divide-by-zero. Default: 1e-5.
|
||||
"""
|
||||
size = feat.size()
|
||||
assert len(size) == 4, "The input feature should be 4D tensor."
|
||||
b, c = size[:2]
|
||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
||||
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
||||
return feat_mean, feat_std
|
||||
|
||||
|
||||
def adaptive_instance_normalization(content_feat, style_feat):
|
||||
"""Adaptive instance normalization.
|
||||
|
||||
Adjust the reference features to have the similar color and illuminations
|
||||
as those in the degradate features.
|
||||
|
||||
Args:
|
||||
content_feat (Tensor): The reference feature.
|
||||
style_feat (Tensor): The degradate features.
|
||||
"""
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = calc_mean_std(style_feat)
|
||||
content_mean, content_std = calc_mean_std(content_feat)
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
|
||||
size
|
||||
)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
if mask is None:
|
||||
mask = torch.zeros(
|
||||
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
|
||||
)
|
||||
not_mask = ~mask
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack(
|
||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos_y = torch.stack(
|
||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||
|
||||
|
||||
class TransformerSALayer(nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model - MLP
|
||||
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
||||
|
||||
self.norm1 = nn.LayerNorm(embed_dim)
|
||||
self.norm2 = nn.LayerNorm(embed_dim)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
# self attention
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(
|
||||
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
||||
)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
|
||||
# ffn
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
return tgt
|
||||
|
||||
|
||||
class Fuse_sft_block(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.encode_enc = ResBlock(2 * in_ch, out_ch)
|
||||
|
||||
self.scale = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
self.shift = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, enc_feat, dec_feat, w=1):
|
||||
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
||||
scale = self.scale(enc_feat)
|
||||
shift = self.shift(enc_feat)
|
||||
residual = w * (dec_feat * scale + shift)
|
||||
out = dec_feat + residual
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class CodeFormer(VQAutoEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
dim_embd=512,
|
||||
n_head=8,
|
||||
n_layers=9,
|
||||
codebook_size=1024,
|
||||
latent_size=256,
|
||||
connect_list=["32", "64", "128", "256"],
|
||||
fix_modules=["quantize", "generator"],
|
||||
):
|
||||
super(CodeFormer, self).__init__(
|
||||
512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
|
||||
)
|
||||
|
||||
if fix_modules is not None:
|
||||
for module in fix_modules:
|
||||
for param in getattr(self, module).parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.connect_list = connect_list
|
||||
self.n_layers = n_layers
|
||||
self.dim_embd = dim_embd
|
||||
self.dim_mlp = dim_embd * 2
|
||||
|
||||
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||
|
||||
# transformer
|
||||
self.ft_layers = nn.Sequential(
|
||||
*[
|
||||
TransformerSALayer(
|
||||
embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
|
||||
)
|
||||
for _ in range(self.n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# logits_predict head
|
||||
self.idx_pred_layer = nn.Sequential(
|
||||
nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
|
||||
)
|
||||
|
||||
self.channels = {
|
||||
"16": 512,
|
||||
"32": 256,
|
||||
"64": 256,
|
||||
"128": 128,
|
||||
"256": 128,
|
||||
"512": 64,
|
||||
}
|
||||
|
||||
# after second residual block for > 16, before attn layer for ==16
|
||||
self.fuse_encoder_block = {
|
||||
"512": 2,
|
||||
"256": 5,
|
||||
"128": 8,
|
||||
"64": 11,
|
||||
"32": 14,
|
||||
"16": 18,
|
||||
}
|
||||
# after first residual block for > 16, before attn layer for ==16
|
||||
self.fuse_generator_block = {
|
||||
"16": 6,
|
||||
"32": 9,
|
||||
"64": 12,
|
||||
"128": 15,
|
||||
"256": 18,
|
||||
"512": 21,
|
||||
}
|
||||
|
||||
# fuse_convs_dict
|
||||
self.fuse_convs_dict = nn.ModuleDict()
|
||||
for f_size in self.connect_list:
|
||||
in_ch = self.channels[f_size]
|
||||
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
||||
# ################### Encoder #####################
|
||||
enc_feat_dict = {}
|
||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||
for i, block in enumerate(self.encoder.blocks):
|
||||
x = block(x)
|
||||
if i in out_list:
|
||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||
|
||||
lq_feat = x
|
||||
# ################# Transformer ###################
|
||||
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
||||
pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
|
||||
# BCHW -> BC(HW) -> (HW)BC
|
||||
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
|
||||
query_emb = feat_emb
|
||||
# Transformer encoder
|
||||
for layer in self.ft_layers:
|
||||
query_emb = layer(query_emb, query_pos=pos_emb)
|
||||
|
||||
# output logits
|
||||
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
||||
logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
|
||||
|
||||
if code_only: # for training stage II
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return logits, lq_feat
|
||||
|
||||
# ################# Quantization ###################
|
||||
# if self.training:
|
||||
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
||||
# # b(hw)c -> bc(hw) -> bchw
|
||||
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
||||
# ------------
|
||||
soft_one_hot = F.softmax(logits, dim=2)
|
||||
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
||||
quant_feat = self.quantize.get_codebook_feat(
|
||||
top_idx, shape=[x.shape[0], 16, 16, 256]
|
||||
)
|
||||
# preserve gradients
|
||||
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
||||
|
||||
if detach_16:
|
||||
quant_feat = quant_feat.detach() # for training stage III
|
||||
if adain:
|
||||
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
||||
|
||||
# ################## Generator ####################
|
||||
x = quant_feat
|
||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||
|
||||
for i, block in enumerate(self.generator.blocks):
|
||||
x = block(x)
|
||||
if i in fuse_list: # fuse after i-th block
|
||||
f_size = str(x.shape[-1])
|
||||
if w > 0:
|
||||
x = self.fuse_convs_dict[f_size](
|
||||
enc_feat_dict[f_size].detach(), x, w
|
||||
)
|
||||
out = x
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return out, logits, lq_feat
|
1
imaginairy/vendored/codeformer/readme.txt
Normal file
1
imaginairy/vendored/codeformer/readme.txt
Normal file
@ -0,0 +1 @@
|
||||
vendored from git@github.com:sczhou/CodeFormer.git c5b4593074ba6214284d6acd5f1719b6c5d739af
|
514
imaginairy/vendored/codeformer/vqgan_arch.py
Normal file
514
imaginairy/vendored/codeformer/vqgan_arch.py
Normal file
@ -0,0 +1,514 @@
|
||||
"""
|
||||
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||
|
||||
"""
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
|
||||
def normalize(in_channels):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
# Define VQVAE classes
|
||||
class VectorQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, beta):
|
||||
super(VectorQuantizer, self).__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
||||
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
||||
self.embedding.weight.data.uniform_(
|
||||
-1.0 / self.codebook_size, 1.0 / self.codebook_size
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.emb_dim)
|
||||
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
d = (
|
||||
(z_flattened**2).sum(dim=1, keepdim=True)
|
||||
+ (self.embedding.weight**2).sum(1)
|
||||
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
||||
)
|
||||
|
||||
mean_distance = torch.mean(d)
|
||||
# find closest encodings
|
||||
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
||||
min_encoding_scores, min_encoding_indices = torch.topk(
|
||||
d, 1, dim=1, largest=False
|
||||
)
|
||||
# [0-1], higher score, higher confidence
|
||||
min_encoding_scores = torch.exp(-min_encoding_scores / 10)
|
||||
|
||||
min_encodings = torch.zeros(
|
||||
min_encoding_indices.shape[0], self.codebook_size
|
||||
).to(z)
|
||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
||||
# compute loss for embedding
|
||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
||||
(z_q - z.detach()) ** 2
|
||||
)
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# perplexity
|
||||
e_mean = torch.mean(min_encodings, dim=0)
|
||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return (
|
||||
z_q,
|
||||
loss,
|
||||
{
|
||||
"perplexity": perplexity,
|
||||
"min_encodings": min_encodings,
|
||||
"min_encoding_indices": min_encoding_indices,
|
||||
"min_encoding_scores": min_encoding_scores,
|
||||
"mean_distance": mean_distance,
|
||||
},
|
||||
)
|
||||
|
||||
def get_codebook_feat(self, indices, shape):
|
||||
# input indices: batch*token_num -> (batch*token_num)*1
|
||||
# shape: batch, height, width, channel
|
||||
indices = indices.view(-1, 1)
|
||||
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
||||
min_encodings.scatter_(1, indices, 1)
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
||||
|
||||
if shape is not None: # reshape back to match original input shape
|
||||
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
class GumbelQuantizer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
codebook_size,
|
||||
emb_dim,
|
||||
num_hiddens,
|
||||
straight_through=False,
|
||||
kl_weight=5e-4,
|
||||
temp_init=1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.straight_through = straight_through
|
||||
self.temperature = temp_init
|
||||
self.kl_weight = kl_weight
|
||||
self.proj = nn.Conv2d(
|
||||
num_hiddens, codebook_size, 1
|
||||
) # projects last encoder layer to quantized logits
|
||||
self.embed = nn.Embedding(codebook_size, emb_dim)
|
||||
|
||||
def forward(self, z):
|
||||
hard = self.straight_through if self.training else True
|
||||
|
||||
logits = self.proj(z)
|
||||
|
||||
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
||||
|
||||
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
||||
|
||||
# + kl divergence to the prior loss
|
||||
qy = F.softmax(logits, dim=1)
|
||||
diff = (
|
||||
self.kl_weight
|
||||
* torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
||||
)
|
||||
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
||||
|
||||
return z_q, diff, {"min_encoding_indices": min_encoding_indices}
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
super(ResBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.norm1 = normalize(in_channels)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.norm2 = normalize(out_channels)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.conv_out = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x_in):
|
||||
x = x_in
|
||||
x = self.norm1(x)
|
||||
x = swish(x)
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = swish(x)
|
||||
x = self.conv2(x)
|
||||
if self.in_channels != self.out_channels:
|
||||
x_in = self.conv_out(x_in)
|
||||
|
||||
return x + x_in
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h * w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = F.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
nf,
|
||||
emb_dim,
|
||||
ch_mult,
|
||||
num_res_blocks,
|
||||
resolution,
|
||||
attn_resolutions,
|
||||
):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.attn_resolutions = attn_resolutions
|
||||
|
||||
curr_res = self.resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
|
||||
blocks = []
|
||||
# initial convultion
|
||||
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# residual and downsampling blocks, with attention on smaller res (16x16)
|
||||
for i in range(self.num_resolutions):
|
||||
block_in_ch = nf * in_ch_mult[i]
|
||||
block_out_ch = nf * ch_mult[i]
|
||||
for _ in range(self.num_res_blocks):
|
||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
||||
block_in_ch = block_out_ch
|
||||
if curr_res in attn_resolutions:
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
|
||||
if i != self.num_resolutions - 1:
|
||||
blocks.append(Downsample(block_in_ch))
|
||||
curr_res = curr_res // 2
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
|
||||
# normalise and convert to latent size
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(
|
||||
nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.ch_mult = ch_mult
|
||||
self.num_resolutions = len(self.ch_mult)
|
||||
self.num_res_blocks = res_blocks
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.in_channels = emb_dim
|
||||
self.out_channels = 3
|
||||
block_in_ch = self.nf * self.ch_mult[-1]
|
||||
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
|
||||
|
||||
blocks = []
|
||||
# initial conv
|
||||
blocks.append(
|
||||
nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
|
||||
for i in reversed(range(self.num_resolutions)):
|
||||
block_out_ch = self.nf * self.ch_mult[i]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
||||
block_in_ch = block_out_ch
|
||||
|
||||
if curr_res in self.attn_resolutions:
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
|
||||
if i != 0:
|
||||
blocks.append(Upsample(block_in_ch))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(
|
||||
nn.Conv2d(
|
||||
block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQAutoEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size,
|
||||
nf,
|
||||
ch_mult,
|
||||
quantizer="nearest",
|
||||
res_blocks=2,
|
||||
attn_resolutions=[16],
|
||||
codebook_size=1024,
|
||||
emb_dim=256,
|
||||
beta=0.25,
|
||||
gumbel_straight_through=False,
|
||||
gumbel_kl_weight=1e-8,
|
||||
model_path=None,
|
||||
):
|
||||
super().__init__()
|
||||
logger = get_root_logger()
|
||||
self.in_channels = 3
|
||||
self.nf = nf
|
||||
self.n_blocks = res_blocks
|
||||
self.codebook_size = codebook_size
|
||||
self.embed_dim = emb_dim
|
||||
self.ch_mult = ch_mult
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.quantizer_type = quantizer
|
||||
self.encoder = Encoder(
|
||||
self.in_channels,
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions,
|
||||
)
|
||||
if self.quantizer_type == "nearest":
|
||||
self.beta = beta # 0.25
|
||||
self.quantize = VectorQuantizer(
|
||||
self.codebook_size, self.embed_dim, self.beta
|
||||
)
|
||||
elif self.quantizer_type == "gumbel":
|
||||
self.gumbel_num_hiddens = emb_dim
|
||||
self.straight_through = gumbel_straight_through
|
||||
self.kl_weight = gumbel_kl_weight
|
||||
self.quantize = GumbelQuantizer(
|
||||
self.codebook_size,
|
||||
self.embed_dim,
|
||||
self.gumbel_num_hiddens,
|
||||
self.straight_through,
|
||||
self.kl_weight,
|
||||
)
|
||||
self.generator = Generator(
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions,
|
||||
)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location="cpu")
|
||||
if "params_ema" in chkpt:
|
||||
self.load_state_dict(
|
||||
torch.load(model_path, map_location="cpu")["params_ema"]
|
||||
)
|
||||
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
|
||||
elif "params" in chkpt:
|
||||
self.load_state_dict(
|
||||
torch.load(model_path, map_location="cpu")["params"]
|
||||
)
|
||||
logger.info(f"vqgan is loaded from: {model_path} [params]")
|
||||
else:
|
||||
raise ValueError(f"Wrong params!")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
quant, codebook_loss, quant_stats = self.quantize(x)
|
||||
x = self.generator(quant)
|
||||
return x, codebook_loss, quant_stats
|
||||
|
||||
|
||||
# patch based discriminator
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQGANDiscriminator(nn.Module):
|
||||
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
||||
super().__init__()
|
||||
|
||||
layers = [
|
||||
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
]
|
||||
ndf_mult = 1
|
||||
ndf_mult_prev = 1
|
||||
for n in range(1, n_layers): # gradually increase the number of filters
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2**n, 8)
|
||||
layers += [
|
||||
nn.Conv2d(
|
||||
ndf * ndf_mult_prev,
|
||||
ndf * ndf_mult,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
]
|
||||
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2**n_layers, 8)
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(
|
||||
ndf * ndf_mult_prev,
|
||||
ndf * ndf_mult,
|
||||
kernel_size=4,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
]
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
|
||||
] # output 1 channel prediction map
|
||||
self.main = nn.Sequential(*layers)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location="cpu")
|
||||
if "params_d" in chkpt:
|
||||
self.load_state_dict(
|
||||
torch.load(model_path, map_location="cpu")["params_d"]
|
||||
)
|
||||
elif "params" in chkpt:
|
||||
self.load_state_dict(
|
||||
torch.load(model_path, map_location="cpu")["params"]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Wrong params!")
|
||||
|
||||
def forward(self, x):
|
||||
return self.main(x)
|
@ -5,7 +5,11 @@
|
||||
# pip-compile --output-file=requirements-dev.txt requirements-dev.in setup.py
|
||||
#
|
||||
absl-py==1.2.0
|
||||
# via tensorboard
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
addict==2.4.0
|
||||
# via basicsr
|
||||
aiohttp==3.8.1
|
||||
# via fsspec
|
||||
aiosignal==1.2.0
|
||||
@ -20,6 +24,10 @@ attrs==22.1.0
|
||||
# via
|
||||
# aiohttp
|
||||
# pytest
|
||||
basicsr==1.4.2
|
||||
# via
|
||||
# gfpgan
|
||||
# realesrgan
|
||||
black==22.8.0
|
||||
# via -r requirements-dev.in
|
||||
cachetools==5.2.0
|
||||
@ -36,33 +44,56 @@ click==8.1.3
|
||||
# imaginAIry (setup.py)
|
||||
coverage==6.4.4
|
||||
# via -r requirements-dev.in
|
||||
cycler==0.11.0
|
||||
# via matplotlib
|
||||
diffusers==0.3.0
|
||||
# via imaginAIry (setup.py)
|
||||
dill==0.3.5.1
|
||||
# via pylint
|
||||
einops==0.3.0
|
||||
# via imaginAIry (setup.py)
|
||||
facexlib==0.2.5
|
||||
# via
|
||||
# gfpgan
|
||||
# realesrgan
|
||||
filelock==3.8.0
|
||||
# via
|
||||
# diffusers
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
filterpy==1.4.5
|
||||
# via facexlib
|
||||
fonttools==4.37.1
|
||||
# via matplotlib
|
||||
frozenlist==1.3.1
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2022.8.2
|
||||
# via pytorch-lightning
|
||||
ftfy==6.1.1
|
||||
# via imaginAIry (setup.py)
|
||||
future==0.18.2
|
||||
# via pytorch-lightning
|
||||
# via
|
||||
# basicsr
|
||||
# pytorch-lightning
|
||||
gfpgan==1.3.7
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# realesrgan
|
||||
google-auth==2.11.0
|
||||
# via
|
||||
# google-auth-oauthlib
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
google-auth-oauthlib==0.4.6
|
||||
# via tensorboard
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
grpcio==1.48.1
|
||||
# via tensorboard
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
huggingface-hub==0.9.1
|
||||
# via
|
||||
# diffusers
|
||||
@ -72,7 +103,9 @@ idna==3.3
|
||||
# requests
|
||||
# yarl
|
||||
imageio==2.9.0
|
||||
# via imaginAIry (setup.py)
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# scikit-image
|
||||
importlib-metadata==4.12.0
|
||||
# via diffusers
|
||||
iniconfig==1.1.1
|
||||
@ -81,14 +114,26 @@ isort==5.10.1
|
||||
# via
|
||||
# -r requirements-dev.in
|
||||
# pylint
|
||||
kiwisolver==1.4.4
|
||||
# via matplotlib
|
||||
kornia==0.6
|
||||
# via imaginAIry (setup.py)
|
||||
lazy-object-proxy==1.7.1
|
||||
# via astroid
|
||||
llvmlite==0.39.1
|
||||
# via numba
|
||||
lmdb==1.3.0
|
||||
# via
|
||||
# basicsr
|
||||
# gfpgan
|
||||
markdown==3.4.1
|
||||
# via tensorboard
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
markupsafe==2.1.1
|
||||
# via werkzeug
|
||||
matplotlib==3.5.3
|
||||
# via filterpy
|
||||
mccabe==0.7.0
|
||||
# via
|
||||
# pylama
|
||||
@ -99,13 +144,30 @@ multidict==6.0.2
|
||||
# yarl
|
||||
mypy-extensions==0.4.3
|
||||
# via black
|
||||
networkx==2.8.6
|
||||
# via scikit-image
|
||||
numba==0.56.2
|
||||
# via facexlib
|
||||
numpy==1.23.3
|
||||
# via
|
||||
# basicsr
|
||||
# diffusers
|
||||
# facexlib
|
||||
# filterpy
|
||||
# gfpgan
|
||||
# imageio
|
||||
# imaginAIry (setup.py)
|
||||
# matplotlib
|
||||
# numba
|
||||
# opencv-python
|
||||
# pytorch-lightning
|
||||
# pywavelets
|
||||
# realesrgan
|
||||
# scikit-image
|
||||
# scipy
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
# tifffile
|
||||
# torchmetrics
|
||||
# torchvision
|
||||
# transformers
|
||||
@ -113,20 +175,33 @@ oauthlib==3.2.1
|
||||
# via requests-oauthlib
|
||||
omegaconf==2.1.1
|
||||
# via imaginAIry (setup.py)
|
||||
opencv-python==4.6.0.66
|
||||
# via
|
||||
# basicsr
|
||||
# facexlib
|
||||
# gfpgan
|
||||
# realesrgan
|
||||
packaging==21.3
|
||||
# via
|
||||
# huggingface-hub
|
||||
# kornia
|
||||
# matplotlib
|
||||
# pytest
|
||||
# pytorch-lightning
|
||||
# scikit-image
|
||||
# torchmetrics
|
||||
# transformers
|
||||
pathspec==0.10.1
|
||||
# via black
|
||||
pillow==9.2.0
|
||||
# via
|
||||
# basicsr
|
||||
# diffusers
|
||||
# facexlib
|
||||
# imageio
|
||||
# matplotlib
|
||||
# realesrgan
|
||||
# scikit-image
|
||||
# torchvision
|
||||
platformdirs==2.5.2
|
||||
# via
|
||||
@ -135,7 +210,9 @@ platformdirs==2.5.2
|
||||
pluggy==1.0.0
|
||||
# via pytest
|
||||
protobuf==3.19.4
|
||||
# via tensorboard
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
py==1.11.0
|
||||
# via pytest
|
||||
pyasn1==0.4.8
|
||||
@ -159,27 +236,39 @@ pylama==8.4.1
|
||||
pylint==2.15.2
|
||||
# via -r requirements-dev.in
|
||||
pyparsing==3.0.9
|
||||
# via packaging
|
||||
# via
|
||||
# matplotlib
|
||||
# packaging
|
||||
pytest==7.1.3
|
||||
# via -r requirements-dev.in
|
||||
python-dateutil==2.8.2
|
||||
# via matplotlib
|
||||
pytorch-lightning==1.4.2
|
||||
# via imaginAIry (setup.py)
|
||||
pywavelets==1.3.0
|
||||
# via scikit-image
|
||||
pyyaml==6.0
|
||||
# via
|
||||
# basicsr
|
||||
# gfpgan
|
||||
# huggingface-hub
|
||||
# omegaconf
|
||||
# pytorch-lightning
|
||||
# transformers
|
||||
realesrgan==0.2.5.0
|
||||
# via imaginAIry (setup.py)
|
||||
regex==2022.9.11
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
requests==2.28.1
|
||||
# via
|
||||
# basicsr
|
||||
# diffusers
|
||||
# fsspec
|
||||
# huggingface-hub
|
||||
# requests-oauthlib
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
# torchvision
|
||||
# transformers
|
||||
@ -187,18 +276,38 @@ requests-oauthlib==1.3.1
|
||||
# via google-auth-oauthlib
|
||||
rsa==4.9
|
||||
# via google-auth
|
||||
scikit-image==0.19.3
|
||||
# via basicsr
|
||||
scipy==1.9.1
|
||||
# via
|
||||
# basicsr
|
||||
# facexlib
|
||||
# filterpy
|
||||
# gfpgan
|
||||
# scikit-image
|
||||
six==1.16.0
|
||||
# via
|
||||
# google-auth
|
||||
# grpcio
|
||||
# python-dateutil
|
||||
snowballstemmer==2.2.0
|
||||
# via pydocstyle
|
||||
tb-nightly==2.11.0a20220912
|
||||
# via
|
||||
# basicsr
|
||||
# gfpgan
|
||||
tensorboard==2.10.0
|
||||
# via pytorch-lightning
|
||||
tensorboard-data-server==0.6.1
|
||||
# via tensorboard
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
tensorboard-plugin-wit==1.8.1
|
||||
# via tensorboard
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
tifffile==2022.8.12
|
||||
# via scikit-image
|
||||
tokenizers==0.12.1
|
||||
# via transformers
|
||||
tomli==2.0.1
|
||||
@ -210,10 +319,14 @@ tomlkit==0.11.4
|
||||
# via pylint
|
||||
torch==1.12.1
|
||||
# via
|
||||
# basicsr
|
||||
# diffusers
|
||||
# facexlib
|
||||
# gfpgan
|
||||
# imaginAIry (setup.py)
|
||||
# kornia
|
||||
# pytorch-lightning
|
||||
# realesrgan
|
||||
# torchmetrics
|
||||
# torchvision
|
||||
torchmetrics==0.6.0
|
||||
@ -221,12 +334,21 @@ torchmetrics==0.6.0
|
||||
# imaginAIry (setup.py)
|
||||
# pytorch-lightning
|
||||
torchvision==0.13.1
|
||||
# via imaginAIry (setup.py)
|
||||
# via
|
||||
# basicsr
|
||||
# facexlib
|
||||
# gfpgan
|
||||
# imaginAIry (setup.py)
|
||||
# realesrgan
|
||||
tqdm==4.64.1
|
||||
# via
|
||||
# basicsr
|
||||
# facexlib
|
||||
# gfpgan
|
||||
# huggingface-hub
|
||||
# imaginAIry (setup.py)
|
||||
# pytorch-lightning
|
||||
# realesrgan
|
||||
# transformers
|
||||
transformers==4.19.2
|
||||
# via imaginAIry (setup.py)
|
||||
@ -238,12 +360,22 @@ typing-extensions==4.3.0
|
||||
# torchvision
|
||||
urllib3==1.26.12
|
||||
# via requests
|
||||
wcwidth==0.2.5
|
||||
# via ftfy
|
||||
werkzeug==2.2.2
|
||||
# via tensorboard
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
wheel==0.37.1
|
||||
# via tensorboard
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
wrapt==1.14.1
|
||||
# via astroid
|
||||
yapf==0.32.0
|
||||
# via
|
||||
# basicsr
|
||||
# gfpgan
|
||||
yarl==1.8.1
|
||||
# via aiohttp
|
||||
zipp==3.8.1
|
||||
|
4
setup.py
4
setup.py
@ -23,7 +23,7 @@ setup(
|
||||
install_requires=[
|
||||
"click",
|
||||
"ftfy", # for vendored clip
|
||||
"torch",
|
||||
"torch>=1.2.0",
|
||||
"numpy",
|
||||
"tqdm",
|
||||
"diffusers",
|
||||
@ -37,5 +37,7 @@ setup(
|
||||
"kornia==0.6",
|
||||
# k-diffusion for use with find_noise.py
|
||||
# "k-diffusion@git+https://github.com/crowsonkb/k-diffusion.git@71ba7d6735e9cba1945b429a21345960eb3f151c#egg=k-diffusion",
|
||||
"realesrgan",
|
||||
"gfpgan>=1.3.7",
|
||||
],
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from imaginairy.api import imagine_image_files, imagine_images
|
||||
from imaginairy.api import imagine, imagine_image_files
|
||||
from imaginairy.schema import ImaginePrompt
|
||||
|
||||
from . import TESTS_FOLDER
|
||||
@ -8,7 +8,7 @@ def test_imagine():
|
||||
prompt = ImaginePrompt(
|
||||
"a scenic landscape", width=512, height=256, steps=20, seed=1
|
||||
)
|
||||
result = next(imagine_images(prompt))
|
||||
result = next(imagine(prompt))
|
||||
assert result.md5() == "4c5957c498881d365cfcf13014812af0"
|
||||
result.img.save(f"{TESTS_FOLDER}/test_output/scenic_landscape.png")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user