feature: update refiners

better handles img2img (partial diffusion runs)
This commit is contained in:
Bryce 2024-01-20 08:36:53 -08:00 committed by Bryce Drennan
parent 1bf53e47cf
commit cf8a44b317
13 changed files with 64 additions and 36 deletions

View File

@ -210,7 +210,7 @@ vendorize_normal_map:
vendorize_refiners: vendorize_refiners:
export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=ce3035923ba71bcb5044708d2f1c37fd1d6722e9 && \ export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=91aea9b7ff63ddf93f99e2ce6a4452bd658b1948 && \
make download_repo REPO=$$REPO PKG=$$PKG COMMIT=$$COMMIT && \ make download_repo REPO=$$REPO PKG=$$PKG COMMIT=$$COMMIT && \
mkdir -p ./imaginairy/vendored/$$PKG && \ mkdir -p ./imaginairy/vendored/$$PKG && \
rm -rf ./imaginairy/vendored/$$PKG/* && \ rm -rf ./imaginairy/vendored/$$PKG/* && \

View File

@ -289,7 +289,7 @@ def generate_single_image(
msg = f"Unknown solver type: {prompt.solver_type}" msg = f"Unknown solver type: {prompt.solver_type}"
raise ValueError(msg) raise ValueError(msg)
sd.scheduler.to(device=sd.unet.device, dtype=sd.unet.dtype) sd.scheduler.to(device=sd.unet.device, dtype=sd.unet.dtype)
sd.set_num_inference_steps(prompt.steps) sd.set_inference_steps(prompt.steps, first_step=first_step)
if hasattr(sd, "mask_latents") and mask_image is not None: if hasattr(sd, "mask_latents") and mask_image is not None:
sd.set_inpainting_conditions( sd.set_inpainting_conditions(
@ -306,11 +306,11 @@ def generate_single_image(
if init_latent is not None: if init_latent is not None:
noise_step = noise_step if noise_step is not None else first_step noise_step = noise_step if noise_step is not None else first_step
if first_step >= len(sd.steps): if first_step >= len(sd.scheduler.all_steps):
noised_latent = init_latent noised_latent = init_latent
else: else:
noised_latent = sd.scheduler.add_noise( noised_latent = sd.scheduler.add_noise(
x=init_latent, noise=noise, step=sd.steps[noise_step] x=init_latent, noise=noise, step=sd.scheduler.all_steps[noise_step]
) )
with lc.timing("text-conditioning"): with lc.timing("text-conditioning"):
@ -330,7 +330,7 @@ def generate_single_image(
with lc.timing("unet"): with lc.timing("unet"):
for step in tqdm( for step in tqdm(
sd.steps[first_step:], bar_format=" {l_bar}{bar}{r_bar}", leave=False sd.steps, bar_format=" {l_bar}{bar}{r_bar}", leave=False
): ):
log_latent(x, "noisy_latent") log_latent(x, "noisy_latent")
x = sd( x = sd(

View File

@ -193,19 +193,23 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str:
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}", f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
f"device={tensor.device}", f"device={tensor.device}",
] ]
if not tensor.is_complex(): if tensor.is_complex():
info_list.extend( tensor_f = tensor.real.float()
[ else:
f"min={tensor.min():.2f}", # type: ignore if tensor.numel() > 0:
f"max={tensor.max():.2f}", # type: ignore info_list.extend(
] [
) f"min={tensor.min():.2f}", # type: ignore
f"max={tensor.max():.2f}", # type: ignore
]
)
tensor_f = tensor.float()
info_list.extend( info_list.extend(
[ [
f"mean={tensor.float().mean():.2f}", f"mean={tensor_f.mean():.2f}",
f"std={tensor.float().std():.2f}", f"std={tensor_f.std():.2f}",
f"norm={norm(x=tensor.float()):.2f}", f"norm={norm(x=tensor_f):.2f}",
f"grad={tensor.requires_grad}", f"grad={tensor.requires_grad}",
] ]
) )

View File

@ -32,21 +32,21 @@ class LatentDiffusionModel(fl.Module, ABC):
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype) self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype) self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
def set_num_inference_steps(self, num_inference_steps: int) -> None: def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
initial_diffusion_rate = self.scheduler.initial_diffusion_rate initial_diffusion_rate = self.scheduler.initial_diffusion_rate
final_diffusion_rate = self.scheduler.final_diffusion_rate final_diffusion_rate = self.scheduler.final_diffusion_rate
device, dtype = self.scheduler.device, self.scheduler.dtype device, dtype = self.scheduler.device, self.scheduler.dtype
self.scheduler = self.scheduler.__class__( self.scheduler = self.scheduler.__class__(
num_inference_steps, num_inference_steps=num_steps,
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
first_inference_step=first_step,
).to(device=device, dtype=dtype) ).to(device=device, dtype=dtype)
def init_latents( def init_latents(
self, self,
size: tuple[int, int], size: tuple[int, int],
init_image: Image.Image | None = None, init_image: Image.Image | None = None,
first_step: int = 0,
noise: Tensor | None = None, noise: Tensor | None = None,
) -> Tensor: ) -> Tensor:
height, width = size height, width = size
@ -59,11 +59,15 @@ class LatentDiffusionModel(fl.Module, ABC):
if init_image is None: if init_image is None:
return noise return noise
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height))) encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step]) return self.scheduler.add_noise(
x=encoded_image,
noise=noise,
step=self.scheduler.first_inference_step,
)
@property @property
def steps(self) -> list[int]: def steps(self) -> list[int]:
return self.scheduler.steps return self.scheduler.inference_steps
@abstractmethod @abstractmethod
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:

View File

@ -24,23 +24,23 @@ def compute_sinusoidal_embedding(
class RangeEncoder(fl.Chain): class RangeEncoder(fl.Chain):
def __init__( def __init__(
self, self,
sinuosidal_embedding_dim: int, sinusoidal_embedding_dim: int,
embedding_dim: int, embedding_dim: int,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
self.sinuosidal_embedding_dim = sinuosidal_embedding_dim self.sinusoidal_embedding_dim = sinusoidal_embedding_dim
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
super().__init__( super().__init__(
fl.Lambda(self.compute_sinuosoidal_embedding), fl.Lambda(self.compute_sinusoidal_embedding),
fl.Converter(set_device=False, set_dtype=True), fl.Converter(set_device=False, set_dtype=True),
fl.Linear(in_features=sinuosidal_embedding_dim, out_features=embedding_dim, device=device, dtype=dtype), fl.Linear(in_features=sinusoidal_embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
fl.SiLU(), fl.SiLU(),
fl.Linear(in_features=embedding_dim, out_features=embedding_dim, device=device, dtype=dtype), fl.Linear(in_features=embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
) )
def compute_sinuosoidal_embedding(self, x: Int[Tensor, "*batch 1"]) -> Float[Tensor, "*batch 1 embedding_dim"]: def compute_sinusoidal_embedding(self, x: Int[Tensor, "*batch 1"]) -> Float[Tensor, "*batch 1 embedding_dim"]:
return compute_sinusoidal_embedding(x, embedding_dim=self.sinuosidal_embedding_dim) return compute_sinusoidal_embedding(x, embedding_dim=self.sinusoidal_embedding_dim)
class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]): class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):

View File

@ -11,6 +11,7 @@ class DDIM(Scheduler):
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
) -> None: ) -> None:
@ -20,6 +21,7 @@ class DDIM(Scheduler):
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule, noise_schedule=noise_schedule,
first_inference_step=first_inference_step,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -35,6 +37,8 @@ class DDIM(Scheduler):
return timesteps.flip(0) return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
timestep, previous_timestep = ( timestep, previous_timestep = (
self.timesteps[step], self.timesteps[step],
( (

View File

@ -5,8 +5,9 @@ from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.sche
class DDPM(Scheduler): class DDPM(Scheduler):
""" """
The Denoising Diffusion Probabilistic Models (DDPM) is a specific type of diffusion model, Denoising Diffusion Probabilistic Model
which uses a specific strategy to generate the timesteps and applies the diffusion process in a specific way.
Only used for training Latent Diffusion models. Cannot be called.
""" """
def __init__( def __init__(
@ -15,6 +16,7 @@ class DDPM(Scheduler):
num_train_timesteps: int = 1_000, num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
first_inference_step: int = 0,
device: Device | str = "cpu", device: Device | str = "cpu",
) -> None: ) -> None:
super().__init__( super().__init__(
@ -22,6 +24,7 @@ class DDPM(Scheduler):
num_train_timesteps=num_train_timesteps, num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
first_inference_step=first_inference_step,
device=device, device=device,
) )

View File

@ -24,6 +24,7 @@ class DPMSolver(Scheduler):
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
last_step_first_order: bool = False, last_step_first_order: bool = False,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
): ):
@ -33,12 +34,12 @@ class DPMSolver(Scheduler):
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule, noise_schedule=noise_schedule,
first_inference_step=first_inference_step,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self.estimated_data = deque([tensor([])] * 2, maxlen=2) self.estimated_data = deque([tensor([])] * 2, maxlen=2)
self.last_step_first_order = last_step_first_order self.last_step_first_order = last_step_first_order
self._first_step_has_been_run = False
def _generate_timesteps(self) -> Tensor: def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because: # We need to use numpy here because:
@ -81,6 +82,7 @@ class DPMSolver(Scheduler):
previous_scale_factor = self.cumulative_scale_factors[previous_timestep] previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_noise_std = self.noise_std[previous_timestep] previous_noise_std = self.noise_std[previous_timestep]
current_noise_std = self.noise_std[current_timestep] current_noise_std = self.noise_std[current_timestep]
estimation_delta = (current_data_estimation - next_data_estimation) / ( estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio) (current_ratio - next_ratio) / (previous_ratio - current_ratio)
) )
@ -100,13 +102,14 @@ class DPMSolver(Scheduler):
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
(ODEs). (ODEs).
""" """
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
current_timestep = self.timesteps[step] current_timestep = self.timesteps[step]
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep] scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
self.estimated_data.append(estimated_denoised_data) self.estimated_data.append(estimated_denoised_data)
if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1) or not self._first_step_has_been_run: if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):
self._first_step_has_been_run = True
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step) return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
return self.multistep_dpm_solver_second_order_update(x=x, step=step) return self.multistep_dpm_solver_second_order_update(x=x, step=step)

View File

@ -13,6 +13,7 @@ class EulerScheduler(Scheduler):
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
): ):
@ -24,6 +25,7 @@ class EulerScheduler(Scheduler):
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule, noise_schedule=noise_schedule,
first_inference_step=first_inference_step,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -64,6 +66,8 @@ class EulerScheduler(Scheduler):
s_tmax: float = float("inf"), s_tmax: float = float("inf"),
s_noise: float = 1.0, s_noise: float = 1.0,
) -> Tensor: ) -> Tensor:
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
sigma = self.sigmas[step] sigma = self.sigmas[step]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0 gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0

View File

@ -33,6 +33,7 @@ class Scheduler(ABC):
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = float32, dtype: DType = float32,
): ):
@ -43,6 +44,7 @@ class Scheduler(ABC):
self.initial_diffusion_rate = initial_diffusion_rate self.initial_diffusion_rate = initial_diffusion_rate
self.final_diffusion_rate = final_diffusion_rate self.final_diffusion_rate = final_diffusion_rate
self.noise_schedule = noise_schedule self.noise_schedule = noise_schedule
self.first_inference_step = first_inference_step
self.scale_factors = self.sample_noise_schedule() self.scale_factors = self.sample_noise_schedule()
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0)) self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0)) self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
@ -68,9 +70,13 @@ class Scheduler(ABC):
... ...
@property @property
def steps(self) -> list[int]: def all_steps(self) -> list[int]:
return list(range(self.num_inference_steps)) return list(range(self.num_inference_steps))
@property
def inference_steps(self) -> list[int]:
return self.all_steps[self.first_inference_step :]
def scale_model_input(self, x: Tensor, step: int) -> Tensor: def scale_model_input(self, x: Tensor, step: int) -> Tensor:
""" """
For compatibility with schedulers that need to scale the input according to the current timestep. For compatibility with schedulers that need to scale the input according to the current timestep.

View File

@ -28,7 +28,7 @@ class TextTimeEmbedding(fl.Chain):
fl.Chain( fl.Chain(
fl.UseContext(context="diffusion", key="time_ids"), fl.UseContext(context="diffusion", key="time_ids"),
fl.Unsqueeze(dim=-1), fl.Unsqueeze(dim=-1),
fl.Lambda(func=self.compute_sinuosoidal_embedding), fl.Lambda(func=self.compute_sinusoidal_embedding),
fl.Reshape(-1), fl.Reshape(-1),
), ),
dim=1, dim=1,
@ -49,7 +49,7 @@ class TextTimeEmbedding(fl.Chain):
), ),
) )
def compute_sinuosoidal_embedding(self, x: Tensor) -> Tensor: def compute_sinusoidal_embedding(self, x: Tensor) -> Tensor:
return compute_sinusoidal_embedding(x=x, embedding_dim=self.time_ids_embedding_dim) return compute_sinusoidal_embedding(x=x, embedding_dim=self.time_ids_embedding_dim)
@ -61,7 +61,7 @@ class TimestepEncoder(fl.Passthrough):
fl.Chain( fl.Chain(
fl.UseContext(context="diffusion", key="timestep"), fl.UseContext(context="diffusion", key="timestep"),
RangeEncoder( RangeEncoder(
sinuosidal_embedding_dim=320, sinusoidal_embedding_dim=320,
embedding_dim=self.timestep_embedding_dim, embedding_dim=self.timestep_embedding_dim,
device=device, device=device,
dtype=dtype, dtype=dtype,

View File

@ -1 +1 @@
vendored from git@github.com:finegrain-ai/refiners.git @ ce3035923ba71bcb5044708d2f1c37fd1d6722e9 vendored from git@github.com:finegrain-ai/refiners.git @ 91aea9b7ff63ddf93f99e2ce6a4452bd658b1948

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.1 MiB

After

Width:  |  Height:  |  Size: 3.0 MiB