imaginAIry/imaginairy/schema.py

192 lines
6.0 KiB
Python
Raw Normal View History

import hashlib
import json
import logging
import os.path
import random
from datetime import datetime, timezone
import numpy
import requests
from PIL import Image
from urllib3.exceptions import LocationParseError
from urllib3.util import parse_url
from imaginairy.utils import get_device, get_device_name
logger = logging.getLogger(__name__)
class InvalidUrlError(ValueError):
pass
class LazyLoadingImage:
def __init__(self, *, filepath=None, url=None):
if not filepath and not url:
raise ValueError("You must specify a url or filepath")
if filepath and url:
raise ValueError("You cannot specify a url and filepath")
# validate file exists
if filepath and not os.path.exists(filepath):
raise FileNotFoundError(f"File does not exist: {filepath}")
# validate url is valid url
if url:
try:
parsed_url = parse_url(url)
except LocationParseError:
2022-09-16 16:24:24 +00:00
raise InvalidUrlError(f"Invalid url: {url}") # noqa
if parsed_url.scheme not in {"http", "https"} or not parsed_url.host:
raise InvalidUrlError(f"Invalid url: {url}")
self._lazy_filepath = filepath
self._lazy_url = url
self._img = None
def __getattr__(self, key):
if key == "_img":
# http://nedbatchelder.com/blog/201010/surprising_getattr_recursion.html
raise AttributeError()
if self._lazy_filepath:
self._img = Image.open(self._lazy_filepath)
logger.info(
f"Loaded input 🖼 of size {self._img.size} from {self._lazy_filepath}"
)
elif self._lazy_url:
2022-09-16 16:24:24 +00:00
self._img = Image.open(requests.get(self._lazy_url, stream=True, timeout=60).raw)
logger.info(
f"Loaded input 🖼 of size {self._img.size} from {self._lazy_url}"
)
return getattr(self._img, key)
def __str__(self):
return self._lazy_filepath or self._lazy_url
class WeightedPrompt:
def __init__(self, text, weight=1):
self.text = text
self.weight = weight
def __str__(self):
return f"{self.weight}*({self.text})"
class ImaginePrompt:
def __init__(
self,
prompt=None,
prompt_strength=7.5,
init_image=None, # Pillow Image, LazyLoadingImage, or filepath str
init_image_strength=0.3,
seed=None,
steps=50,
height=512,
width=512,
upscale=False,
fix_faces=False,
sampler_type="PLMS",
):
prompt = prompt if prompt is not None else "a scenic landscape"
if isinstance(prompt, str):
self.prompts = [WeightedPrompt(prompt, 1)]
else:
self.prompts = prompt
self.prompts.sort(key=lambda p: p.weight, reverse=True)
self.prompt_strength = prompt_strength
if isinstance(init_image, str):
init_image = LazyLoadingImage(filepath=init_image)
self.init_image = init_image
self.init_image_strength = init_image_strength
self.seed = random.randint(1, 1_000_000_000) if seed is None else seed
self.steps = steps
self.height = height
self.width = width
self.upscale = upscale
self.fix_faces = fix_faces
self.sampler_type = sampler_type
@property
def prompt_text(self):
if len(self.prompts) == 1:
return self.prompts[0].text
return "|".join(str(p) for p in self.prompts)
def prompt_description(self):
return (
f'🖼 : "{self.prompt_text}" {self.width}x{self.height}px '
f"seed:{self.seed} prompt-strength:{self.prompt_strength} steps:{self.steps} sampler-type:{self.sampler_type}"
)
def as_dict(self):
prompts = [(p.weight, p.text) for p in self.prompts]
return {
"software": "imaginairy",
"prompts": prompts,
"prompt_strength": self.prompt_strength,
"init_image": str(self.init_image),
"init_image_strength": self.init_image_strength,
"seed": self.seed,
"steps": self.steps,
"height": self.height,
"width": self.width,
"upscale": self.upscale,
"fix_faces": self.fix_faces,
"sampler_type": self.sampler_type,
}
class ExifCodes:
"""https://www.awaresystems.be/imaging/tiff/tifftags/baseline.html"""
2022-09-11 07:35:57 +00:00
ImageDescription = 0x010E
Software = 0x0131
DateTime = 0x0132
HostComputer = 0x013C
UserComment = 0x9286
class ImagineResult:
def __init__(self, img, prompt: ImaginePrompt, is_nsfw, upscaled_img=None):
self.img = img
self.upscaled_img = upscaled_img
self.prompt = prompt
self.is_nsfw = is_nsfw
self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc)
self.torch_backend = get_device()
self.hardware_name = get_device_name(get_device())
def cv2_img(self):
open_cv_image = numpy.array(self.img)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
return open_cv_image
# return cv2.cvtColor(numpy.array(self.img), cv2.COLOR_RGB2BGR)
def md5(self):
return hashlib.md5(self.img.tobytes()).hexdigest()
def metadata_dict(self):
return {
"prompt": self.prompt.as_dict(),
}
def _exif(self):
exif = Image.Exif()
exif[ExifCodes.ImageDescription] = self.prompt.prompt_description()
exif[ExifCodes.UserComment] = json.dumps(self.metadata_dict())
# help future web scrapes not ingest AI generated art
exif[ExifCodes.Software] = "Imaginairy / Stable Diffusion v1.4"
exif[ExifCodes.DateTime] = self.created_at.isoformat(sep=" ")[:19]
exif[ExifCodes.HostComputer] = f"{self.torch_backend}:{self.hardware_name}"
return exif
def save(self, save_path):
self.img.save(save_path, exif=self._exif())
def save_upscaled(self, save_path):
self.upscaled_img.save(save_path, exif=self._exif())