mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
316114e660
Wrote an openai script and custom prompt to generate them.
82 lines
2.6 KiB
Python
82 lines
2.6 KiB
Python
"""Classes for image denoising operations"""
|
|
|
|
from typing import TYPE_CHECKING, Dict, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from imaginairy.utils import instantiate_from_config
|
|
from imaginairy.vendored.k_diffusion.utils import append_dims
|
|
|
|
if TYPE_CHECKING:
|
|
from .denoiser_scaling import DenoiserScaling
|
|
from .discretizer import Discretization
|
|
|
|
|
|
class Denoiser(nn.Module):
|
|
def __init__(self, scaling_config: Dict):
|
|
super().__init__()
|
|
|
|
self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
|
|
|
|
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
|
|
return sigma
|
|
|
|
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
|
|
return c_noise
|
|
|
|
def forward(
|
|
self,
|
|
network: nn.Module,
|
|
input_tensor: torch.Tensor,
|
|
sigma: torch.Tensor,
|
|
cond: Dict,
|
|
**additional_model_inputs,
|
|
) -> torch.Tensor:
|
|
sigma = self.possibly_quantize_sigma(sigma)
|
|
sigma_shape = sigma.shape
|
|
sigma = append_dims(sigma, input_tensor.ndim)
|
|
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
|
|
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
|
|
return (
|
|
network(input_tensor * c_in, c_noise, cond, **additional_model_inputs)
|
|
* c_out
|
|
+ input_tensor * c_skip
|
|
)
|
|
|
|
|
|
class DiscreteDenoiser(Denoiser):
|
|
def __init__(
|
|
self,
|
|
scaling_config: Dict,
|
|
num_idx: int,
|
|
discretization_config: Dict,
|
|
do_append_zero: bool = False,
|
|
quantize_c_noise: bool = True,
|
|
flip: bool = True,
|
|
):
|
|
super().__init__(scaling_config)
|
|
self.discretization: Discretization = instantiate_from_config(
|
|
discretization_config
|
|
)
|
|
sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip)
|
|
self.register_buffer("sigmas", sigmas)
|
|
self.quantize_c_noise = quantize_c_noise
|
|
self.num_idx = num_idx
|
|
|
|
def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
|
|
dists = sigma - self.sigmas[:, None]
|
|
return dists.abs().argmin(dim=0).view(sigma.shape)
|
|
|
|
def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor:
|
|
return self.sigmas[idx]
|
|
|
|
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
|
|
return self.idx_to_sigma(self.sigma_to_idx(sigma))
|
|
|
|
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
|
|
if self.quantize_c_noise:
|
|
return self.sigma_to_idx(c_noise)
|
|
else:
|
|
return c_noise
|