fix: if weights are float32 but float16 was specified, still use float16

This commit is contained in:
Bryce 2024-01-20 07:57:47 -08:00 committed by Bryce Drennan
parent cf8a44b317
commit e6a1c988c5

View File

@ -603,8 +603,8 @@ def load_sdxl_pipeline_from_diffusers_weights(
text_encoder.load_state_dict(text_encoder_weights, assign=True) text_encoder.load_state_dict(text_encoder_weights, assign=True)
del text_encoder_weights del text_encoder_weights
lda = lda.to(device=device, dtype=torch.float32) lda = lda.to(device=device, dtype=torch.float32)
unet = unet.to(device=device) unet = unet.to(device=device, dtype=dtype)
text_encoder = text_encoder.to(device=device) text_encoder = text_encoder.to(device=device, dtype=dtype)
if for_inpainting: if for_inpainting:
StableDiffusionCls = StableDiffusion_XL_Inpainting StableDiffusionCls = StableDiffusion_XL_Inpainting
else: else: