|
|
|
@ -205,7 +205,6 @@ def imagine(
|
|
|
|
|
fix_torch_group_norm,
|
|
|
|
|
fix_torch_nn_layer_norm,
|
|
|
|
|
get_device,
|
|
|
|
|
platform_appropriate_autocast,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
check_torch_version()
|
|
|
|
@ -227,9 +226,7 @@ def imagine(
|
|
|
|
|
if half_mode is None:
|
|
|
|
|
half_mode = "cuda" in get_device() or get_device() == "mps"
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(), platform_appropriate_autocast(
|
|
|
|
|
precision
|
|
|
|
|
), fix_torch_nn_layer_norm(), fix_torch_group_norm():
|
|
|
|
|
with torch.no_grad(), fix_torch_nn_layer_norm(), fix_torch_group_norm():
|
|
|
|
|
for i, prompt in enumerate(prompts):
|
|
|
|
|
concrete_prompt = prompt.make_concrete_copy()
|
|
|
|
|
prog_text = f"{i + 1}/{num_prompts}"
|
|
|
|
|