fix: if weights are float32 but float16 was specified, still use float16

bd/refactor-upscalers
Bryce 4 months ago
parent 646b6e9a62
commit 9f491369a1

@ -603,8 +603,8 @@ def load_sdxl_pipeline_from_diffusers_weights(
text_encoder.load_state_dict(text_encoder_weights, assign=True) text_encoder.load_state_dict(text_encoder_weights, assign=True)
del text_encoder_weights del text_encoder_weights
lda = lda.to(device=device, dtype=torch.float32) lda = lda.to(device=device, dtype=torch.float32)
unet = unet.to(device=device) unet = unet.to(device=device, dtype=dtype)
text_encoder = text_encoder.to(device=device) text_encoder = text_encoder.to(device=device, dtype=dtype)
if for_inpainting: if for_inpainting:
StableDiffusionCls = StableDiffusion_XL_Inpainting StableDiffusionCls = StableDiffusion_XL_Inpainting
else: else:

Loading…
Cancel
Save