fix: performance improvement. disable ema (#139)

A configuration `use_ema: False` was became necessary in the newer Stable Diffusion code but was missing from the 1.5 config.
This commit is contained in:
Bryce Drennan 2022-12-18 00:00:38 -08:00 committed by GitHub
parent ad5e467042
commit ccf9749df5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 27 additions and 13 deletions

4
.gitignore vendored
View File

@ -21,4 +21,6 @@ gfpgan/**
tests/vastai_cli.py
/tests/test_output_local_cuda/
/testing_support/
.unison*
.unison*
*.kgrind
*.pyprof

View File

@ -298,3 +298,9 @@ aimg.add_command(imagine_cmd, name="imagine")
if __name__ == "__main__":
imagine_cmd() # noqa
# from cProfile import Profile
# from pyprof2calltree import convert, visualize
# profiler = Profile()
# profiler.runctx("imagine_cmd.main(standalone_mode=False)", locals(), globals())
# convert(profiler.getstats(), 'imagine.kgrind')
# visualize(profiler.getstats())

View File

@ -16,6 +16,7 @@ model:
monitor: val/loss_simple_ema
scale_factor: 0.18215
finetune_keys: null
use_ema: False
scheduler_config: # 10000 warm-up steps
target: ldm.lr_scheduler.LambdaLinearScheduler

View File

@ -1,5 +1,5 @@
model:
base_learning_rate: 1.0e-04
base_learning_rate: 1.0e-4
target: imaginairy.modules.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
@ -11,10 +11,11 @@ model:
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
scheduler_config: # 10000 warm-up steps
target: imaginairy.lr_scheduler.LambdaLinearScheduler
@ -28,6 +29,7 @@ model:
unet_config:
target: imaginairy.modules.diffusion.openaimodel.UNetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
@ -39,7 +41,6 @@ model:
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:

View File

@ -80,7 +80,8 @@ def load_model_from_config(config, weights_location):
except FileNotFoundError as e:
if e.errno == 2:
logger.error(
f'Error: "{ckpt_path}" not a valid path to model weights.\nPreconfigured models you can use: {MODEL_SHORT_NAMES}.')
f'Error: "{ckpt_path}" not a valid path to model weights.\nPreconfigured models you can use: {MODEL_SHORT_NAMES}.'
)
sys.exit(1)
raise e
except RuntimeError as e:

View File

@ -449,16 +449,19 @@ class SpatialTransformer(nn.Module):
b, c, h, w = x.shape # noqa
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
if self.use_linear:
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i])
if self.use_linear:
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i])
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
if not self.use_linear:
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
else:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c")
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i])
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x + x_in