|
|
|
@ -73,31 +73,11 @@ def load_model_from_config(config):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def patch_conv(**patch):
|
|
|
|
|
"""
|
|
|
|
|
Patch to enable tiling mode
|
|
|
|
|
|
|
|
|
|
https://github.com/replicate/cog-stable-diffusion/compare/main...TomMoore515:material_stable_diffusion:main
|
|
|
|
|
"""
|
|
|
|
|
cls = torch.nn.Conv2d
|
|
|
|
|
init = cls.__init__
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
return init(self, *args, **kwargs, **patch)
|
|
|
|
|
|
|
|
|
|
cls.__init__ = __init__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache()
|
|
|
|
|
def load_model(tile_mode=False):
|
|
|
|
|
if tile_mode:
|
|
|
|
|
# generated images are tileable
|
|
|
|
|
patch_conv(padding_mode="circular")
|
|
|
|
|
|
|
|
|
|
def load_model():
|
|
|
|
|
config = "configs/stable-diffusion-v1.yaml"
|
|
|
|
|
config = OmegaConf.load(f"{LIB_PATH}/{config}")
|
|
|
|
|
model = load_model_from_config(config)
|
|
|
|
|
|
|
|
|
|
model = model.to(get_device())
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
@ -111,7 +91,6 @@ def imagine_image_files(
|
|
|
|
|
ddim_eta=0.0,
|
|
|
|
|
record_step_images=False,
|
|
|
|
|
output_file_extension="jpg",
|
|
|
|
|
tile_mode=False,
|
|
|
|
|
print_caption=False,
|
|
|
|
|
):
|
|
|
|
|
big_path = os.path.join(outdir, "upscaled")
|
|
|
|
@ -139,7 +118,6 @@ def imagine_image_files(
|
|
|
|
|
precision=precision,
|
|
|
|
|
ddim_eta=ddim_eta,
|
|
|
|
|
img_callback=_record_step if record_step_images else None,
|
|
|
|
|
tile_mode=tile_mode,
|
|
|
|
|
add_caption=print_caption,
|
|
|
|
|
):
|
|
|
|
|
prompt = result.prompt
|
|
|
|
@ -164,11 +142,10 @@ def imagine(
|
|
|
|
|
precision="autocast",
|
|
|
|
|
ddim_eta=0.0,
|
|
|
|
|
img_callback=None,
|
|
|
|
|
tile_mode=False,
|
|
|
|
|
half_mode=None,
|
|
|
|
|
add_caption=False,
|
|
|
|
|
):
|
|
|
|
|
model = load_model(tile_mode=tile_mode)
|
|
|
|
|
model = load_model()
|
|
|
|
|
|
|
|
|
|
# only run half-mode on cuda. run it by default
|
|
|
|
|
half_mode = half_mode is None and get_device() == "cuda"
|
|
|
|
@ -194,6 +171,7 @@ def imagine(
|
|
|
|
|
):
|
|
|
|
|
logger.info(f"Generating {prompt.prompt_description()}")
|
|
|
|
|
seed_everything(prompt.seed)
|
|
|
|
|
model.tile_mode(prompt.tile_mode)
|
|
|
|
|
|
|
|
|
|
uc = None
|
|
|
|
|
if prompt.prompt_strength != 1.0:
|
|
|
|
|