2022-09-10 05:14:04 +00:00
|
|
|
import hashlib
|
2022-09-11 06:27:22 +00:00
|
|
|
import json
|
2022-09-16 06:06:59 +00:00
|
|
|
import logging
|
|
|
|
import os.path
|
2022-09-10 05:14:04 +00:00
|
|
|
import random
|
2022-09-11 06:27:22 +00:00
|
|
|
from datetime import datetime, timezone
|
2022-09-17 05:34:42 +00:00
|
|
|
from functools import lru_cache
|
2022-09-10 05:14:04 +00:00
|
|
|
|
2022-09-16 06:06:59 +00:00
|
|
|
import requests
|
2022-09-24 05:58:48 +00:00
|
|
|
from PIL import Image, ImageOps
|
2022-09-16 06:06:59 +00:00
|
|
|
from urllib3.exceptions import LocationParseError
|
|
|
|
from urllib3.util import parse_url
|
2022-09-11 06:27:22 +00:00
|
|
|
|
2022-10-04 22:07:40 +00:00
|
|
|
from imaginairy.utils import get_device, get_hardware_description
|
2022-09-10 05:14:04 +00:00
|
|
|
|
2022-09-16 06:06:59 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class InvalidUrlError(ValueError):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class LazyLoadingImage:
|
|
|
|
def __init__(self, *, filepath=None, url=None):
|
|
|
|
if not filepath and not url:
|
|
|
|
raise ValueError("You must specify a url or filepath")
|
|
|
|
if filepath and url:
|
|
|
|
raise ValueError("You cannot specify a url and filepath")
|
|
|
|
|
|
|
|
# validate file exists
|
|
|
|
if filepath and not os.path.exists(filepath):
|
|
|
|
raise FileNotFoundError(f"File does not exist: {filepath}")
|
|
|
|
|
|
|
|
# validate url is valid url
|
|
|
|
if url:
|
|
|
|
try:
|
|
|
|
parsed_url = parse_url(url)
|
|
|
|
except LocationParseError:
|
2022-09-16 16:24:24 +00:00
|
|
|
raise InvalidUrlError(f"Invalid url: {url}") # noqa
|
2022-09-16 06:06:59 +00:00
|
|
|
if parsed_url.scheme not in {"http", "https"} or not parsed_url.host:
|
|
|
|
raise InvalidUrlError(f"Invalid url: {url}")
|
|
|
|
|
|
|
|
self._lazy_filepath = filepath
|
|
|
|
self._lazy_url = url
|
|
|
|
self._img = None
|
|
|
|
|
|
|
|
def __getattr__(self, key):
|
|
|
|
if key == "_img":
|
|
|
|
# http://nedbatchelder.com/blog/201010/surprising_getattr_recursion.html
|
|
|
|
raise AttributeError()
|
2022-09-17 05:34:42 +00:00
|
|
|
if self._img:
|
|
|
|
return getattr(self._img, key)
|
2022-09-16 06:06:59 +00:00
|
|
|
|
|
|
|
if self._lazy_filepath:
|
|
|
|
self._img = Image.open(self._lazy_filepath)
|
|
|
|
logger.info(
|
|
|
|
f"Loaded input 🖼 of size {self._img.size} from {self._lazy_filepath}"
|
|
|
|
)
|
|
|
|
elif self._lazy_url:
|
2022-09-17 05:21:20 +00:00
|
|
|
self._img = Image.open(
|
|
|
|
requests.get(self._lazy_url, stream=True, timeout=60).raw
|
|
|
|
)
|
2022-09-16 06:06:59 +00:00
|
|
|
logger.info(
|
|
|
|
f"Loaded input 🖼 of size {self._img.size} from {self._lazy_url}"
|
|
|
|
)
|
2022-09-24 05:58:48 +00:00
|
|
|
# fix orientation
|
|
|
|
self._img = ImageOps.exif_transpose(self._img)
|
2022-09-16 06:06:59 +00:00
|
|
|
|
|
|
|
return getattr(self._img, key)
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return self._lazy_filepath or self._lazy_url
|
|
|
|
|
2022-09-10 05:14:04 +00:00
|
|
|
|
|
|
|
class WeightedPrompt:
|
|
|
|
def __init__(self, text, weight=1):
|
|
|
|
self.text = text
|
|
|
|
self.weight = weight
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return f"{self.weight}*({self.text})"
|
|
|
|
|
|
|
|
|
|
|
|
class ImaginePrompt:
|
2022-09-18 13:07:07 +00:00
|
|
|
class MaskMode:
|
|
|
|
KEEP = "keep"
|
|
|
|
REPLACE = "replace"
|
|
|
|
|
2022-10-07 00:21:01 +00:00
|
|
|
DEFAULT_FACE_FIDELITY = 0.2
|
|
|
|
|
2022-09-10 05:14:04 +00:00
|
|
|
def __init__(
|
2022-09-17 19:24:27 +00:00
|
|
|
self,
|
|
|
|
prompt=None,
|
|
|
|
prompt_strength=7.5,
|
|
|
|
init_image=None, # Pillow Image, LazyLoadingImage, or filepath str
|
|
|
|
init_image_strength=0.3,
|
2022-09-18 13:07:07 +00:00
|
|
|
mask_prompt=None,
|
|
|
|
mask_image=None,
|
|
|
|
mask_mode=MaskMode.REPLACE,
|
2022-09-25 20:07:27 +00:00
|
|
|
mask_modify_original=True,
|
2022-09-17 19:24:27 +00:00
|
|
|
seed=None,
|
|
|
|
steps=50,
|
|
|
|
height=512,
|
|
|
|
width=512,
|
|
|
|
upscale=False,
|
|
|
|
fix_faces=False,
|
2022-10-07 00:21:01 +00:00
|
|
|
fix_faces_fidelity=DEFAULT_FACE_FIDELITY,
|
2022-09-17 19:24:27 +00:00
|
|
|
sampler_type="PLMS",
|
|
|
|
conditioning=None,
|
2022-09-21 05:57:03 +00:00
|
|
|
tile_mode=False,
|
2022-09-10 05:14:04 +00:00
|
|
|
):
|
2022-10-07 00:21:01 +00:00
|
|
|
prompt = prompt if prompt is not None else ""
|
|
|
|
fix_faces_fidelity = (
|
|
|
|
fix_faces_fidelity if fix_faces_fidelity else self.DEFAULT_FACE_FIDELITY
|
|
|
|
)
|
2022-09-10 05:14:04 +00:00
|
|
|
if isinstance(prompt, str):
|
|
|
|
self.prompts = [WeightedPrompt(prompt, 1)]
|
|
|
|
else:
|
|
|
|
self.prompts = prompt
|
2022-09-11 06:27:22 +00:00
|
|
|
self.prompts.sort(key=lambda p: p.weight, reverse=True)
|
|
|
|
self.prompt_strength = prompt_strength
|
2022-09-16 06:06:59 +00:00
|
|
|
if isinstance(init_image, str):
|
|
|
|
init_image = LazyLoadingImage(filepath=init_image)
|
2022-09-24 05:58:48 +00:00
|
|
|
|
2022-09-23 18:33:11 +00:00
|
|
|
if isinstance(mask_image, str):
|
|
|
|
mask_image = LazyLoadingImage(filepath=mask_image)
|
2022-09-18 13:07:07 +00:00
|
|
|
|
|
|
|
if mask_image is not None and mask_prompt is not None:
|
|
|
|
raise ValueError("You can only set one of `mask_image` and `mask_prompt`")
|
|
|
|
|
2022-09-10 05:14:04 +00:00
|
|
|
self.init_image = init_image
|
|
|
|
self.init_image_strength = init_image_strength
|
|
|
|
self.seed = random.randint(1, 1_000_000_000) if seed is None else seed
|
|
|
|
self.steps = steps
|
|
|
|
self.height = height
|
|
|
|
self.width = width
|
|
|
|
self.upscale = upscale
|
|
|
|
self.fix_faces = fix_faces
|
2022-10-03 17:26:08 +00:00
|
|
|
self.fix_faces_fidelity = fix_faces_fidelity
|
2022-09-11 06:27:22 +00:00
|
|
|
self.sampler_type = sampler_type
|
2022-09-17 05:21:20 +00:00
|
|
|
self.conditioning = conditioning
|
2022-09-18 13:07:07 +00:00
|
|
|
self.mask_prompt = mask_prompt
|
|
|
|
self.mask_image = mask_image
|
|
|
|
self.mask_mode = mask_mode
|
2022-09-25 20:07:27 +00:00
|
|
|
self.mask_modify_original = mask_modify_original
|
2022-09-21 05:57:03 +00:00
|
|
|
self.tile_mode = tile_mode
|
2022-09-10 05:14:04 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def prompt_text(self):
|
|
|
|
if len(self.prompts) == 1:
|
|
|
|
return self.prompts[0].text
|
|
|
|
return "|".join(str(p) for p in self.prompts)
|
|
|
|
|
2022-09-11 06:27:22 +00:00
|
|
|
def prompt_description(self):
|
|
|
|
return (
|
|
|
|
f'🖼 : "{self.prompt_text}" {self.width}x{self.height}px '
|
|
|
|
f"seed:{self.seed} prompt-strength:{self.prompt_strength} steps:{self.steps} sampler-type:{self.sampler_type}"
|
|
|
|
)
|
|
|
|
|
|
|
|
def as_dict(self):
|
|
|
|
prompts = [(p.weight, p.text) for p in self.prompts]
|
|
|
|
return {
|
|
|
|
"software": "imaginairy",
|
|
|
|
"prompts": prompts,
|
|
|
|
"prompt_strength": self.prompt_strength,
|
2022-09-16 06:06:59 +00:00
|
|
|
"init_image": str(self.init_image),
|
2022-09-11 06:27:22 +00:00
|
|
|
"init_image_strength": self.init_image_strength,
|
|
|
|
"seed": self.seed,
|
|
|
|
"steps": self.steps,
|
|
|
|
"height": self.height,
|
|
|
|
"width": self.width,
|
|
|
|
"upscale": self.upscale,
|
|
|
|
"fix_faces": self.fix_faces,
|
|
|
|
"sampler_type": self.sampler_type,
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class ExifCodes:
|
|
|
|
"""https://www.awaresystems.be/imaging/tiff/tifftags/baseline.html"""
|
2022-09-11 07:35:57 +00:00
|
|
|
|
2022-09-11 06:27:22 +00:00
|
|
|
ImageDescription = 0x010E
|
|
|
|
Software = 0x0131
|
|
|
|
DateTime = 0x0132
|
|
|
|
HostComputer = 0x013C
|
|
|
|
UserComment = 0x9286
|
|
|
|
|
2022-09-10 05:14:04 +00:00
|
|
|
|
|
|
|
class ImagineResult:
|
2022-09-24 05:58:48 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
img,
|
|
|
|
prompt: ImaginePrompt,
|
|
|
|
is_nsfw,
|
2022-10-10 08:22:11 +00:00
|
|
|
safety_score,
|
2022-09-24 05:58:48 +00:00
|
|
|
upscaled_img=None,
|
2022-09-26 04:55:25 +00:00
|
|
|
modified_original=None,
|
2022-09-25 20:07:27 +00:00
|
|
|
mask_binary=None,
|
|
|
|
mask_grayscale=None,
|
2022-09-24 05:58:48 +00:00
|
|
|
):
|
2022-09-25 20:07:27 +00:00
|
|
|
self.prompt = prompt
|
|
|
|
|
|
|
|
self.images = {"generated": img}
|
|
|
|
|
|
|
|
if upscaled_img:
|
|
|
|
self.images["upscaled"] = upscaled_img
|
|
|
|
|
2022-09-26 04:55:25 +00:00
|
|
|
if modified_original:
|
|
|
|
self.images["modified_original"] = modified_original
|
2022-09-25 20:07:27 +00:00
|
|
|
|
|
|
|
if mask_binary:
|
|
|
|
self.images["mask_binary"] = mask_binary
|
|
|
|
|
|
|
|
if mask_grayscale:
|
|
|
|
self.images["mask_grayscale"] = mask_grayscale
|
|
|
|
|
|
|
|
# for backward compat
|
2022-09-10 05:14:04 +00:00
|
|
|
self.img = img
|
2022-09-13 07:27:53 +00:00
|
|
|
self.upscaled_img = upscaled_img
|
2022-09-25 20:07:27 +00:00
|
|
|
|
2022-09-15 02:40:50 +00:00
|
|
|
self.is_nsfw = is_nsfw
|
2022-10-10 08:22:11 +00:00
|
|
|
self.safety_score = safety_score
|
2022-09-11 06:27:22 +00:00
|
|
|
self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc)
|
|
|
|
self.torch_backend = get_device()
|
2022-10-04 22:07:40 +00:00
|
|
|
self.hardware_name = get_hardware_description(get_device())
|
2022-09-10 05:14:04 +00:00
|
|
|
|
|
|
|
def md5(self):
|
|
|
|
return hashlib.md5(self.img.tobytes()).hexdigest()
|
2022-09-11 06:27:22 +00:00
|
|
|
|
|
|
|
def metadata_dict(self):
|
|
|
|
return {
|
|
|
|
"prompt": self.prompt.as_dict(),
|
|
|
|
}
|
|
|
|
|
2022-09-13 07:27:53 +00:00
|
|
|
def _exif(self):
|
2022-09-16 06:06:59 +00:00
|
|
|
exif = Image.Exif()
|
2022-09-11 06:27:22 +00:00
|
|
|
exif[ExifCodes.ImageDescription] = self.prompt.prompt_description()
|
|
|
|
exif[ExifCodes.UserComment] = json.dumps(self.metadata_dict())
|
|
|
|
# help future web scrapes not ingest AI generated art
|
|
|
|
exif[ExifCodes.Software] = "Imaginairy / Stable Diffusion v1.4"
|
|
|
|
exif[ExifCodes.DateTime] = self.created_at.isoformat(sep=" ")[:19]
|
|
|
|
exif[ExifCodes.HostComputer] = f"{self.torch_backend}:{self.hardware_name}"
|
2022-09-13 07:27:53 +00:00
|
|
|
return exif
|
|
|
|
|
2022-09-25 20:07:27 +00:00
|
|
|
def save(self, save_path, image_type="generated"):
|
|
|
|
img = self.images.get(image_type, None)
|
|
|
|
if img is None:
|
|
|
|
raise ValueError(
|
|
|
|
f"Image of type {image_type} not stored. Options are: {self.images.keys()}"
|
|
|
|
)
|
2022-09-17 05:34:42 +00:00
|
|
|
|
2022-09-25 20:07:27 +00:00
|
|
|
img.convert("RGB").save(save_path, exif=self._exif())
|
2022-09-24 05:58:48 +00:00
|
|
|
|
2022-09-17 05:34:42 +00:00
|
|
|
|
|
|
|
@lru_cache(maxsize=2)
|
|
|
|
def _get_briefly_cached_url(url):
|
|
|
|
return requests.get(url, timeout=60)
|