From d3106fc9e359a417e2ad404d1de6f40491b6e76e Mon Sep 17 00:00:00 2001 From: Bryce Drennan Date: Fri, 5 Jan 2024 06:34:17 -0800 Subject: [PATCH] fix: videogen bug (#443) --- imaginairy/api/generate.py | 5 +---- imaginairy/api/video_sample.py | 2 +- imaginairy/utils/__init__.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/imaginairy/api/generate.py b/imaginairy/api/generate.py index 472ee45..3806b34 100755 --- a/imaginairy/api/generate.py +++ b/imaginairy/api/generate.py @@ -205,7 +205,6 @@ def imagine( fix_torch_group_norm, fix_torch_nn_layer_norm, get_device, - platform_appropriate_autocast, ) check_torch_version() @@ -227,9 +226,7 @@ def imagine( if half_mode is None: half_mode = "cuda" in get_device() or get_device() == "mps" - with torch.no_grad(), platform_appropriate_autocast( - precision - ), fix_torch_nn_layer_norm(), fix_torch_group_norm(): + with torch.no_grad(), fix_torch_nn_layer_norm(), fix_torch_group_norm(): for i, prompt in enumerate(prompts): concrete_prompt = prompt.make_concrete_copy() prog_text = f"{i + 1}/{num_prompts}" diff --git a/imaginairy/api/video_sample.py b/imaginairy/api/video_sample.py index 6a63c0f..8c63e8f 100644 --- a/imaginairy/api/video_sample.py +++ b/imaginairy/api/video_sample.py @@ -241,7 +241,7 @@ def generate_video( c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) - randn = torch.randn(shape, device=device) + randn = torch.randn(shape, device=device, dtype=torch.float16) additional_model_inputs = {} additional_model_inputs["image_only_indicator"] = torch.zeros( diff --git a/imaginairy/utils/__init__.py b/imaginairy/utils/__init__.py index a41cb81..65ed547 100644 --- a/imaginairy/utils/__init__.py +++ b/imaginairy/utils/__init__.py @@ -91,7 +91,7 @@ def platform_appropriate_autocast(precision="autocast", enabled=True): # https://github.com/pytorch/pytorch/issues/55374 # https://github.com/invoke-ai/InvokeAI/pull/518 - if precision == "autocast" and get_device() in ("cuda",) and False: + if precision == "autocast" and get_device() in ("cuda",): with autocast(get_device(), enabled=enabled): yield else: