fix: if weights are float32 but float16 was specified, still use float16

pull/461/head
Bryce 5 months ago
parent ba7721f5ec
commit b5874a64cd

@ -603,8 +603,8 @@ def load_sdxl_pipeline_from_diffusers_weights(
text_encoder.load_state_dict(text_encoder_weights, assign=True)
del text_encoder_weights
lda = lda.to(device=device, dtype=torch.float32)
unet = unet.to(device=device)
text_encoder = text_encoder.to(device=device)
unet = unet.to(device=device, dtype=dtype)
text_encoder = text_encoder.to(device=device, dtype=dtype)
if for_inpainting:
StableDiffusionCls = StableDiffusion_XL_Inpainting
else:

Loading…
Cancel
Save