Seems to be caused by incompatible types in group_norm when we use autocast.
Patch group_norm to cast the weights to the same type as the inputs
From what I can understand all the other repos just switch to full precision instead
of addressing this. I think this would make things slower but I'm not sure. So maybe
the patching solution I'm doing is better?
https://github.com/pytorch/pytorch/pull/81852