imaginAIry/imaginairy/utils.py
2022-09-07 20:59:30 -07:00

38 lines
1.0 KiB
Python

import importlib
from functools import lru_cache
import torch
@lru_cache()
def get_device():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
def print_params(model):
total_params = sum(p.numel() for p in model.parameters())
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
def instantiate_from_config(config):
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)