feature: allow loading custom model weights at command line

Addresses #40
pull/49/head
Bryce 2 years ago committed by Bryce Drennan
parent 282fbc19b5
commit 340a90bacd

@ -50,10 +50,23 @@ class SafetyMode:
# the press or governments to freak out about AI...
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.FILTER)
DEFAULT_MODEL_WEIGHTS_LOCATION = (
"https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
)
def load_model_from_config(config):
url = "https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
ckpt_path = cached_path(url)
def load_model_from_config(
config, model_weights_location=DEFAULT_MODEL_WEIGHTS_LOCATION
):
model_weights_location = (
model_weights_location
if model_weights_location
else DEFAULT_MODEL_WEIGHTS_LOCATION
)
if model_weights_location.startswith("http"):
ckpt_path = cached_path(model_weights_location)
else:
ckpt_path = model_weights_location
logger.info(f"Loading model onto {get_device()} backend...")
logger.debug(f"Loading model from {ckpt_path}")
pl_sd = torch.load(ckpt_path, map_location="cpu")
@ -73,10 +86,12 @@ def load_model_from_config(config):
@lru_cache()
def load_model():
def load_model(model_weights_location=None):
config = "configs/stable-diffusion-v1.yaml"
config = OmegaConf.load(f"{LIB_PATH}/{config}")
model = load_model_from_config(config)
model = load_model_from_config(
config, model_weights_location=model_weights_location
)
model = model.to(get_device())
return model
@ -91,6 +106,7 @@ def imagine_image_files(
record_step_images=False,
output_file_extension="jpg",
print_caption=False,
model_weights_path=None,
):
generated_imgs_path = os.path.join(outdir, "generated")
os.makedirs(generated_imgs_path, exist_ok=True)
@ -118,6 +134,7 @@ def imagine_image_files(
ddim_eta=ddim_eta,
img_callback=_record_step if record_step_images else None,
add_caption=print_caption,
model_weights_path=model_weights_path,
):
prompt = result.prompt
img_str = ""
@ -146,8 +163,9 @@ def imagine(
img_callback=None,
half_mode=None,
add_caption=False,
model_weights_path=None,
):
model = load_model()
model = load_model(model_weights_location=model_weights_path)
# only run half-mode on cuda. run it by default
half_mode = half_mode is None and get_device() == "cuda"

@ -4,7 +4,7 @@ import click
from click_shell import shell
from imaginairy import LazyLoadingImage, generate_caption
from imaginairy.api import imagine_image_files, load_model
from imaginairy.api import imagine_image_files
from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS
from imaginairy.schema import ImaginePrompt
from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings
@ -173,6 +173,12 @@ def configure_logging(level="INFO"):
type=click.Choice(["full", "autocast"]),
default="autocast",
)
@click.option(
"--model-weights-path",
help="path to model weights file. by default we use stable diffusion 1.4",
type=click.Path(exists=True),
default=None,
)
@click.pass_context
def imagine_cmd(
ctx,
@ -201,6 +207,7 @@ def imagine_cmd(
mask_modify_original,
caption,
precision,
model_weights_path,
):
"""Have the AI generate images. alias:imagine"""
if ctx.invoked_subcommand is not None:
@ -223,7 +230,6 @@ def imagine_cmd(
if fix_faces_fidelity is not None:
fix_faces_fidelity = float(fix_faces_fidelity)
prompts = []
load_model()
for _ in range(repeats):
for prompt_text in prompt_texts:
prompt = ImaginePrompt(
@ -255,6 +261,7 @@ def imagine_cmd(
output_file_extension="jpg",
print_caption=caption,
precision=precision,
model_weights_path=model_weights_path,
)

Loading…
Cancel
Save