fix: try to address #13 BFloat16 issue

Seems to be caused by incompatible types in group_norm when we use autocast.

Patch group_norm to cast the weights to the same type as the inputs

From what I can understand all the other repos just switch to full precision instead
of addressing this.  I think this would make things slower but I'm not sure. So maybe
the patching solution I'm doing is better?

https://github.com/pytorch/pytorch/pull/81852
pull/18/head
Bryce 2 years ago
parent e23e363bf5
commit 09bc1c70e6

@ -51,6 +51,12 @@ deploy: ## Deploy the package to pypi.org
rm -rf dist
@echo "Deploy successful! ✨ 🍰 ✨"
build-dev-image:
docker build -f tests/Dockerfile -t imaginairy-dev .
run-dev: build-dev-image
docker run -it -v $$HOME/.cache/huggingface:/root/.cache/huggingface -v $$HOME/.cache/torch:/root/.cache/torch -v `pwd`/outputs:/outputs imaginairy-dev /bin/bash
requirements: ## Freeze the requirements.txt file
pip-compile setup.py requirements-dev.in --output-file=requirements-dev.txt --upgrade

@ -29,6 +29,7 @@ from imaginairy.samplers.base import get_sampler
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import (
expand_mask,
fix_torch_group_norm,
fix_torch_nn_layer_norm,
get_device,
instantiate_from_config,
@ -156,13 +157,16 @@ def imagine(
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
_img_callback = None
if get_device() == "cpu":
logger.info("Running in CPU mode. it's gonna be slooooooow.")
precision_scope = (
autocast
if precision == "autocast" and get_device() in ("cuda", "cpu")
else nullcontext
)
with torch.no_grad(), precision_scope(get_device()), fix_torch_nn_layer_norm():
with torch.no_grad(), precision_scope(
get_device()
), fix_torch_nn_layer_norm(), fix_torch_group_norm():
for prompt in prompts:
with ImageLoggingContext(
prompt=prompt,

@ -149,6 +149,12 @@ def configure_logging(level="INFO"):
is_flag=True,
help="Generate a text description of the generated image",
)
@click.option(
"--precision",
help="evaluate at this precision",
type=click.Choice(["full", "autocast"]),
default="autocast",
)
@click.pass_context
def imagine_cmd(
ctx,
@ -174,6 +180,7 @@ def imagine_cmd(
mask_mode,
mask_expansion,
caption,
precision,
):
"""Have the AI generate images. alias:imagine"""
if ctx.invoked_subcommand is not None:
@ -220,6 +227,7 @@ def imagine_cmd(
record_step_images="images" in show_work,
output_file_extension="png",
print_caption=caption,
precision=precision,
)

@ -2,7 +2,7 @@ import importlib
import logging
import os.path
import platform
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from functools import lru_cache
from typing import List, Optional
@ -10,7 +10,7 @@ import numpy as np
import requests
import torch
from PIL import Image, ImageFilter
from torch import Tensor
from torch import Tensor, autocast
from torch.nn import functional
from torch.overrides import handle_torch_function, has_torch_function_variadic
from transformers import cached_path
@ -104,6 +104,43 @@ def fix_torch_nn_layer_norm():
functional.layer_norm = orig_function
@contextmanager
def fix_torch_group_norm():
"""
Patch group_norm to cast the weights to the same type as the inputs
From what I can understand all the other repos just switch to full precision instead
of addressing this. I think this would make things slower but I'm not sure.
https://github.com/pytorch/pytorch/pull/81852
"""
orig_group_norm = functional.group_norm
def _group_norm_wrapper(
input: Tensor,
num_groups: int,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
) -> Tensor:
if weight is not None and weight.dtype != input.dtype:
weight = weight.to(input.dtype)
if bias is not None and bias.dtype != input.dtype:
bias = bias.to(input.dtype)
return orig_group_norm(
input=input, num_groups=num_groups, weight=weight, bias=bias, eps=eps
)
functional.group_norm = _group_norm_wrapper
try:
yield
finally:
functional.group_norm = orig_group_norm
def expand_mask(mask_image, size):
if size < 0:
threshold = 0.95

Loading…
Cancel
Save