From 340a90bacd684cd5fd6c277eccedc0da707b8b81 Mon Sep 17 00:00:00 2001 From: Bryce Date: Wed, 5 Oct 2022 21:50:20 -0700 Subject: [PATCH] feature: allow loading custom model weights at command line Addresses #40 --- imaginairy/api.py | 30 ++++++++++++++++++++++++------ imaginairy/cmds.py | 11 +++++++++-- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/imaginairy/api.py b/imaginairy/api.py index b323f01..5e413ca 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -50,10 +50,23 @@ class SafetyMode: # the press or governments to freak out about AI... IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.FILTER) +DEFAULT_MODEL_WEIGHTS_LOCATION = ( + "https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media" +) + -def load_model_from_config(config): - url = "https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media" - ckpt_path = cached_path(url) +def load_model_from_config( + config, model_weights_location=DEFAULT_MODEL_WEIGHTS_LOCATION +): + model_weights_location = ( + model_weights_location + if model_weights_location + else DEFAULT_MODEL_WEIGHTS_LOCATION + ) + if model_weights_location.startswith("http"): + ckpt_path = cached_path(model_weights_location) + else: + ckpt_path = model_weights_location logger.info(f"Loading model onto {get_device()} backend...") logger.debug(f"Loading model from {ckpt_path}") pl_sd = torch.load(ckpt_path, map_location="cpu") @@ -73,10 +86,12 @@ def load_model_from_config(config): @lru_cache() -def load_model(): +def load_model(model_weights_location=None): config = "configs/stable-diffusion-v1.yaml" config = OmegaConf.load(f"{LIB_PATH}/{config}") - model = load_model_from_config(config) + model = load_model_from_config( + config, model_weights_location=model_weights_location + ) model = model.to(get_device()) return model @@ -91,6 +106,7 @@ def imagine_image_files( record_step_images=False, output_file_extension="jpg", print_caption=False, + model_weights_path=None, ): generated_imgs_path = os.path.join(outdir, "generated") os.makedirs(generated_imgs_path, exist_ok=True) @@ -118,6 +134,7 @@ def imagine_image_files( ddim_eta=ddim_eta, img_callback=_record_step if record_step_images else None, add_caption=print_caption, + model_weights_path=model_weights_path, ): prompt = result.prompt img_str = "" @@ -146,8 +163,9 @@ def imagine( img_callback=None, half_mode=None, add_caption=False, + model_weights_path=None, ): - model = load_model() + model = load_model(model_weights_location=model_weights_path) # only run half-mode on cuda. run it by default half_mode = half_mode is None and get_device() == "cuda" diff --git a/imaginairy/cmds.py b/imaginairy/cmds.py index c5291e2..71d1f0e 100644 --- a/imaginairy/cmds.py +++ b/imaginairy/cmds.py @@ -4,7 +4,7 @@ import click from click_shell import shell from imaginairy import LazyLoadingImage, generate_caption -from imaginairy.api import imagine_image_files, load_model +from imaginairy.api import imagine_image_files from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS from imaginairy.schema import ImaginePrompt from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings @@ -173,6 +173,12 @@ def configure_logging(level="INFO"): type=click.Choice(["full", "autocast"]), default="autocast", ) +@click.option( + "--model-weights-path", + help="path to model weights file. by default we use stable diffusion 1.4", + type=click.Path(exists=True), + default=None, +) @click.pass_context def imagine_cmd( ctx, @@ -201,6 +207,7 @@ def imagine_cmd( mask_modify_original, caption, precision, + model_weights_path, ): """Have the AI generate images. alias:imagine""" if ctx.invoked_subcommand is not None: @@ -223,7 +230,6 @@ def imagine_cmd( if fix_faces_fidelity is not None: fix_faces_fidelity = float(fix_faces_fidelity) prompts = [] - load_model() for _ in range(repeats): for prompt_text in prompt_texts: prompt = ImaginePrompt( @@ -255,6 +261,7 @@ def imagine_cmd( output_file_extension="jpg", print_caption=caption, precision=precision, + model_weights_path=model_weights_path, )