fix: if weights are float32 but float16 was specified, still use float16

pull/474/head
Bryce 4 months ago committed by Bryce Drennan
parent cf8a44b317
commit e6a1c988c5

@ -603,8 +603,8 @@ def load_sdxl_pipeline_from_diffusers_weights(
text_encoder.load_state_dict(text_encoder_weights, assign=True)
del text_encoder_weights
lda = lda.to(device=device, dtype=torch.float32)
unet = unet.to(device=device)
text_encoder = text_encoder.to(device=device)
unet = unet.to(device=device, dtype=dtype)
text_encoder = text_encoder.to(device=device, dtype=dtype)
if for_inpainting:
StableDiffusionCls = StableDiffusion_XL_Inpainting
else:

Loading…
Cancel
Save