diff --git a/Makefile b/Makefile index bd4d4e0..6b37a6d 100644 --- a/Makefile +++ b/Makefile @@ -51,6 +51,12 @@ deploy: ## Deploy the package to pypi.org rm -rf dist @echo "Deploy successful! ✨ 🍰 ✨" +build-dev-image: + docker build -f tests/Dockerfile -t imaginairy-dev . + +run-dev: build-dev-image + docker run -it -v $$HOME/.cache/huggingface:/root/.cache/huggingface -v $$HOME/.cache/torch:/root/.cache/torch -v `pwd`/outputs:/outputs imaginairy-dev /bin/bash + requirements: ## Freeze the requirements.txt file pip-compile setup.py requirements-dev.in --output-file=requirements-dev.txt --upgrade diff --git a/imaginairy/api.py b/imaginairy/api.py index c999b3f..2b1f066 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -29,6 +29,7 @@ from imaginairy.samplers.base import get_sampler from imaginairy.schema import ImaginePrompt, ImagineResult from imaginairy.utils import ( expand_mask, + fix_torch_group_norm, fix_torch_nn_layer_norm, get_device, instantiate_from_config, @@ -156,13 +157,16 @@ def imagine( prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts _img_callback = None - + if get_device() == "cpu": + logger.info("Running in CPU mode. it's gonna be slooooooow.") precision_scope = ( autocast if precision == "autocast" and get_device() in ("cuda", "cpu") else nullcontext ) - with torch.no_grad(), precision_scope(get_device()), fix_torch_nn_layer_norm(): + with torch.no_grad(), precision_scope( + get_device() + ), fix_torch_nn_layer_norm(), fix_torch_group_norm(): for prompt in prompts: with ImageLoggingContext( prompt=prompt, diff --git a/imaginairy/cmds.py b/imaginairy/cmds.py index 84fc257..8d16106 100644 --- a/imaginairy/cmds.py +++ b/imaginairy/cmds.py @@ -149,6 +149,12 @@ def configure_logging(level="INFO"): is_flag=True, help="Generate a text description of the generated image", ) +@click.option( + "--precision", + help="evaluate at this precision", + type=click.Choice(["full", "autocast"]), + default="autocast", +) @click.pass_context def imagine_cmd( ctx, @@ -174,6 +180,7 @@ def imagine_cmd( mask_mode, mask_expansion, caption, + precision, ): """Have the AI generate images. alias:imagine""" if ctx.invoked_subcommand is not None: @@ -220,6 +227,7 @@ def imagine_cmd( record_step_images="images" in show_work, output_file_extension="png", print_caption=caption, + precision=precision, ) diff --git a/imaginairy/utils.py b/imaginairy/utils.py index b382c36..ca75610 100644 --- a/imaginairy/utils.py +++ b/imaginairy/utils.py @@ -2,7 +2,7 @@ import importlib import logging import os.path import platform -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from functools import lru_cache from typing import List, Optional @@ -10,7 +10,7 @@ import numpy as np import requests import torch from PIL import Image, ImageFilter -from torch import Tensor +from torch import Tensor, autocast from torch.nn import functional from torch.overrides import handle_torch_function, has_torch_function_variadic from transformers import cached_path @@ -104,6 +104,43 @@ def fix_torch_nn_layer_norm(): functional.layer_norm = orig_function +@contextmanager +def fix_torch_group_norm(): + """ + Patch group_norm to cast the weights to the same type as the inputs + + From what I can understand all the other repos just switch to full precision instead + of addressing this. I think this would make things slower but I'm not sure. + + https://github.com/pytorch/pytorch/pull/81852 + + """ + + orig_group_norm = functional.group_norm + + def _group_norm_wrapper( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, + ) -> Tensor: + if weight is not None and weight.dtype != input.dtype: + weight = weight.to(input.dtype) + if bias is not None and bias.dtype != input.dtype: + bias = bias.to(input.dtype) + + return orig_group_norm( + input=input, num_groups=num_groups, weight=weight, bias=bias, eps=eps + ) + + functional.group_norm = _group_norm_wrapper + try: + yield + finally: + functional.group_norm = orig_group_norm + + def expand_mask(mask_image, size): if size < 0: threshold = 0.95