tests: add some utils tests

pull/51/head
Bryce 2 years ago committed by Bryce Drennan
parent 0244d4151f
commit e5c5df6b3d

@ -22,7 +22,7 @@ from imaginairy.modules.diffusion.util import (
noise_like,
)
from imaginairy.modules.distributions import DiagonalGaussianDistribution
from imaginairy.utils import instantiate_from_config, log_params
from imaginairy.utils import instantiate_from_config
logger = logging.getLogger(__name__)
__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
@ -95,7 +95,6 @@ class DDPM(pl.LightningModule):
self.channels = channels
self.use_positional_encodings = use_positional_encodings
self.model = DiffusionWrapper(unet_config, conditioning_key)
log_params(self.model)
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:

@ -11,7 +11,7 @@ from PIL import Image, ImageOps
from urllib3.exceptions import LocationParseError
from urllib3.util import parse_url
from imaginairy.utils import get_device, get_device_name
from imaginairy.utils import get_device, get_hardware_description
logger = logging.getLogger(__name__)
@ -219,7 +219,7 @@ class ImagineResult:
self.is_nsfw = is_nsfw
self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc)
self.torch_backend = get_device()
self.hardware_name = get_device_name(get_device())
self.hardware_name = get_hardware_description(get_device())
def md5(self):
return hashlib.md5(self.img.tobytes()).hexdigest()

@ -4,7 +4,7 @@ import os.path
import platform
from contextlib import contextmanager, nullcontext
from functools import lru_cache
from typing import List, Optional
from typing import Any, List, Optional, Union
import requests
import torch
@ -17,7 +17,8 @@ logger = logging.getLogger(__name__)
@lru_cache()
def get_device():
def get_device() -> str:
"""Return the best torch backend available"""
if torch.cuda.is_available():
return "cuda"
@ -28,33 +29,40 @@ def get_device():
@lru_cache()
def get_device_name(device_type):
def get_hardware_description(device_type: str) -> str:
"""Description of the hardware being used"""
desc = platform.platform()
if device_type == "cuda":
return torch.cuda.get_device_name(0)
return platform.processor()
desc += "-" + torch.cuda.get_device_name(0)
return desc
def log_params(model):
total_params = sum(p.numel() for p in model.parameters())
logger.debug(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
def get_obj_from_str(import_path: str, reload=False) -> Any:
"""
Gets a python object from a string reference if it's location
Example: "functools.lru_cache"
"""
module_path, obj_name = import_path.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module_path)
importlib.reload(module_imp)
module = importlib.import_module(module_path, package=None)
return getattr(module, obj_name)
def instantiate_from_config(config):
def instantiate_from_config(config: Union[dict, str]) -> Any:
"""Instantiate an object from a config dict"""
if "target" not in config:
if config == "__is_first_stage__":
return None
if config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", {}))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
params = config.get("params", {})
_cls = get_obj_from_str(config["target"])
return _cls(**params)
@contextmanager

@ -0,0 +1,84 @@
import platform
from datetime import datetime
from functools import lru_cache
from unittest import mock
import pytest
import torch.backends.mps
import torch.cuda
from imaginairy.utils import (
get_device,
get_hardware_description,
get_obj_from_str,
instantiate_from_config,
platform_appropriate_autocast,
)
def test_get_device(monkeypatch):
# just run it for real to check that it doesn't error before we mock things
get_device()
m_cuda_is_available = mock.MagicMock()
m_mps_is_available = mock.MagicMock()
monkeypatch.setattr(torch.cuda, "is_available", m_cuda_is_available)
monkeypatch.setattr(torch.backends.mps, "is_available", m_mps_is_available)
get_device.cache_clear()
m_cuda_is_available.side_effect = lambda: True
m_mps_is_available.side_effect = lambda: False
assert get_device() == "cuda"
get_device.cache_clear()
m_cuda_is_available.side_effect = lambda: False
m_mps_is_available.side_effect = lambda: True
assert get_device() == "mps:0"
get_device.cache_clear()
m_cuda_is_available.side_effect = lambda: False
m_mps_is_available.side_effect = lambda: False
assert get_device() == "cpu"
def test_get_hardware_description(monkeypatch):
monkeypatch.setattr(platform, "platform", lambda: "macOS-12.5.1-arm64-arm-64bit-z")
assert get_hardware_description("cpu") == "macOS-12.5.1-arm64-arm-64bit-z"
monkeypatch.setattr(platform, "platform", lambda: "macOS-12.5.1-arm64-arm-64bit-z")
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
monkeypatch.setattr(torch.cuda, "get_device_name", lambda x: "rtx-3090")
assert get_hardware_description("cuda") == "macOS-12.5.1-arm64-arm-64bit-z-rtx-3090"
def test_get_obj_from_str():
foo = get_obj_from_str("functools.lru_cache")
assert lru_cache == foo
foo = get_obj_from_str("functools.lru_cache", reload=True)
assert lru_cache != foo
def test_instantiate_from_config():
config = {
"target": "datetime.datetime",
"params": {"year": 2002, "month": 10, "day": 1},
}
o = instantiate_from_config(config)
assert o == datetime(2002, 10, 1)
config = "__is_first_stage__"
assert instantiate_from_config(config) is None
config = "__is_unconditional__"
assert instantiate_from_config(config) is None
config = "asdf"
with pytest.raises(KeyError):
instantiate_from_config(config)
def test_platform_appropriate_autocast():
with platform_appropriate_autocast("autocast"):
pass

@ -18,7 +18,7 @@ ignore =
Z999,W0221,W0511,W1203
[pylama:tests/*]
ignore = C0114,C0116,D103,W0613
ignore = C0104,C0114,C0116,D103,W0143,W0613
[pylama:*/__init__.py]
ignore = D104

Loading…
Cancel
Save