mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-05 12:00:15 +00:00
feature: allow loading custom model weights at command line
Addresses #40
This commit is contained in:
parent
282fbc19b5
commit
340a90bacd
@ -50,10 +50,23 @@ class SafetyMode:
|
|||||||
# the press or governments to freak out about AI...
|
# the press or governments to freak out about AI...
|
||||||
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.FILTER)
|
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.FILTER)
|
||||||
|
|
||||||
|
DEFAULT_MODEL_WEIGHTS_LOCATION = (
|
||||||
|
"https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
|
||||||
|
)
|
||||||
|
|
||||||
def load_model_from_config(config):
|
|
||||||
url = "https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
|
def load_model_from_config(
|
||||||
ckpt_path = cached_path(url)
|
config, model_weights_location=DEFAULT_MODEL_WEIGHTS_LOCATION
|
||||||
|
):
|
||||||
|
model_weights_location = (
|
||||||
|
model_weights_location
|
||||||
|
if model_weights_location
|
||||||
|
else DEFAULT_MODEL_WEIGHTS_LOCATION
|
||||||
|
)
|
||||||
|
if model_weights_location.startswith("http"):
|
||||||
|
ckpt_path = cached_path(model_weights_location)
|
||||||
|
else:
|
||||||
|
ckpt_path = model_weights_location
|
||||||
logger.info(f"Loading model onto {get_device()} backend...")
|
logger.info(f"Loading model onto {get_device()} backend...")
|
||||||
logger.debug(f"Loading model from {ckpt_path}")
|
logger.debug(f"Loading model from {ckpt_path}")
|
||||||
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
||||||
@ -73,10 +86,12 @@ def load_model_from_config(config):
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def load_model():
|
def load_model(model_weights_location=None):
|
||||||
config = "configs/stable-diffusion-v1.yaml"
|
config = "configs/stable-diffusion-v1.yaml"
|
||||||
config = OmegaConf.load(f"{LIB_PATH}/{config}")
|
config = OmegaConf.load(f"{LIB_PATH}/{config}")
|
||||||
model = load_model_from_config(config)
|
model = load_model_from_config(
|
||||||
|
config, model_weights_location=model_weights_location
|
||||||
|
)
|
||||||
model = model.to(get_device())
|
model = model.to(get_device())
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -91,6 +106,7 @@ def imagine_image_files(
|
|||||||
record_step_images=False,
|
record_step_images=False,
|
||||||
output_file_extension="jpg",
|
output_file_extension="jpg",
|
||||||
print_caption=False,
|
print_caption=False,
|
||||||
|
model_weights_path=None,
|
||||||
):
|
):
|
||||||
generated_imgs_path = os.path.join(outdir, "generated")
|
generated_imgs_path = os.path.join(outdir, "generated")
|
||||||
os.makedirs(generated_imgs_path, exist_ok=True)
|
os.makedirs(generated_imgs_path, exist_ok=True)
|
||||||
@ -118,6 +134,7 @@ def imagine_image_files(
|
|||||||
ddim_eta=ddim_eta,
|
ddim_eta=ddim_eta,
|
||||||
img_callback=_record_step if record_step_images else None,
|
img_callback=_record_step if record_step_images else None,
|
||||||
add_caption=print_caption,
|
add_caption=print_caption,
|
||||||
|
model_weights_path=model_weights_path,
|
||||||
):
|
):
|
||||||
prompt = result.prompt
|
prompt = result.prompt
|
||||||
img_str = ""
|
img_str = ""
|
||||||
@ -146,8 +163,9 @@ def imagine(
|
|||||||
img_callback=None,
|
img_callback=None,
|
||||||
half_mode=None,
|
half_mode=None,
|
||||||
add_caption=False,
|
add_caption=False,
|
||||||
|
model_weights_path=None,
|
||||||
):
|
):
|
||||||
model = load_model()
|
model = load_model(model_weights_location=model_weights_path)
|
||||||
|
|
||||||
# only run half-mode on cuda. run it by default
|
# only run half-mode on cuda. run it by default
|
||||||
half_mode = half_mode is None and get_device() == "cuda"
|
half_mode = half_mode is None and get_device() == "cuda"
|
||||||
|
@ -4,7 +4,7 @@ import click
|
|||||||
from click_shell import shell
|
from click_shell import shell
|
||||||
|
|
||||||
from imaginairy import LazyLoadingImage, generate_caption
|
from imaginairy import LazyLoadingImage, generate_caption
|
||||||
from imaginairy.api import imagine_image_files, load_model
|
from imaginairy.api import imagine_image_files
|
||||||
from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS
|
from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS
|
||||||
from imaginairy.schema import ImaginePrompt
|
from imaginairy.schema import ImaginePrompt
|
||||||
from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings
|
from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings
|
||||||
@ -173,6 +173,12 @@ def configure_logging(level="INFO"):
|
|||||||
type=click.Choice(["full", "autocast"]),
|
type=click.Choice(["full", "autocast"]),
|
||||||
default="autocast",
|
default="autocast",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model-weights-path",
|
||||||
|
help="path to model weights file. by default we use stable diffusion 1.4",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def imagine_cmd(
|
def imagine_cmd(
|
||||||
ctx,
|
ctx,
|
||||||
@ -201,6 +207,7 @@ def imagine_cmd(
|
|||||||
mask_modify_original,
|
mask_modify_original,
|
||||||
caption,
|
caption,
|
||||||
precision,
|
precision,
|
||||||
|
model_weights_path,
|
||||||
):
|
):
|
||||||
"""Have the AI generate images. alias:imagine"""
|
"""Have the AI generate images. alias:imagine"""
|
||||||
if ctx.invoked_subcommand is not None:
|
if ctx.invoked_subcommand is not None:
|
||||||
@ -223,7 +230,6 @@ def imagine_cmd(
|
|||||||
if fix_faces_fidelity is not None:
|
if fix_faces_fidelity is not None:
|
||||||
fix_faces_fidelity = float(fix_faces_fidelity)
|
fix_faces_fidelity = float(fix_faces_fidelity)
|
||||||
prompts = []
|
prompts = []
|
||||||
load_model()
|
|
||||||
for _ in range(repeats):
|
for _ in range(repeats):
|
||||||
for prompt_text in prompt_texts:
|
for prompt_text in prompt_texts:
|
||||||
prompt = ImaginePrompt(
|
prompt = ImaginePrompt(
|
||||||
@ -255,6 +261,7 @@ def imagine_cmd(
|
|||||||
output_file_extension="jpg",
|
output_file_extension="jpg",
|
||||||
print_caption=caption,
|
print_caption=caption,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
|
model_weights_path=model_weights_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user