feature: urls as init images

- --init-image accepts urls
- cleanup command line code
This commit is contained in:
Bryce 2022-09-15 23:06:59 -07:00
parent 47db34647b
commit c7a822d701
14 changed files with 140 additions and 36 deletions

View File

@ -6,5 +6,4 @@ ENV PIP_DISABLE_PIP_VERSION_CHECK=1 \
PIP_ROOT_USER_ACTION=ignore
RUN pip install imaginairy
#RUN #imagine pizza
RUN imagine --help

View File

@ -84,15 +84,20 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
## How To
```python
from imaginairy import imagine, imagine_image_files, ImaginePrompt, WeightedPrompt
from imaginairy import imagine, imagine_image_files, ImaginePrompt, WeightedPrompt, LazyLoadingImage
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/6/6c/Thomas_Cole_-_Architect%E2%80%99s_Dream_-_Google_Art_Project.jpg/540px-Thomas_Cole_-_Architect%E2%80%99s_Dream_-_Google_Art_Project.jpg"
prompts = [
ImaginePrompt("a scenic landscape", seed=1),
ImaginePrompt("a bowl of fruit"),
ImaginePrompt([
WeightedPrompt("cat", weight=1),
WeightedPrompt("dog", weight=1),
])
]),
ImaginePrompt(
"a spacious building",
init_image=LazyLoadingImage(url)
)
]
for result in imagine(prompts):
# do something

View File

@ -4,4 +4,9 @@ import os
os.putenv("PYTORCH_ENABLE_MPS_FALLBACK", "1")
from .api import imagine, imagine_image_files # noqa
from .schema import ImaginePrompt, ImagineResult, WeightedPrompt # noqa
from .schema import ( # noqa
ImaginePrompt,
ImagineResult,
LazyLoadingImage,
WeightedPrompt,
)

View File

@ -23,8 +23,8 @@ from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import (
fix_torch_nn_layer_norm,
get_device,
img_path_to_torch_image,
instantiate_from_config,
pillow_img_to_torch_image,
)
LIB_PATH = os.path.dirname(__file__)
@ -67,6 +67,7 @@ def load_model_from_config(config):
def patch_conv(**patch):
"""
Patch to enable tiling mode
https://github.com/replicate/cog-stable-diffusion/compare/main...TomMoore515:material_stable_diffusion:main
"""
cls = torch.nn.Conv2d
@ -204,7 +205,11 @@ def imagine(
ddim_steps = int(prompt.steps / generation_strength)
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
init_image, w, h = img_path_to_torch_image(prompt.init_image)
init_image, w, h = pillow_img_to_torch_image(
prompt.init_image,
max_height=prompt.height,
max_width=prompt.width,
)
init_image = init_image.to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image)

View File

@ -1,12 +0,0 @@
def imagine_cmd(*args, **kwargs):
from .suppress_logs import suppress_annoying_logs_and_warnings # noqa
suppress_annoying_logs_and_warnings()
from imaginairy.cmds import imagine_cmd as imagine_cmd_orig # noqa
imagine_cmd_orig(*args, **kwargs)
if __name__ == "__main__":
imagine_cmd() # noqa

View File

@ -2,8 +2,10 @@ import logging.config
import click
from imaginairy import LazyLoadingImage
from imaginairy.api import load_model
from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS
from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings
logger = logging.getLogger(__name__)
@ -54,7 +56,7 @@ def configure_logging(level="INFO"):
)
@click.option(
"--init-image",
help="Starting image.",
help="Starting image. filepath or url",
)
@click.option(
"--init-image-strength",
@ -146,6 +148,7 @@ def imagine_cmd(
tile,
):
"""Render an image"""
suppress_annoying_logs_and_warnings()
configure_logging(log_level)
from imaginairy.api import imagine_image_files
@ -159,10 +162,14 @@ def imagine_cmd(
sampler_type = "ddim"
logger.info(" Sampler type switched to ddim for img2img")
if init_image and init_image.startswith("http"):
init_image = LazyLoadingImage(url=init_image)
prompts = []
load_model(tile_mode=tile)
for _ in range(repeats):
for prompt_text in prompt_texts:
prompt = ImaginePrompt(
prompt_text,
prompt_strength=prompt_strength,

View File

@ -1,13 +1,70 @@
import hashlib
import json
import logging
import os.path
import random
from datetime import datetime, timezone
import numpy
from PIL.Image import Exif
import requests
from PIL import Image
from urllib3.exceptions import LocationParseError
from urllib3.util import parse_url
from imaginairy.utils import get_device, get_device_name
logger = logging.getLogger(__name__)
class InvalidUrlError(ValueError):
pass
class LazyLoadingImage:
def __init__(self, *, filepath=None, url=None):
if not filepath and not url:
raise ValueError("You must specify a url or filepath")
if filepath and url:
raise ValueError("You cannot specify a url and filepath")
# validate file exists
if filepath and not os.path.exists(filepath):
raise FileNotFoundError(f"File does not exist: {filepath}")
# validate url is valid url
if url:
try:
parsed_url = parse_url(url)
except LocationParseError:
raise InvalidUrlError(f"Invalid url: {url}")
if parsed_url.scheme not in {"http", "https"} or not parsed_url.host:
raise InvalidUrlError(f"Invalid url: {url}")
self._lazy_filepath = filepath
self._lazy_url = url
self._img = None
def __getattr__(self, key):
if key == "_img":
# http://nedbatchelder.com/blog/201010/surprising_getattr_recursion.html
raise AttributeError()
if self._lazy_filepath:
self._img = Image.open(self._lazy_filepath)
logger.info(
f"Loaded input 🖼 of size {self._img.size} from {self._lazy_filepath}"
)
elif self._lazy_url:
self._img = Image.open(requests.get(self._lazy_url, stream=True).raw)
logger.info(
f"Loaded input 🖼 of size {self._img.size} from {self._lazy_url}"
)
return getattr(self._img, key)
def __str__(self):
return self._lazy_filepath or self._lazy_url
class WeightedPrompt:
def __init__(self, text, weight=1):
@ -23,7 +80,7 @@ class ImaginePrompt:
self,
prompt=None,
prompt_strength=7.5,
init_image=None,
init_image=None, # Pillow Image, LazyLoadingImage, or filepath str
init_image_strength=0.3,
seed=None,
steps=50,
@ -40,6 +97,8 @@ class ImaginePrompt:
self.prompts = prompt
self.prompts.sort(key=lambda p: p.weight, reverse=True)
self.prompt_strength = prompt_strength
if isinstance(init_image, str):
init_image = LazyLoadingImage(filepath=init_image)
self.init_image = init_image
self.init_image_strength = init_image_strength
self.seed = random.randint(1, 1_000_000_000) if seed is None else seed
@ -68,7 +127,7 @@ class ImaginePrompt:
"software": "imaginairy",
"prompts": prompts,
"prompt_strength": self.prompt_strength,
"init_image": self.init_image,
"init_image": str(self.init_image),
"init_image_strength": self.init_image_strength,
"seed": self.seed,
"steps": self.steps,
@ -116,7 +175,7 @@ class ImagineResult:
}
def _exif(self):
exif = Exif()
exif = Image.Exif()
exif[ExifCodes.ImageDescription] = self.prompt.prompt_description()
exif[ExifCodes.UserComment] = json.dumps(self.metadata_dict())
# help future web scrapes not ingest AI generated art

View File

@ -1,22 +1,21 @@
import logging.config
import warnings
from pytorch_lightning import _logger as pytorch_logger
from transformers.modeling_utils import logger as modeling_logger
from transformers.utils.logging import _configure_library_root_logger
def disable_transformers_custom_logging():
from transformers.modeling_utils import logger
from transformers.utils.logging import _configure_library_root_logger
_configure_library_root_logger()
logger = logger.parent
logger = modeling_logger.parent
logger.handlers = []
logger.propagate = True
logger.setLevel(logging.NOTSET)
def disable_pytorch_lighting_custom_logging():
from pytorch_lightning import _logger
_logger.setLevel(logging.NOTSET)
pytorch_logger.setLevel(logging.NOTSET)
def disable_common_warnings():

View File

@ -107,6 +107,7 @@ def img_path_to_torch_image(path, max_height=512, max_width=512):
def pillow_img_to_torch_image(image, max_height=512, max_width=512):
image = image.convert("RGB")
w, h = image.size
resize_ratio = min(max_width / w, max_height / h)
w, h = int(w * resize_ratio), int(h * resize_ratio)

View File

@ -17,7 +17,7 @@ setup(
},
packages=find_packages(include=("imaginairy", "imaginairy.*")),
entry_points={
"console_scripts": ["imagine=imaginairy.cmd_wrap:imagine_cmd"],
"console_scripts": ["imagine=imaginairy.cmds:imagine_cmd"],
},
package_data={"imaginairy": ["configs/*.yaml", "vendored/clip/*.txt.gz"]},
install_requires=[

View File

@ -2,6 +2,7 @@ import sys
import pytest
from imaginairy import api
from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings
if "pytest" in str(sys.argv):
@ -10,6 +11,4 @@ if "pytest" in str(sys.argv):
@pytest.fixture(scope="session", autouse=True)
def pre_setup():
from imaginairy import api
api.IMAGINAIRY_SAFETY_MODE = "disabled"

View File

@ -1,5 +1,6 @@
import pytest
from imaginairy import LazyLoadingImage
from imaginairy.api import imagine, imagine_image_files
from imaginairy.schema import ImaginePrompt
from imaginairy.utils import get_device
@ -59,6 +60,23 @@ def test_img_to_img():
imagine_image_files(prompt, outdir=out_folder)
def test_img_to_img_from_url():
prompt = ImaginePrompt(
"dogs lying on a hot pink couch",
init_image=LazyLoadingImage(
url="http://images.cocodataset.org/val2017/000000039769.jpg"
),
init_image_strength=0.5,
width=512,
height=512,
steps=50,
seed=1,
sampler_type="DDIM",
)
out_folder = f"{TESTS_FOLDER}/test_output"
imagine_image_files(prompt, outdir=out_folder)
def test_img_to_file():
prompt = ImaginePrompt(
"an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo",

19
tests/test_schema.py Normal file
View File

@ -0,0 +1,19 @@
import pytest
from imaginairy import LazyLoadingImage
from imaginairy.schema import InvalidUrlError
from tests import TESTS_FOLDER
def test_lazy_load_image():
with pytest.raises(ValueError, match=r".*specify a url or filepath.*"):
LazyLoadingImage()
with pytest.raises(FileNotFoundError, match=r".*File does not exist.*"):
LazyLoadingImage(filepath="/tmp/bterpojirewpdfsn/ergqgr")
with pytest.raises(InvalidUrlError):
LazyLoadingImage(url="/tmp/bterpojirewpdfsn/ergqgr")
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
assert img.size == (1686, 1246)

View File

@ -8,11 +8,11 @@ filterwarnings =
[pylama]
format = pylint
skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads/*
skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads/*,imaginairy/vendored/*
linters = pylint,pycodestyle,pydocstyle,pyflakes,mypy
ignore =
Z999,C0103,C0301,C0114,C0115,C0116,
Z999,D100,D101,D102,D103,D107,D202,D203,D212,D400,D401,D415,
Z999,D100,D101,D102,D103,D105,D107,D202,D203,D212,D400,D401,D415,
Z999,E501,E1101,
Z999,R0901,R0902,R0903,R0193,R0912,R0913,R0914,R0915,
Z999,W0221,W0511,W1203