feature: inpainting model support; improved model manager

pull/85/head
Bryce 2 years ago committed by Bryce Drennan
parent 54cc874cba
commit 021a0c540d

@ -6,7 +6,7 @@ import numpy as np
import PIL
import torch
import torch.nn
from einops import rearrange
from einops import rearrange, repeat
from PIL import Image, ImageDraw, ImageFilter, ImageOps
from pytorch_lightning import seed_everything
@ -128,7 +128,9 @@ def imagine(
f"Generating 🖼 {i + 1}/{num_prompts}: {prompt.prompt_description()}"
)
model = get_diffusion_model(
weights_location=prompt.model, half_mode=half_mode
weights_location=prompt.model,
half_mode=half_mode,
for_inpainting=prompt.mask_image or prompt.mask_prompt,
)
with ImageLoggingContext(
prompt=prompt,
@ -242,9 +244,60 @@ def imagine(
schedule=schedule,
noise=noise,
)
batch_size = 1
log_latent(init_latent_noised, "init_latent_noised")
batch = {
"txt": batch_size * [prompt.prompt_text],
}
c_cat = []
if mask_image_orig:
mask_t = pillow_img_to_torch_image(
ImageOps.invert(mask_image_orig)
).to(get_device())
inverted_mask = 1 - mask
masked_image_t = init_image_t * (mask_t < 0.5)
batch.update(
{
"image": repeat(
init_image_t.to(device=get_device()),
"1 ... -> n ...",
n=batch_size,
),
"txt": batch_size * [prompt.prompt_text],
"mask": repeat(
inverted_mask.to(device=get_device()),
"1 ... -> n ...",
n=batch_size,
),
"masked_image": repeat(
masked_image_t.to(device=get_device()),
"1 ... -> n ...",
n=batch_size,
),
}
)
for concat_key in getattr(model, "concat_keys", []):
cc = batch[concat_key].float()
if concat_key != model.masked_image_key:
bchw = [batch_size, 4, shape[2], shape[3]]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = model.get_first_stage_encoding(
model.encode_first_stage(cc)
)
c_cat.append(cc)
if c_cat:
c_cat = [torch.cat(c_cat, dim=1)]
positive_conditioning = {
"c_concat": c_cat,
"c_crossattn": [positive_conditioning],
}
neutral_conditioning = {
"c_concat": c_cat,
"c_crossattn": [neutral_conditioning],
}
samples = sampler.sample(
num_steps=prompt.steps,
initial_latent=init_latent_noised,

@ -0,0 +1,70 @@
model:
base_learning_rate: 7.5e-05
target: imaginairy.modules.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid # important
monitor: val/loss_simple_ema
scale_factor: 0.18215
finetune_keys: null
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: imaginairy.modules.diffusion.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: imaginairy.modules.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: imaginairy.modules.clip_embedders.FrozenCLIPEmbedder

@ -6,7 +6,8 @@ import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from imaginairy.utils import get_cached_url_path, get_device
from imaginairy.model_manager import get_cached_url_path
from imaginairy.utils import get_device
from imaginairy.vendored.blip.blip import BLIP_Decoder, load_checkpoint
device = get_device()

@ -8,7 +8,7 @@ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from PIL import Image
from torchvision.transforms.functional import normalize
from imaginairy.utils import get_cached_url_path
from imaginairy.model_manager import get_cached_url_path
from imaginairy.vendored.codeformer.codeformer_arch import CodeFormer
logger = logging.getLogger(__name__)

@ -6,7 +6,8 @@ from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from realesrgan import RealESRGANer
from imaginairy.utils import get_cached_url_path, get_device
from imaginairy.model_manager import get_cached_url_path
from imaginairy.utils import get_device
@lru_cache()

@ -43,9 +43,9 @@ def pillow_img_to_opencv_img(img: PIL.Image.Image):
def model_latents_to_pillow_imgs(latents: torch.Tensor) -> Sequence[PIL.Image.Image]:
from imaginairy.model_manager import get_diffusion_model # noqa
from imaginairy.model_manager import get_current_diffusion_model # noqa
model = get_diffusion_model()
model = get_current_diffusion_model()
latents = model.decode_first_stage(latents)
latents = torch.clamp((latents + 1.0) / 2.0, min=0.0, max=1.0)
imgs = []

@ -1,10 +1,14 @@
import gc
import glob
import logging
import os
import requests
import torch
from omegaconf import OmegaConf
from transformers import cached_path
from transformers.utils.hub import TRANSFORMERS_CACHE, HfFolder
from transformers.utils.hub import url_to_filename as tf_url_to_filename
from imaginairy.paths import PKG_ROOT
from imaginairy.utils import get_device, instantiate_from_config
@ -14,16 +18,25 @@ logger = logging.getLogger(__name__)
MODEL_SHORTCUTS = {
"SD-1.4": (
"configs/stable-diffusion-v1.yaml",
"https://huggingface.co/bstddev/sd-v1-4/resolve/main/sd-v1-4.ckpt",
"https://huggingface.co/bstddev/sd-v1-4/resolve/77221977fa8de8ab8f36fac0374c120bd5b53287/sd-v1-4.ckpt",
),
"SD-1.5": (
"configs/stable-diffusion-v1.yaml",
"https://huggingface.co/acheong08/SD-V1-5-cloned/resolve/main/v1-5-pruned-emaonly.ckpt",
"https://huggingface.co/acheong08/SD-V1-5-cloned/resolve/fc392f6bd4345b80fc2256fa8aded8766b6c629e/v1-5-pruned-emaonly.ckpt",
),
"SD-1.5-inpaint": (
"configs/stable-diffusion-v1-inpaint.yaml",
"https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/9f492cedac6a1a2993f0b6ba44bb71b96a8aa9e6/sd-v1-5-inpainting.ckpt",
),
}
DEFAULT_MODEL = "SD-1.5"
LOADED_MODELS = {}
MOST_RECENTLY_LOADED_MODEL = None
class HuggingFaceAuthorizationError(RuntimeError):
pass
class MemoryAwareModel:
@ -68,7 +81,7 @@ class MemoryAwareModel:
def load_model_from_config(config, weights_location):
if weights_location.startswith("http"):
ckpt_path = cached_path(weights_location)
ckpt_path = get_cached_url_path(weights_location)
else:
ckpt_path = weights_location
logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
@ -80,7 +93,7 @@ def load_model_from_config(config, weights_location):
if weights_location.startswith("http"):
logger.warning("Corrupt checkpoint. deleting and re-downloading...")
os.remove(ckpt_path)
ckpt_path = cached_path(weights_location)
ckpt_path = get_cached_url_path(weights_location)
pl_sd = torch.load(ckpt_path, map_location="cpu")
if pl_sd is None:
raise e
@ -104,16 +117,45 @@ def get_diffusion_model(
weights_location=DEFAULT_MODEL,
config_path="configs/stable-diffusion-v1.yaml",
half_mode=None,
for_inpainting=False,
):
"""
Load a diffusion model
Weights location may also be shortcut name, e.g. "SD-1.5"
"""
try:
return _get_diffusion_model(
weights_location, config_path, half_mode, for_inpainting
)
except HuggingFaceAuthorizationError as e:
if for_inpainting:
logger.warning(
f"Failed to load inpainting model. Attempting to fall-back to standard model. {str(e)}"
)
return _get_diffusion_model(
DEFAULT_MODEL, config_path, half_mode, for_inpainting=False
)
raise e
def _get_diffusion_model(
weights_location=DEFAULT_MODEL,
config_path="configs/stable-diffusion-v1.yaml",
half_mode=None,
for_inpainting=False,
):
"""
Load a diffusion model
Weights location may also be shortcut name, e.g. "SD-1.5"
"""
global MOST_RECENTLY_LOADED_MODEL # noqa
if weights_location is None:
weights_location = DEFAULT_MODEL
if weights_location in MODEL_SHORTCUTS:
if for_inpainting and f"{weights_location}-inpaint" in MODEL_SHORTCUTS:
config_path, weights_location = MODEL_SHORTCUTS[f"{weights_location}-inpaint"]
elif weights_location in MODEL_SHORTCUTS:
config_path, weights_location = MODEL_SHORTCUTS[weights_location]
key = (config_path, weights_location)
@ -125,4 +167,87 @@ def get_diffusion_model(
model = LOADED_MODELS[key]
# calling model attribute forces it to load
model.num_timesteps_cond # noqa
MOST_RECENTLY_LOADED_MODEL = model
return model
def get_current_diffusion_model():
return MOST_RECENTLY_LOADED_MODEL
def get_cache_dir():
xdg_cache_home = os.getenv("XDG_CACHE_HOME", None)
if xdg_cache_home is None:
user_home = os.getenv("HOME", None)
if user_home:
xdg_cache_home = os.path.join(user_home, ".cache")
if xdg_cache_home is not None:
return os.path.join(xdg_cache_home, "imaginairy", "weights")
return os.path.join(os.path.dirname(__file__), ".cached-downloads")
def get_cached_url_path(url):
"""
Gets the contents of a url, but caches the response indefinitely
While we attempt to use the cached_path from huggingface transformers, we fall back
to our own implementation if the url does not provide an etag header, which `cached_path`
requires. We also skip the `head` call that `cached_path` makes on every call if the file
is already cached.
"""
try:
return huggingface_cached_path(url)
except (OSError, ValueError):
pass
filename = url.split("/")[-1]
dest = get_cache_dir()
os.makedirs(dest, exist_ok=True)
dest_path = os.path.join(dest, filename)
if os.path.exists(dest_path):
return dest_path
r = requests.get(url) # noqa
with open(dest_path, "wb") as f:
f.write(r.content)
return dest_path
def find_url_in_huggingface_cache(url):
huggingface_filename = os.path.join(TRANSFORMERS_CACHE, tf_url_to_filename(url))
for name in glob.glob(huggingface_filename + "*"):
if name.endswith((".json", ".lock")):
continue
return name
return None
def check_huggingface_url_authorized(url):
if not url.startswith("https://huggingface.co/"):
return None
token = HfFolder.get_token()
headers = {}
if token is not None:
headers["authorization"] = f"Bearer {token}"
response = requests.head(url, allow_redirects=True, headers=headers, timeout=5)
if response.status_code == 401:
raise HuggingFaceAuthorizationError(
"Unauthorized access to HuggingFace model. This model requires a huggingface token. "
"Please login to HuggingFace "
"or set HUGGING_FACE_HUB_TOKEN to your User Access Token. "
"See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information"
)
return None
def huggingface_cached_path(url):
# bypass all the HEAD calls done by the default `cached_path`
dest_path = find_url_in_huggingface_cache(url)
if not dest_path:
check_huggingface_url_authorized(url)
token = HfFolder.get_token()
dest_path = cached_path(url, use_auth_token=token)
return dest_path

@ -897,3 +897,58 @@ class DiffusionWrapper(pl.LightningModule):
raise NotImplementedError()
return out
class LatentInpaintDiffusion(LatentDiffusion):
def __init__( # noqa
self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
finetune_keys=None, # noqa
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
@torch.no_grad()
def get_input(
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
):
# note: restricted to non-trainable encoders currently
assert (
not self.cond_stage_trainable
), "trainable cond stages not yet supported for inpainting"
z, c, x, xrec, xc = super().get_input(
batch,
self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=bs,
)
assert self.concat_keys is not None
c_cat = []
for ck in self.concat_keys:
cc = (
rearrange(batch[ck], "b h w c -> b c h w")
.to(memory_format=torch.contiguous_format)
.float()
)
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
bchw = z.shape
if ck != self.masked_image_key:
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds

@ -16,7 +16,6 @@ import torch
from einops import repeat as e_repeat
from torch import nn
from imaginairy.log_utils import log_tensor
from imaginairy.utils import instantiate_from_config
logger = logging.getLogger(__name__)
@ -208,7 +207,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
)
else:
embedding = e_repeat(timesteps, "b -> b d", d=dim)
log_tensor(embedding, "timestep_embedding")
return embedding

@ -5,7 +5,7 @@ import numpy as np
import torch
from torch import nn
from imaginairy.log_utils import log_img, log_latent
from imaginairy.log_utils import log_latent
from imaginairy.modules.diffusion.util import (
extract_into_tensor,
make_ddim_sampling_parameters,
@ -125,7 +125,21 @@ def get_noise_prediction(
noisy_latent_in = torch.cat([noisy_latent] * 2)
time_encoding_in = torch.cat([time_encoding] * 2)
conditioning_in = torch.cat([neutral_conditioning, positive_conditioning])
if isinstance(positive_conditioning, dict):
assert isinstance(neutral_conditioning, dict)
conditioning_in = {}
for k in positive_conditioning:
if isinstance(positive_conditioning[k], list):
conditioning_in[k] = [
torch.cat([neutral_conditioning[k][i], positive_conditioning[k][i]])
for i in range(len(positive_conditioning[k]))
]
else:
conditioning_in[k] = torch.cat(
[neutral_conditioning[k], positive_conditioning[k]]
)
else:
conditioning_in = torch.cat([neutral_conditioning, positive_conditioning])
noise_pred_neutral, noise_pred_positive = denoise_func(
noisy_latent_in, time_encoding_in, conditioning_in
@ -164,8 +178,7 @@ def mask_blend(noisy_latent, orig_latent, mask, mask_noise, ts, model):
log_latent(hinted_orig_latent, f"hinted_orig_latent {ts}")
else:
hinted_orig_latent = noised_orig_latent
log_img(mask, f"mask {ts}")
# logger.info(mask.shape)
hinted_orig_latent_masked = hinted_orig_latent * mask
log_latent(hinted_orig_latent_masked, f"hinted_orig_latent_masked {ts}")
noisy_latent_masked = (1.0 - mask) * noisy_latent

@ -41,11 +41,6 @@ class DDIMSampler:
t_start=None,
quantize_x0=False,
):
if positive_conditioning.shape[0] != batch_size:
raise ValueError(
f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
schedule = NoiseSchedule(
model_num_timesteps=self.model.num_timesteps,
model_alphas_cumprod=self.model.alphas_cumprod,

@ -77,10 +77,10 @@ class KDiffusionSampler:
initial_latent=None,
t_start=None,
):
if positive_conditioning.shape[0] != batch_size:
raise ValueError(
f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
# if positive_conditioning.shape[0] != batch_size:
# raise ValueError(
# f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
# )
if initial_latent is None:
initial_latent = torch.randn(shape, device="cpu").to(self.device)

@ -45,10 +45,10 @@ class PLMSSampler:
quantize_denoised=False,
**kwargs,
):
if positive_conditioning.shape[0] != batch_size:
raise ValueError(
f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
# if positive_conditioning.shape[0] != batch_size:
# raise ValueError(
# f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
# )
schedule = NoiseSchedule(
model_num_timesteps=self.model.num_timesteps,

@ -1,17 +1,14 @@
import importlib
import logging
import os.path
import platform
from contextlib import contextmanager, nullcontext
from functools import lru_cache
from typing import Any, List, Optional, Union
import requests
import torch
from torch import Tensor, autocast
from torch.nn import functional
from torch.overrides import handle_torch_function, has_torch_function_variadic
from transformers import cached_path
logger = logging.getLogger(__name__)
@ -157,37 +154,6 @@ def fix_torch_group_norm():
functional.group_norm = orig_group_norm
def get_cache_dir():
xdg_cache_home = os.getenv("XDG_CACHE_HOME", None)
if xdg_cache_home is None:
user_home = os.getenv("HOME", None)
if user_home:
xdg_cache_home = os.path.join(user_home, ".cache")
if xdg_cache_home is not None:
return os.path.join(xdg_cache_home, "imaginairy", "weights")
return os.path.join(os.path.dirname(__file__), ".cached-downloads")
def get_cached_url_path(url):
try:
return cached_path(url)
except (OSError, ValueError):
pass
filename = url.split("/")[-1]
dest = get_cache_dir()
os.makedirs(dest, exist_ok=True)
dest_path = os.path.join(dest, filename)
if os.path.exists(dest_path):
return dest_path
r = requests.get(url) # noqa
with open(dest_path, "wb") as f:
f.write(r.content)
return dest_path
def randn_seeded(seed: int, size: List[int]) -> Tensor:
"""Generate a random tensor with a given seed"""
g_cpu = torch.Generator()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 326 KiB

After

Width:  |  Height:  |  Size: 322 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 333 KiB

After

Width:  |  Height:  |  Size: 311 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 263 KiB

After

Width:  |  Height:  |  Size: 261 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 236 KiB

After

Width:  |  Height:  |  Size: 239 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 264 KiB

After

Width:  |  Height:  |  Size: 248 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 239 KiB

After

Width:  |  Height:  |  Size: 241 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 239 KiB

After

Width:  |  Height:  |  Size: 241 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 235 KiB

After

Width:  |  Height:  |  Size: 237 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 254 KiB

After

Width:  |  Height:  |  Size: 239 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 236 KiB

After

Width:  |  Height:  |  Size: 239 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 236 KiB

After

Width:  |  Height:  |  Size: 239 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 265 KiB

After

Width:  |  Height:  |  Size: 264 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 258 KiB

After

Width:  |  Height:  |  Size: 263 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 260 KiB

After

Width:  |  Height:  |  Size: 260 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 238 KiB

After

Width:  |  Height:  |  Size: 252 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 263 KiB

After

Width:  |  Height:  |  Size: 264 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 262 KiB

After

Width:  |  Height:  |  Size: 264 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 235 KiB

After

Width:  |  Height:  |  Size: 254 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 252 KiB

After

Width:  |  Height:  |  Size: 263 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 263 KiB

After

Width:  |  Height:  |  Size: 271 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 244 KiB

After

Width:  |  Height:  |  Size: 237 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 264 KiB

After

Width:  |  Height:  |  Size: 266 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 262 KiB

After

Width:  |  Height:  |  Size: 264 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 234 KiB

After

Width:  |  Height:  |  Size: 256 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 264 KiB

After

Width:  |  Height:  |  Size: 266 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 262 KiB

After

Width:  |  Height:  |  Size: 262 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 231 KiB

After

Width:  |  Height:  |  Size: 256 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 262 KiB

After

Width:  |  Height:  |  Size: 263 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 262 KiB

After

Width:  |  Height:  |  Size: 262 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 234 KiB

After

Width:  |  Height:  |  Size: 253 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 253 KiB

After

Width:  |  Height:  |  Size: 263 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 265 KiB

After

Width:  |  Height:  |  Size: 271 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 239 KiB

After

Width:  |  Height:  |  Size: 243 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 263 KiB

After

Width:  |  Height:  |  Size: 264 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 262 KiB

After

Width:  |  Height:  |  Size: 264 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 235 KiB

After

Width:  |  Height:  |  Size: 254 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 262 KiB

After

Width:  |  Height:  |  Size: 265 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 263 KiB

After

Width:  |  Height:  |  Size: 263 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 227 KiB

After

Width:  |  Height:  |  Size: 254 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 257 KiB

After

Width:  |  Height:  |  Size: 263 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 260 KiB

After

Width:  |  Height:  |  Size: 262 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 241 KiB

After

Width:  |  Height:  |  Size: 253 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 558 KiB

After

Width:  |  Height:  |  Size: 558 KiB

Loading…
Cancel
Save