@ -2,7 +2,7 @@ import importlib
import logging
import logging
import os . path
import os . path
import platform
import platform
from contextlib import contextmanager
from contextlib import contextmanager , nullcontext
from functools import lru_cache
from functools import lru_cache
from typing import List , Optional
from typing import List , Optional
@ -10,7 +10,7 @@ import numpy as np
import requests
import requests
import torch
import torch
from PIL import Image , ImageFilter
from PIL import Image , ImageFilter
from torch import Tensor
from torch import Tensor , autocast
from torch . nn import functional
from torch . nn import functional
from torch . overrides import handle_torch_function , has_torch_function_variadic
from torch . overrides import handle_torch_function , has_torch_function_variadic
from transformers import cached_path
from transformers import cached_path
@ -104,6 +104,43 @@ def fix_torch_nn_layer_norm():
functional . layer_norm = orig_function
functional . layer_norm = orig_function
@contextmanager
def fix_torch_group_norm ( ) :
"""
Patch group_norm to cast the weights to the same type as the inputs
From what I can understand all the other repos just switch to full precision instead
of addressing this . I think this would make things slower but I ' m not sure.
https : / / github . com / pytorch / pytorch / pull / 81852
"""
orig_group_norm = functional . group_norm
def _group_norm_wrapper (
input : Tensor ,
num_groups : int ,
weight : Optional [ Tensor ] = None ,
bias : Optional [ Tensor ] = None ,
eps : float = 1e-5 ,
) - > Tensor :
if weight is not None and weight . dtype != input . dtype :
weight = weight . to ( input . dtype )
if bias is not None and bias . dtype != input . dtype :
bias = bias . to ( input . dtype )
return orig_group_norm (
input = input , num_groups = num_groups , weight = weight , bias = bias , eps = eps
)
functional . group_norm = _group_norm_wrapper
try :
yield
finally :
functional . group_norm = orig_group_norm
def expand_mask ( mask_image , size ) :
def expand_mask ( mask_image , size ) :
if size < 0 :
if size < 0 :
threshold = 0.95
threshold = 0.95