feature: allow loading custom model weights at command line

Addresses #40
This commit is contained in:
Bryce 2022-10-05 21:50:20 -07:00 committed by Bryce Drennan
parent 282fbc19b5
commit 340a90bacd
2 changed files with 33 additions and 8 deletions

View File

@ -50,10 +50,23 @@ class SafetyMode:
# the press or governments to freak out about AI... # the press or governments to freak out about AI...
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.FILTER) IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.FILTER)
DEFAULT_MODEL_WEIGHTS_LOCATION = (
"https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
)
def load_model_from_config(config):
url = "https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media" def load_model_from_config(
ckpt_path = cached_path(url) config, model_weights_location=DEFAULT_MODEL_WEIGHTS_LOCATION
):
model_weights_location = (
model_weights_location
if model_weights_location
else DEFAULT_MODEL_WEIGHTS_LOCATION
)
if model_weights_location.startswith("http"):
ckpt_path = cached_path(model_weights_location)
else:
ckpt_path = model_weights_location
logger.info(f"Loading model onto {get_device()} backend...") logger.info(f"Loading model onto {get_device()} backend...")
logger.debug(f"Loading model from {ckpt_path}") logger.debug(f"Loading model from {ckpt_path}")
pl_sd = torch.load(ckpt_path, map_location="cpu") pl_sd = torch.load(ckpt_path, map_location="cpu")
@ -73,10 +86,12 @@ def load_model_from_config(config):
@lru_cache() @lru_cache()
def load_model(): def load_model(model_weights_location=None):
config = "configs/stable-diffusion-v1.yaml" config = "configs/stable-diffusion-v1.yaml"
config = OmegaConf.load(f"{LIB_PATH}/{config}") config = OmegaConf.load(f"{LIB_PATH}/{config}")
model = load_model_from_config(config) model = load_model_from_config(
config, model_weights_location=model_weights_location
)
model = model.to(get_device()) model = model.to(get_device())
return model return model
@ -91,6 +106,7 @@ def imagine_image_files(
record_step_images=False, record_step_images=False,
output_file_extension="jpg", output_file_extension="jpg",
print_caption=False, print_caption=False,
model_weights_path=None,
): ):
generated_imgs_path = os.path.join(outdir, "generated") generated_imgs_path = os.path.join(outdir, "generated")
os.makedirs(generated_imgs_path, exist_ok=True) os.makedirs(generated_imgs_path, exist_ok=True)
@ -118,6 +134,7 @@ def imagine_image_files(
ddim_eta=ddim_eta, ddim_eta=ddim_eta,
img_callback=_record_step if record_step_images else None, img_callback=_record_step if record_step_images else None,
add_caption=print_caption, add_caption=print_caption,
model_weights_path=model_weights_path,
): ):
prompt = result.prompt prompt = result.prompt
img_str = "" img_str = ""
@ -146,8 +163,9 @@ def imagine(
img_callback=None, img_callback=None,
half_mode=None, half_mode=None,
add_caption=False, add_caption=False,
model_weights_path=None,
): ):
model = load_model() model = load_model(model_weights_location=model_weights_path)
# only run half-mode on cuda. run it by default # only run half-mode on cuda. run it by default
half_mode = half_mode is None and get_device() == "cuda" half_mode = half_mode is None and get_device() == "cuda"

View File

@ -4,7 +4,7 @@ import click
from click_shell import shell from click_shell import shell
from imaginairy import LazyLoadingImage, generate_caption from imaginairy import LazyLoadingImage, generate_caption
from imaginairy.api import imagine_image_files, load_model from imaginairy.api import imagine_image_files
from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS
from imaginairy.schema import ImaginePrompt from imaginairy.schema import ImaginePrompt
from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings
@ -173,6 +173,12 @@ def configure_logging(level="INFO"):
type=click.Choice(["full", "autocast"]), type=click.Choice(["full", "autocast"]),
default="autocast", default="autocast",
) )
@click.option(
"--model-weights-path",
help="path to model weights file. by default we use stable diffusion 1.4",
type=click.Path(exists=True),
default=None,
)
@click.pass_context @click.pass_context
def imagine_cmd( def imagine_cmd(
ctx, ctx,
@ -201,6 +207,7 @@ def imagine_cmd(
mask_modify_original, mask_modify_original,
caption, caption,
precision, precision,
model_weights_path,
): ):
"""Have the AI generate images. alias:imagine""" """Have the AI generate images. alias:imagine"""
if ctx.invoked_subcommand is not None: if ctx.invoked_subcommand is not None:
@ -223,7 +230,6 @@ def imagine_cmd(
if fix_faces_fidelity is not None: if fix_faces_fidelity is not None:
fix_faces_fidelity = float(fix_faces_fidelity) fix_faces_fidelity = float(fix_faces_fidelity)
prompts = [] prompts = []
load_model()
for _ in range(repeats): for _ in range(repeats):
for prompt_text in prompt_texts: for prompt_text in prompt_texts:
prompt = ImaginePrompt( prompt = ImaginePrompt(
@ -255,6 +261,7 @@ def imagine_cmd(
output_file_extension="jpg", output_file_extension="jpg",
print_caption=caption, print_caption=caption,
precision=precision, precision=precision,
model_weights_path=model_weights_path,
) )