You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/utils/__init__.py

347 lines
9.1 KiB
Python

import importlib
import logging
import numpy as np
import platform
import random
import re
import time
from contextlib import contextmanager, nullcontext
from functools import lru_cache
from typing import Any, List, Optional
import torch
from torch import Tensor, autocast
from torch.nn import functional
from torch.overrides import handle_torch_function, has_torch_function_variadic
logger = logging.getLogger(__name__)
@lru_cache
def get_device() -> str:
"""Return the best torch backend available."""
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
@lru_cache
def get_default_dtype():
"""Return the default dtype for torch."""
if get_device() == "cuda":
return torch.float16
if get_device() == "mps":
return torch.float16
return torch.float32
@lru_cache
def get_hardware_description(device_type: str) -> str:
"""Description of the hardware being used."""
desc = platform.platform()
if device_type == "cuda":
desc += "-" + torch.cuda.get_device_name(0)
return desc
def get_obj_from_str(import_path: str, reload=False) -> Any:
"""
Gets a python object from a string reference if it's location.
Example: "functools.lru_cache"
"""
module_path, obj_name = import_path.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module_path)
importlib.reload(module_imp)
module = importlib.import_module(module_path, package=None)
return getattr(module, obj_name)
def instantiate_from_config(config: dict) -> Any:
"""Instantiate an object from a config dict."""
if "target" not in config:
if config == "__is_first_stage__":
return None
if config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
params = config.get("params", {})
_cls = get_obj_from_str(config["target"])
start = time.perf_counter()
c = _cls(**params)
end = time.perf_counter()
logger.debug(f"Instantiation of {_cls} took {end-start} seconds")
return c
@contextmanager
def platform_appropriate_autocast(precision="autocast", enabled=True):
"""
Allow calculations to run in mixed precision, which can be faster.
"""
# autocast not supported on CPU
# https://github.com/pytorch/pytorch/issues/55374
# https://github.com/invoke-ai/InvokeAI/pull/518
if precision == "autocast" and get_device() in ("cuda",) and False:
with autocast(get_device(), enabled=enabled):
yield
else:
with nullcontext(get_device()):
yield
def _fixed_layer_norm(
input: Tensor, # noqa
normalized_shape: List[int],
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
) -> Tensor:
"""
Applies Layer Normalization for last certain number of dimensions.
See :class:`~torch.nn.LayerNorm` for details.
"""
if has_torch_function_variadic(input, weight, bias):
return handle_torch_function(
_fixed_layer_norm,
(input, weight, bias),
input,
normalized_shape,
weight=weight,
bias=bias,
eps=eps,
)
return torch.layer_norm(
input.contiguous(),
normalized_shape,
weight,
bias,
eps,
torch.backends.cudnn.enabled,
)
@contextmanager
def fix_torch_nn_layer_norm():
"""https://github.com/CompVis/stable-diffusion/issues/25#issuecomment-1221416526."""
orig_function = functional.layer_norm
functional.layer_norm = _fixed_layer_norm
try:
yield
finally:
functional.layer_norm = orig_function
@contextmanager
def fix_torch_group_norm():
"""
Patch group_norm to cast the weights to the same type as the inputs.
From what I can understand all the other repos just switch to full precision instead
of addressing this. I think this would make things slower but I'm not sure.
https://github.com/pytorch/pytorch/pull/81852
"""
orig_group_norm = functional.group_norm
def _group_norm_wrapper(
input: Tensor, # noqa
num_groups: int,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
) -> Tensor:
if weight is not None and weight.dtype != input.dtype:
weight = weight.to(input.dtype)
if bias is not None and bias.dtype != input.dtype:
bias = bias.to(input.dtype)
return orig_group_norm(
input=input, num_groups=num_groups, weight=weight, bias=bias, eps=eps
)
functional.group_norm = _group_norm_wrapper
try:
yield
finally:
functional.group_norm = orig_group_norm
def randn_seeded(seed: int, size: List[int]) -> Tensor:
"""Generate a random tensor with a given seed."""
from hashlib import md5
g_cpu = torch.Generator()
g_cpu.manual_seed(seed)
noise = torch.randn(
size,
device="cpu",
generator=g_cpu,
)
# md5 of the torch tensor `noise`
torch_md5 = md5(noise.numpy().tobytes()).hexdigest()
logger.debug(f"Made noise of size {size} from seed {seed}. md5:{torch_md5}")
return noise
def check_torch_working():
"""Check that torch is working."""
try:
torch.randn(1, device=get_device())
except RuntimeError as e:
if "CUDA" in str(e):
msg = "CUDA is not working. Make sure you have a GPU and CUDA installed."
raise RuntimeError(msg) from e
raise
def frange(start, stop, step):
"""Range but handles floats."""
x = start
while True:
if x >= stop:
return
yield x
x += step
def shrink_list(items, max_size):
if len(items) <= max_size:
return items
removal_ratio = len(items) / (max_size - 1)
new_items = {}
for i, item in enumerate(items):
new_items[int(i / removal_ratio)] = item
return [items[0], *list(new_items.values())]
def glob_expand_paths(paths):
import glob
import os.path
expanded_paths = []
for p in paths:
if p.startswith("http"):
expanded_paths.append(p)
else:
p = os.path.expanduser(p)
if os.path.exists(p) and os.path.isfile(p):
expanded_paths.append(p)
else:
expanded_paths.extend(glob.glob(os.path.expanduser(p)))
return expanded_paths
def get_next_filenumber(path):
"""Get the next file number in a directory."""
import os
filenames = os.listdir(path)
if not filenames:
return 0
file_count = len(filenames)
filenames.sort()
try:
last_file_name = filenames[-1]
last_file_num = int(last_file_name.split("_")[0])
except (ValueError, IndexError):
last_file_num = 0
return max(file_count, last_file_num + 1)
def check_torch_version():
"""
Check that the torch version is compatible with ImaginAIry.
https://github.com/brycedrennan/imaginAIry/issues/329
"""
from packaging import version
if version.parse(torch.__version__) < version.parse("2.0.0"):
raise RuntimeError("ImaginAIry is not compatible with torch<2.0.0")
def exists(val):
return val is not None
def default(val, d):
if val is not None:
return val
return d() if callable(d) else d
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def expand_dims_like(x, y):
while x.dim() != y.dim():
x = x.unsqueeze(-1)
return x
def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
"""
Will return the result of a recursive get attribute call.
E.g.:
a.b.c
= getattr(getattr(a, "b"), "c")
= get_nested_attribute(a, "b.c")
If any part of the attribute call is an integer x with current obj a, will
try to call a[x] instead of a.x first.
"""
attributes = attribute_path.split(".")
if depth is not None and depth > 0:
attributes = attributes[:depth]
assert len(attributes) > 0, "At least one attribute should be selected"
current_attribute = obj
current_key = None
for level, attribute in enumerate(attributes):
current_key = ".".join(attributes[: level + 1])
try:
id_ = int(attribute)
current_attribute = current_attribute[id_]
except ValueError:
current_attribute = getattr(current_attribute, attribute)
return (current_attribute, current_key) if return_key else current_attribute
def prompt_normalized(prompt, length=130):
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:length]
def clear_gpu_cache():
import gc
import torch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def seed_everything(seed: int | None = None) -> None:
if seed is None:
seed = random.randint(0, 2**32 - 1)
logger.info(f"Using random seed: {seed}")
random.seed(a=seed)
np.random.seed(seed=seed)
torch.manual_seed(seed=seed)
torch.cuda.manual_seed_all(seed=seed)