fix: videogen bug (#443)

pull/445/head
Bryce Drennan 4 months ago committed by GitHub
parent 89bc1a9f1c
commit d3106fc9e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -205,7 +205,6 @@ def imagine(
fix_torch_group_norm,
fix_torch_nn_layer_norm,
get_device,
platform_appropriate_autocast,
)
check_torch_version()
@ -227,9 +226,7 @@ def imagine(
if half_mode is None:
half_mode = "cuda" in get_device() or get_device() == "mps"
with torch.no_grad(), platform_appropriate_autocast(
precision
), fix_torch_nn_layer_norm(), fix_torch_group_norm():
with torch.no_grad(), fix_torch_nn_layer_norm(), fix_torch_group_norm():
for i, prompt in enumerate(prompts):
concrete_prompt = prompt.make_concrete_copy()
prog_text = f"{i + 1}/{num_prompts}"

@ -241,7 +241,7 @@ def generate_video(
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
randn = torch.randn(shape, device=device)
randn = torch.randn(shape, device=device, dtype=torch.float16)
additional_model_inputs = {}
additional_model_inputs["image_only_indicator"] = torch.zeros(

@ -91,7 +91,7 @@ def platform_appropriate_autocast(precision="autocast", enabled=True):
# https://github.com/pytorch/pytorch/issues/55374
# https://github.com/invoke-ai/InvokeAI/pull/518
if precision == "autocast" and get_device() in ("cuda",) and False:
if precision == "autocast" and get_device() in ("cuda",):
with autocast(get_device(), enabled=enabled):
yield
else:

Loading…
Cancel
Save