fix: try to address #13 BFloat16 issue

Seems to be caused by incompatible types in group_norm when we use autocast.

Patch group_norm to cast the weights to the same type as the inputs

From what I can understand all the other repos just switch to full precision instead
of addressing this.  I think this would make things slower but I'm not sure. So maybe
the patching solution I'm doing is better?

https://github.com/pytorch/pytorch/pull/81852
pull/18/head
Bryce 2 years ago
parent e23e363bf5
commit 09bc1c70e6

@ -51,6 +51,12 @@ deploy: ## Deploy the package to pypi.org
rm -rf dist rm -rf dist
@echo "Deploy successful! ✨ 🍰 ✨" @echo "Deploy successful! ✨ 🍰 ✨"
build-dev-image:
docker build -f tests/Dockerfile -t imaginairy-dev .
run-dev: build-dev-image
docker run -it -v $$HOME/.cache/huggingface:/root/.cache/huggingface -v $$HOME/.cache/torch:/root/.cache/torch -v `pwd`/outputs:/outputs imaginairy-dev /bin/bash
requirements: ## Freeze the requirements.txt file requirements: ## Freeze the requirements.txt file
pip-compile setup.py requirements-dev.in --output-file=requirements-dev.txt --upgrade pip-compile setup.py requirements-dev.in --output-file=requirements-dev.txt --upgrade

@ -29,6 +29,7 @@ from imaginairy.samplers.base import get_sampler
from imaginairy.schema import ImaginePrompt, ImagineResult from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import ( from imaginairy.utils import (
expand_mask, expand_mask,
fix_torch_group_norm,
fix_torch_nn_layer_norm, fix_torch_nn_layer_norm,
get_device, get_device,
instantiate_from_config, instantiate_from_config,
@ -156,13 +157,16 @@ def imagine(
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
_img_callback = None _img_callback = None
if get_device() == "cpu":
logger.info("Running in CPU mode. it's gonna be slooooooow.")
precision_scope = ( precision_scope = (
autocast autocast
if precision == "autocast" and get_device() in ("cuda", "cpu") if precision == "autocast" and get_device() in ("cuda", "cpu")
else nullcontext else nullcontext
) )
with torch.no_grad(), precision_scope(get_device()), fix_torch_nn_layer_norm(): with torch.no_grad(), precision_scope(
get_device()
), fix_torch_nn_layer_norm(), fix_torch_group_norm():
for prompt in prompts: for prompt in prompts:
with ImageLoggingContext( with ImageLoggingContext(
prompt=prompt, prompt=prompt,

@ -149,6 +149,12 @@ def configure_logging(level="INFO"):
is_flag=True, is_flag=True,
help="Generate a text description of the generated image", help="Generate a text description of the generated image",
) )
@click.option(
"--precision",
help="evaluate at this precision",
type=click.Choice(["full", "autocast"]),
default="autocast",
)
@click.pass_context @click.pass_context
def imagine_cmd( def imagine_cmd(
ctx, ctx,
@ -174,6 +180,7 @@ def imagine_cmd(
mask_mode, mask_mode,
mask_expansion, mask_expansion,
caption, caption,
precision,
): ):
"""Have the AI generate images. alias:imagine""" """Have the AI generate images. alias:imagine"""
if ctx.invoked_subcommand is not None: if ctx.invoked_subcommand is not None:
@ -220,6 +227,7 @@ def imagine_cmd(
record_step_images="images" in show_work, record_step_images="images" in show_work,
output_file_extension="png", output_file_extension="png",
print_caption=caption, print_caption=caption,
precision=precision,
) )

@ -2,7 +2,7 @@ import importlib
import logging import logging
import os.path import os.path
import platform import platform
from contextlib import contextmanager from contextlib import contextmanager, nullcontext
from functools import lru_cache from functools import lru_cache
from typing import List, Optional from typing import List, Optional
@ -10,7 +10,7 @@ import numpy as np
import requests import requests
import torch import torch
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from torch import Tensor from torch import Tensor, autocast
from torch.nn import functional from torch.nn import functional
from torch.overrides import handle_torch_function, has_torch_function_variadic from torch.overrides import handle_torch_function, has_torch_function_variadic
from transformers import cached_path from transformers import cached_path
@ -104,6 +104,43 @@ def fix_torch_nn_layer_norm():
functional.layer_norm = orig_function functional.layer_norm = orig_function
@contextmanager
def fix_torch_group_norm():
"""
Patch group_norm to cast the weights to the same type as the inputs
From what I can understand all the other repos just switch to full precision instead
of addressing this. I think this would make things slower but I'm not sure.
https://github.com/pytorch/pytorch/pull/81852
"""
orig_group_norm = functional.group_norm
def _group_norm_wrapper(
input: Tensor,
num_groups: int,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
) -> Tensor:
if weight is not None and weight.dtype != input.dtype:
weight = weight.to(input.dtype)
if bias is not None and bias.dtype != input.dtype:
bias = bias.to(input.dtype)
return orig_group_norm(
input=input, num_groups=num_groups, weight=weight, bias=bias, eps=eps
)
functional.group_norm = _group_norm_wrapper
try:
yield
finally:
functional.group_norm = orig_group_norm
def expand_mask(mask_image, size): def expand_mask(mask_image, size):
if size < 0: if size < 0:
threshold = 0.95 threshold = 0.95

Loading…
Cancel
Save