From c7a822d7017b0467e5d786da912f9928e95d1f0e Mon Sep 17 00:00:00 2001 From: Bryce Date: Thu, 15 Sep 2022 23:06:59 -0700 Subject: [PATCH] feature: urls as init images - --init-image accepts urls - cleanup command line code --- Dockerfile | 3 +- README.md | 9 +++-- imaginairy/__init__.py | 7 +++- imaginairy/api.py | 9 +++-- imaginairy/cmd_wrap.py | 12 ------- imaginairy/cmds.py | 9 ++++- imaginairy/schema.py | 67 ++++++++++++++++++++++++++++++++++--- imaginairy/suppress_logs.py | 13 ++++--- imaginairy/utils.py | 1 + setup.py | 2 +- tests/conftest.py | 3 +- tests/test_imagine.py | 18 ++++++++++ tests/test_schema.py | 19 +++++++++++ tox.ini | 4 +-- 14 files changed, 140 insertions(+), 36 deletions(-) delete mode 100644 imaginairy/cmd_wrap.py create mode 100644 tests/test_schema.py diff --git a/Dockerfile b/Dockerfile index e40e653..6fbfdfb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,5 +6,4 @@ ENV PIP_DISABLE_PIP_VERSION_CHECK=1 \ PIP_ROOT_USER_ACTION=ignore RUN pip install imaginairy - -#RUN #imagine pizza \ No newline at end of file +RUN imagine --help diff --git a/README.md b/README.md index 48ec214..5dc160f 100644 --- a/README.md +++ b/README.md @@ -84,15 +84,20 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645 ## How To ```python -from imaginairy import imagine, imagine_image_files, ImaginePrompt, WeightedPrompt +from imaginairy import imagine, imagine_image_files, ImaginePrompt, WeightedPrompt, LazyLoadingImage +url = "https://upload.wikimedia.org/wikipedia/commons/thumb/6/6c/Thomas_Cole_-_Architect%E2%80%99s_Dream_-_Google_Art_Project.jpg/540px-Thomas_Cole_-_Architect%E2%80%99s_Dream_-_Google_Art_Project.jpg" prompts = [ ImaginePrompt("a scenic landscape", seed=1), ImaginePrompt("a bowl of fruit"), ImaginePrompt([ WeightedPrompt("cat", weight=1), WeightedPrompt("dog", weight=1), - ]) + ]), + ImaginePrompt( + "a spacious building", + init_image=LazyLoadingImage(url) + ) ] for result in imagine(prompts): # do something diff --git a/imaginairy/__init__.py b/imaginairy/__init__.py index f65fcfc..f13573b 100644 --- a/imaginairy/__init__.py +++ b/imaginairy/__init__.py @@ -4,4 +4,9 @@ import os os.putenv("PYTORCH_ENABLE_MPS_FALLBACK", "1") from .api import imagine, imagine_image_files # noqa -from .schema import ImaginePrompt, ImagineResult, WeightedPrompt # noqa +from .schema import ( # noqa + ImaginePrompt, + ImagineResult, + LazyLoadingImage, + WeightedPrompt, +) diff --git a/imaginairy/api.py b/imaginairy/api.py index e94eb30..4f8c06c 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -23,8 +23,8 @@ from imaginairy.schema import ImaginePrompt, ImagineResult from imaginairy.utils import ( fix_torch_nn_layer_norm, get_device, - img_path_to_torch_image, instantiate_from_config, + pillow_img_to_torch_image, ) LIB_PATH = os.path.dirname(__file__) @@ -67,6 +67,7 @@ def load_model_from_config(config): def patch_conv(**patch): """ Patch to enable tiling mode + https://github.com/replicate/cog-stable-diffusion/compare/main...TomMoore515:material_stable_diffusion:main """ cls = torch.nn.Conv2d @@ -204,7 +205,11 @@ def imagine( ddim_steps = int(prompt.steps / generation_strength) sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta) - init_image, w, h = img_path_to_torch_image(prompt.init_image) + init_image, w, h = pillow_img_to_torch_image( + prompt.init_image, + max_height=prompt.height, + max_width=prompt.width, + ) init_image = init_image.to(get_device()) init_latent = model.get_first_stage_encoding( model.encode_first_stage(init_image) diff --git a/imaginairy/cmd_wrap.py b/imaginairy/cmd_wrap.py deleted file mode 100644 index 6ed4353..0000000 --- a/imaginairy/cmd_wrap.py +++ /dev/null @@ -1,12 +0,0 @@ -def imagine_cmd(*args, **kwargs): - from .suppress_logs import suppress_annoying_logs_and_warnings # noqa - - suppress_annoying_logs_and_warnings() - - from imaginairy.cmds import imagine_cmd as imagine_cmd_orig # noqa - - imagine_cmd_orig(*args, **kwargs) - - -if __name__ == "__main__": - imagine_cmd() # noqa diff --git a/imaginairy/cmds.py b/imaginairy/cmds.py index 839e2b2..99954b4 100644 --- a/imaginairy/cmds.py +++ b/imaginairy/cmds.py @@ -2,8 +2,10 @@ import logging.config import click +from imaginairy import LazyLoadingImage from imaginairy.api import load_model from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS +from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings logger = logging.getLogger(__name__) @@ -54,7 +56,7 @@ def configure_logging(level="INFO"): ) @click.option( "--init-image", - help="Starting image.", + help="Starting image. filepath or url", ) @click.option( "--init-image-strength", @@ -146,6 +148,7 @@ def imagine_cmd( tile, ): """Render an image""" + suppress_annoying_logs_and_warnings() configure_logging(log_level) from imaginairy.api import imagine_image_files @@ -159,10 +162,14 @@ def imagine_cmd( sampler_type = "ddim" logger.info(" Sampler type switched to ddim for img2img") + if init_image and init_image.startswith("http"): + init_image = LazyLoadingImage(url=init_image) + prompts = [] load_model(tile_mode=tile) for _ in range(repeats): for prompt_text in prompt_texts: + prompt = ImaginePrompt( prompt_text, prompt_strength=prompt_strength, diff --git a/imaginairy/schema.py b/imaginairy/schema.py index abfab5f..8cacb32 100644 --- a/imaginairy/schema.py +++ b/imaginairy/schema.py @@ -1,13 +1,70 @@ import hashlib import json +import logging +import os.path import random from datetime import datetime, timezone import numpy -from PIL.Image import Exif +import requests +from PIL import Image +from urllib3.exceptions import LocationParseError +from urllib3.util import parse_url from imaginairy.utils import get_device, get_device_name +logger = logging.getLogger(__name__) + + +class InvalidUrlError(ValueError): + pass + + +class LazyLoadingImage: + def __init__(self, *, filepath=None, url=None): + if not filepath and not url: + raise ValueError("You must specify a url or filepath") + if filepath and url: + raise ValueError("You cannot specify a url and filepath") + + # validate file exists + if filepath and not os.path.exists(filepath): + raise FileNotFoundError(f"File does not exist: {filepath}") + + # validate url is valid url + if url: + try: + parsed_url = parse_url(url) + except LocationParseError: + raise InvalidUrlError(f"Invalid url: {url}") + if parsed_url.scheme not in {"http", "https"} or not parsed_url.host: + raise InvalidUrlError(f"Invalid url: {url}") + + self._lazy_filepath = filepath + self._lazy_url = url + self._img = None + + def __getattr__(self, key): + if key == "_img": + # http://nedbatchelder.com/blog/201010/surprising_getattr_recursion.html + raise AttributeError() + + if self._lazy_filepath: + self._img = Image.open(self._lazy_filepath) + logger.info( + f"Loaded input 🖼 of size {self._img.size} from {self._lazy_filepath}" + ) + elif self._lazy_url: + self._img = Image.open(requests.get(self._lazy_url, stream=True).raw) + logger.info( + f"Loaded input 🖼 of size {self._img.size} from {self._lazy_url}" + ) + + return getattr(self._img, key) + + def __str__(self): + return self._lazy_filepath or self._lazy_url + class WeightedPrompt: def __init__(self, text, weight=1): @@ -23,7 +80,7 @@ class ImaginePrompt: self, prompt=None, prompt_strength=7.5, - init_image=None, + init_image=None, # Pillow Image, LazyLoadingImage, or filepath str init_image_strength=0.3, seed=None, steps=50, @@ -40,6 +97,8 @@ class ImaginePrompt: self.prompts = prompt self.prompts.sort(key=lambda p: p.weight, reverse=True) self.prompt_strength = prompt_strength + if isinstance(init_image, str): + init_image = LazyLoadingImage(filepath=init_image) self.init_image = init_image self.init_image_strength = init_image_strength self.seed = random.randint(1, 1_000_000_000) if seed is None else seed @@ -68,7 +127,7 @@ class ImaginePrompt: "software": "imaginairy", "prompts": prompts, "prompt_strength": self.prompt_strength, - "init_image": self.init_image, + "init_image": str(self.init_image), "init_image_strength": self.init_image_strength, "seed": self.seed, "steps": self.steps, @@ -116,7 +175,7 @@ class ImagineResult: } def _exif(self): - exif = Exif() + exif = Image.Exif() exif[ExifCodes.ImageDescription] = self.prompt.prompt_description() exif[ExifCodes.UserComment] = json.dumps(self.metadata_dict()) # help future web scrapes not ingest AI generated art diff --git a/imaginairy/suppress_logs.py b/imaginairy/suppress_logs.py index bec2214..244d5ea 100644 --- a/imaginairy/suppress_logs.py +++ b/imaginairy/suppress_logs.py @@ -1,22 +1,21 @@ import logging.config import warnings +from pytorch_lightning import _logger as pytorch_logger +from transformers.modeling_utils import logger as modeling_logger +from transformers.utils.logging import _configure_library_root_logger -def disable_transformers_custom_logging(): - from transformers.modeling_utils import logger - from transformers.utils.logging import _configure_library_root_logger +def disable_transformers_custom_logging(): _configure_library_root_logger() - logger = logger.parent + logger = modeling_logger.parent logger.handlers = [] logger.propagate = True logger.setLevel(logging.NOTSET) def disable_pytorch_lighting_custom_logging(): - from pytorch_lightning import _logger - - _logger.setLevel(logging.NOTSET) + pytorch_logger.setLevel(logging.NOTSET) def disable_common_warnings(): diff --git a/imaginairy/utils.py b/imaginairy/utils.py index 7de7acc..ab60915 100644 --- a/imaginairy/utils.py +++ b/imaginairy/utils.py @@ -107,6 +107,7 @@ def img_path_to_torch_image(path, max_height=512, max_width=512): def pillow_img_to_torch_image(image, max_height=512, max_width=512): + image = image.convert("RGB") w, h = image.size resize_ratio = min(max_width / w, max_height / h) w, h = int(w * resize_ratio), int(h * resize_ratio) diff --git a/setup.py b/setup.py index bf33914..5c35982 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ setup( }, packages=find_packages(include=("imaginairy", "imaginairy.*")), entry_points={ - "console_scripts": ["imagine=imaginairy.cmd_wrap:imagine_cmd"], + "console_scripts": ["imagine=imaginairy.cmds:imagine_cmd"], }, package_data={"imaginairy": ["configs/*.yaml", "vendored/clip/*.txt.gz"]}, install_requires=[ diff --git a/tests/conftest.py b/tests/conftest.py index 9983fb4..1e04441 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import sys import pytest +from imaginairy import api from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings if "pytest" in str(sys.argv): @@ -10,6 +11,4 @@ if "pytest" in str(sys.argv): @pytest.fixture(scope="session", autouse=True) def pre_setup(): - from imaginairy import api - api.IMAGINAIRY_SAFETY_MODE = "disabled" diff --git a/tests/test_imagine.py b/tests/test_imagine.py index 711f680..3b1fba5 100644 --- a/tests/test_imagine.py +++ b/tests/test_imagine.py @@ -1,5 +1,6 @@ import pytest +from imaginairy import LazyLoadingImage from imaginairy.api import imagine, imagine_image_files from imaginairy.schema import ImaginePrompt from imaginairy.utils import get_device @@ -59,6 +60,23 @@ def test_img_to_img(): imagine_image_files(prompt, outdir=out_folder) +def test_img_to_img_from_url(): + prompt = ImaginePrompt( + "dogs lying on a hot pink couch", + init_image=LazyLoadingImage( + url="http://images.cocodataset.org/val2017/000000039769.jpg" + ), + init_image_strength=0.5, + width=512, + height=512, + steps=50, + seed=1, + sampler_type="DDIM", + ) + out_folder = f"{TESTS_FOLDER}/test_output" + imagine_image_files(prompt, outdir=out_folder) + + def test_img_to_file(): prompt = ImaginePrompt( "an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo", diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..4b31882 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,19 @@ +import pytest + +from imaginairy import LazyLoadingImage +from imaginairy.schema import InvalidUrlError +from tests import TESTS_FOLDER + + +def test_lazy_load_image(): + with pytest.raises(ValueError, match=r".*specify a url or filepath.*"): + LazyLoadingImage() + + with pytest.raises(FileNotFoundError, match=r".*File does not exist.*"): + LazyLoadingImage(filepath="/tmp/bterpojirewpdfsn/ergqgr") + + with pytest.raises(InvalidUrlError): + LazyLoadingImage(url="/tmp/bterpojirewpdfsn/ergqgr") + + img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg") + assert img.size == (1686, 1246) diff --git a/tox.ini b/tox.ini index a7338b9..d8a1b9b 100644 --- a/tox.ini +++ b/tox.ini @@ -8,11 +8,11 @@ filterwarnings = [pylama] format = pylint -skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads/* +skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads/*,imaginairy/vendored/* linters = pylint,pycodestyle,pydocstyle,pyflakes,mypy ignore = Z999,C0103,C0301,C0114,C0115,C0116, - Z999,D100,D101,D102,D103,D107,D202,D203,D212,D400,D401,D415, + Z999,D100,D101,D102,D103,D105,D107,D202,D203,D212,D400,D401,D415, Z999,E501,E1101, Z999,R0901,R0902,R0903,R0193,R0912,R0913,R0914,R0915, Z999,W0221,W0511,W1203