docs and lint

pull/1/head
Bryce 2 years ago
parent 541ecb9701
commit ff7455034d

@ -58,12 +58,14 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
```bash
>> imagine "a couple smiling" --steps 40 --seed 1 --fix-faces
```
<img src="assets/000178_1_PLMS40_PS7.5_a_couple_smiling_nofix.png" height="256"> => <img src="assets/000178_1_PLMS40_PS7.5_a_couple_smiling_fixed.png" height="256">
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000178_1_PLMS40_PS7.5_a_couple_smiling_nofix.png" height="256">
=>
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000178_1_PLMS40_PS7.5_a_couple_smiling_fixed.png" height="256">
### Upscaling [by RealESRGAN](https://github.com/xinntao/Real-ESRGAN)
<img src="assets/000206_856637805_PLMS40_PS7.5_colorful_smoke.jpg">
<img src="assets/000206_856637805_PLMS40_PS7.5_colorful_smoke_upscaled.jpg" height="512">
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000206_856637805_PLMS40_PS7.5_colorful_smoke.jpg" height="128"> =>
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000206_856637805_PLMS40_PS7.5_colorful_smoke_upscaled.jpg" height="256">
## Features
@ -71,6 +73,7 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
- Generate images either in code or from command line.
- It just works. Proper requirements are installed. model weights are automatically downloaded. No huggingface account needed.
(if you have the right hardware... and aren't on windows)
- No more distorted faces!
- Noisy logs are gone (which was surprisingly hard to accomplish)
- WeightedPrompts let you smash together separate prompts (cat-dog)
- Tile Mode creates tileable images

@ -252,7 +252,7 @@ def imagine(
img, x_sample, half_mode=half_mode
):
logger.info(" ⚠️ Filtering NSFW image")
img = img.filter(ImageFilter.GaussianBlur(radius=10))
img = img.filter(ImageFilter.GaussianBlur(radius=40))
if prompt.fix_faces:
logger.info(" Fixing 😊 's in 🖼 using GFPGAN...")

@ -631,49 +631,48 @@ class LatentDiffusion(DDPM):
@torch.no_grad()
def encode_first_stage(self, x):
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
df = self.split_input_params["vqf"]
self.split_input_params["original_image_size"] = x.shape[-2:]
bs, nc, h, w = x.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
logger.info("reducing Kernel")
if (
hasattr(self, "split_input_params")
and self.split_input_params["patch_distributed_vq"]
):
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
df = self.split_input_params["vqf"]
self.split_input_params["original_image_size"] = x.shape[-2:]
bs, nc, h, w = x.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
logger.info("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
logger.info("reducing stride")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
logger.info("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(
x, ks, stride, df=df
)
z = unfold(x) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view(
(z.shape[0], -1, ks[0], ks[1], z.shape[-1])
) # (bn, nc, ks[0], ks[1], L )
fold, unfold, normalization, weighting = self.get_fold_unfold(
x, ks, stride, df=df
)
z = unfold(x) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view(
(z.shape[0], -1, ks[0], ks[1], z.shape[-1])
) # (bn, nc, ks[0], ks[1], L )
output_list = [
self.first_stage_model.encode(z[:, :, :, :, i])
for i in range(z.shape[-1])
]
output_list = [
self.first_stage_model.encode(z[:, :, :, :, i])
for i in range(z.shape[-1])
]
o = torch.stack(output_list, axis=-1)
o = o * weighting
o = torch.stack(output_list, axis=-1)
o = o * weighting
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization
return decoded
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization
return decoded
else:
return self.first_stage_model.encode(x)
else:
return self.first_stage_model.encode(x)
return self.first_stage_model.encode(x)
def apply_model(self, x_noisy, t, cond, return_ids=False):
@ -814,8 +813,8 @@ class LatentDiffusion(DDPM):
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon
return x_recon
def p_mean_variance(
self,
@ -851,16 +850,16 @@ class LatentDiffusion(DDPM):
if clip_denoised:
x_recon.clamp_(-1.0, 1.0)
if quantize_denoised:
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
x_recon, _, _ = self.first_stage_model.quantize(x_recon)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_recon, x_t=x, t=t
)
if return_codebook_ids:
return model_mean, posterior_variance, posterior_log_variance, logits
elif return_x0:
if return_x0:
return model_mean, posterior_variance, posterior_log_variance, x_recon
else:
return model_mean, posterior_variance, posterior_log_variance
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(
@ -890,10 +889,7 @@ class LatentDiffusion(DDPM):
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
)
if return_codebook_ids:
raise DeprecationWarning("Support dropped.")
model_mean, _, model_log_variance, logits = outputs
elif return_x0:
if return_x0:
model_mean, _, model_log_variance, x0 = outputs
else:
model_mean, _, model_log_variance = outputs
@ -904,17 +900,13 @@ class LatentDiffusion(DDPM):
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
if return_codebook_ids:
return model_mean + nonzero_mask * (
0.5 * model_log_variance
).exp() * noise, logits.argmax(dim=1)
if return_x0:
return (
model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
x0,
)
else:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
class DiffusionWrapper(pl.LightningModule):

@ -11,6 +11,7 @@ import requests
import torch
from PIL import Image
from torch import Tensor
from torch.overrides import handle_torch_function, has_torch_function_variadic
from transformers import cached_path
logger = logging.getLogger(__name__)
@ -56,9 +57,6 @@ def get_obj_from_str(string, reload=False):
return getattr(importlib.import_module(module, package=None), cls)
from torch.overrides import handle_torch_function, has_torch_function_variadic
def _fixed_layer_norm(
input: Tensor,
normalized_shape: List[int],

@ -7,7 +7,7 @@ ignore =
Z999,D100,D101,D102,D103,D107,D202,D203,D212,D400,D401,D415,
Z999,E501,E1101,
Z999,R0901,R0902,R0903,R0193,R0912,R0913,R0914,R0915,
Z999,W0511,W1203
Z999,W0221,W0511,W1203
[pylama:tests/*]
ignore = C0114,C0116,D103,W0613

Loading…
Cancel
Save