You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/utils.py

177 lines
4.8 KiB
Python

import importlib
import logging
import os.path
import platform
from contextlib import contextmanager
from functools import lru_cache
from typing import List, Optional
import numpy as np
import requests
import torch
from PIL import Image, ImageFilter
from torch import Tensor
from torch.nn import functional
from torch.overrides import handle_torch_function, has_torch_function_variadic
from transformers import cached_path
from imaginairy.img_log import log_img
logger = logging.getLogger(__name__)
@lru_cache()
def get_device():
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps:0"
return "cpu"
@lru_cache()
def get_device_name(device_type):
if device_type == "cuda":
return torch.cuda.get_device_name(0)
return platform.processor()
def log_params(model):
total_params = sum(p.numel() for p in model.parameters())
logger.debug(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
def instantiate_from_config(config):
if "target" not in config:
if config == "__is_first_stage__":
return None
if config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", {}))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def _fixed_layer_norm(
input: Tensor, # noqa
normalized_shape: List[int],
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
) -> Tensor:
"""
Applies Layer Normalization for last certain number of dimensions.
See :class:`~torch.nn.LayerNorm` for details.
"""
if has_torch_function_variadic(input, weight, bias):
return handle_torch_function(
_fixed_layer_norm,
(input, weight, bias),
input,
normalized_shape,
weight=weight,
bias=bias,
eps=eps,
)
return torch.layer_norm(
input.contiguous(),
normalized_shape,
weight,
bias,
eps,
torch.backends.cudnn.enabled,
)
@contextmanager
def fix_torch_nn_layer_norm():
"""https://github.com/CompVis/stable-diffusion/issues/25#issuecomment-1221416526"""
orig_function = functional.layer_norm
functional.layer_norm = _fixed_layer_norm
try:
yield
finally:
functional.layer_norm = orig_function
def expand_mask(mask_image, size):
if size < 0:
threshold = 0.95
else:
threshold = 0.05
mask_image = mask_image.convert("L")
mask_image = mask_image.filter(ImageFilter.GaussianBlur(size))
log_img(mask_image, "init mask blurred")
mask = np.array(mask_image)
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < threshold] = 0
mask[mask >= threshold] = 1
return Image.fromarray(np.uint8(mask.squeeze() * 255))
def img_path_to_torch_image(path):
image = Image.open(path).convert("RGB")
logger.info(f"Loaded input 🖼 of size {image.size} from {path}")
return pillow_img_to_torch_image(image)
def pillow_fit_image_within(image, max_height=512, max_width=512):
image = image.convert("RGB")
w, h = image.size
resize_ratio = min(max_width / w, max_height / h)
w, h = int(w * resize_ratio), int(h * resize_ratio)
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=Image.Resampling.NEAREST)
return image, w, h
def pillow_img_to_torch_image(image):
image = image.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def get_cache_dir():
xdg_cache_home = os.getenv("XDG_CACHE_HOME", None)
if xdg_cache_home is None:
user_home = os.getenv("HOME", None)
if user_home:
xdg_cache_home = os.path.join(user_home, ".cache")
if xdg_cache_home is not None:
return os.path.join(xdg_cache_home, "imaginairy", "weights")
return os.path.join(os.path.dirname(__file__), ".cached-downloads")
def get_cached_url_path(url):
try:
return cached_path(url)
except OSError:
pass
filename = url.split("/")[-1]
dest = get_cache_dir()
os.makedirs(dest, exist_ok=True)
dest_path = os.path.join(dest, filename)
if os.path.exists(dest_path):
return dest_path
r = requests.get(url) # noqa
with open(dest_path, "wb") as f:
f.write(r.content)
return dest_path