feature/fix: migrate to pydantic 2.3

- test: add schema tests/fuzzer and fixes
 - fix default prompt. add tests
 - fix outpaint and controlnet defaults
 - fix init image strength defaults
pull/382/head
Bryce 9 months ago committed by Bryce Drennan
parent 8e956f5360
commit 7c2004bfcc

@ -30,15 +30,15 @@ def generate_image_morph_video():
transcendence_weight = (max(year - 2050, 0) / 3) * robotic_weight
subprompts = [
WeightedPrompt(
f"{year_txt} professional {color_txt} headshot photo of a woman with a pearl earring wearing an {year_txt} outfit. {scene}",
text=f"{year_txt} professional {color_txt} headshot photo of a woman with a pearl earring wearing an {year_txt} outfit. {scene}",
weight=pearl_weight + 0.1,
),
WeightedPrompt(
"photo of a cybernetic woman computer chips in her head. circuits, cybernetic, robotic, biomechanical, elegant, sharp focus, highly detailed, intricate details. scenic majestic mountains of mars in the background",
text="photo of a cybernetic woman computer chips in her head. circuits, cybernetic, robotic, biomechanical, elegant, sharp focus, highly detailed, intricate details. scenic majestic mountains of mars in the background",
weight=robotic_weight + 0.01,
),
WeightedPrompt(
"photo of a cybernetic woman floating above a wormhole. computer chips in her head. circuits, cybernetic, robotic, biomechanical, elegant, sharp focus, highly detailed, intricate details",
text="photo of a cybernetic woman floating above a wormhole. computer chips in her head. circuits, cybernetic, robotic, biomechanical, elegant, sharp focus, highly detailed, intricate details",
weight=transcendence_weight,
),
]

@ -184,3 +184,7 @@ def imagine_cmd(
caption_text,
control_inputs=control_inputs,
)
if __name__ == "__main__":
imagine_cmd() # noqa

@ -245,11 +245,13 @@ CONTROLNET_CONFIGS = [
),
]
CONTROLNET_CONFIG_SHORTCUTS = {m.short_name: m for m in CONTROLNET_CONFIGS}
CONTROLNET_CONFIG_SHORTCUTS = {}
for m in CONTROLNET_CONFIGS:
if m.alias:
CONTROLNET_CONFIG_SHORTCUTS[m.alias] = m
for m in CONTROLNET_CONFIGS:
CONTROLNET_CONFIG_SHORTCUTS[m.short_name] = m
SAMPLER_TYPE_OPTIONS = [
"plms",

@ -206,6 +206,8 @@ def prepare_image_for_outpaint(
def outpaint_arg_str_parse(arg_str):
if not arg_str:
return {}
arg_pattern = re.compile(r"([A-Z]+)(\d+)")
args = arg_str.upper().split(",")

@ -7,9 +7,18 @@ import logging
import os.path
import random
from datetime import datetime, timezone
from io import BytesIO
from typing import TYPE_CHECKING, Any, List, Literal, Optional
from pydantic import BaseModel, Field, validator
from pydantic import (
BaseModel,
Field,
GetCoreSchemaHandler,
field_validator,
model_validator,
)
from pydantic_core import core_schema
from pydantic_core.core_schema import FieldValidationInfo
from imaginairy import config
@ -22,6 +31,20 @@ else:
logger = logging.getLogger(__name__)
def save_image_as_base64(image: "Image.Image") -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
return base64.b64encode(img_bytes).decode()
def load_image_from_base64(image_str: str) -> "Image.Image":
from PIL import Image
img_bytes = base64.b64decode(image_str)
return Image.open(io.BytesIO(img_bytes))
class InvalidUrlError(ValueError):
pass
@ -29,11 +52,13 @@ class InvalidUrlError(ValueError):
class LazyLoadingImage:
"""Image file encoded as base64 string."""
def __init__(self, *, filepath=None, url=None, img=None):
if not filepath and not url and not img:
raise ValueError("You must specify a url or filepath or img")
if sum([bool(filepath), bool(url), bool(img)]) > 1:
raise ValueError("You cannot specify a url and filepath")
def __init__(self, *, filepath=None, url=None, img: Image = None, b64: str = None):
if not filepath and not url and not img and not b64:
raise ValueError(
"You must specify a url or filepath or img or base64 string"
)
if sum([bool(filepath), bool(url), bool(img), bool(b64)]) > 1:
raise ValueError("You cannot multiple input methods")
# validate file exists
if filepath and not os.path.exists(filepath):
@ -51,6 +76,9 @@ class LazyLoadingImage:
if parsed_url.scheme not in {"http", "https"} or not parsed_url.host:
raise InvalidUrlError(f"Invalid url: {url}")
if b64:
img = self.load_image_from_base64(b64)
self._lazy_filepath = filepath
self._lazy_url = url
self._img = img
@ -75,8 +103,11 @@ class LazyLoadingImage:
import requests
self._img = Image.open(
requests.get(self._lazy_url, stream=True, timeout=60).raw
BytesIO(
requests.get(self._lazy_url, stream=True, timeout=60).content
)
)
logger.debug(
f"Loaded input 🖼 of size {self._img.size} from {self._lazy_url}"
)
@ -86,25 +117,53 @@ class LazyLoadingImage:
self._img = ImageOps.exif_transpose(self._img)
@classmethod
def __modify_schema__(cls, field_schema, field):
field_schema["title"] = field.name.replace("_", " ").title()
@classmethod
def __get_validators__(cls):
yield cls.validate
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
def validate(value: Any) -> "LazyLoadingImage":
from PIL import Image, UnidentifiedImageError
if isinstance(value, cls):
return value
if isinstance(value, Image.Image):
return cls(img=value)
if isinstance(value, str):
if "." in value[:1000]:
try:
return cls(filepath=value)
except FileNotFoundError as e:
raise ValueError(str(e)) # noqa
try:
return cls(b64=value)
except UnidentifiedImageError:
msg = "base64 string was not recognized as a valid image type"
raise ValueError(msg) # noqa
if isinstance(value, dict):
return cls(**value)
raise ValueError(
"Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string"
)
@classmethod
def validate(cls, v):
from PIL import Image
def handle_b64(value: Any) -> "LazyLoadingImage":
if isinstance(value, str):
return cls(b64=value)
raise ValueError(
"Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string"
)
if isinstance(v, cls):
return v
if isinstance(v, Image.Image):
return cls(img=v)
if isinstance(v, str):
return cls(img=cls.load_image_from_base64(v))
raise ValueError(
"Image value must be either a PIL.Image.Image or a Base64 string"
return core_schema.json_or_python_schema(
json_schema=core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_before_validator_function(
handle_b64, core_schema.any_schema()
),
]
),
python_schema=core_schema.no_info_before_validator_function(
validate, core_schema.any_schema()
),
serialization=core_schema.plain_serializer_function_ser_schema(str),
)
@staticmethod
@ -121,10 +180,13 @@ class LazyLoadingImage:
img_bytes = base64.b64decode(image_str)
return Image.open(io.BytesIO(img_bytes))
def __str__(self):
def as_base64(self):
self._load_img()
return self.save_image_as_base64(self._img) # type: ignore
def __str__(self):
return self.as_base64()
def __repr__(self):
"""human readable representation.
@ -133,15 +195,31 @@ class LazyLoadingImage:
return f"<LazyLoadingImage filepath={self._lazy_filepath} url={self._lazy_url}>"
#
# LazyLoadingImage = Annotated[
# _LazyLoadingImage,
# AfterValidator(_LazyLoadingImage.validate),
# PlainSerializer(lambda i: str(i), return_type=str),
# WithJsonSchema({"type": "string"}, mode="serialization"),
# ]
class ControlNetInput(BaseModel):
mode: str
image: Optional[LazyLoadingImage] = None
image_raw: Optional[LazyLoadingImage] = None
strength: int = Field(1, ge=0)
@validator("image_raw")
def image_raw_validate(cls, v, values):
if values.get("image") is not None and v is not None:
strength: int = Field(1, ge=0, le=1000)
# @field_validator("image", "image_raw", mode="before")
# def validate_images(cls, v):
# if isinstance(v, str):
# return LazyLoadingImage(filepath=v)
#
# return v
@field_validator("image_raw")
def image_raw_validate(cls, v, info: FieldValidationInfo):
if info.data.get("image") is not None and v is not None:
raise ValueError("You cannot specify both image and image_raw")
# if v is None and values.get("image") is None:
@ -149,47 +227,68 @@ class ControlNetInput(BaseModel):
return v
@field_validator("mode")
def mode_validate(cls, v):
if v not in config.CONTROLNET_CONFIG_SHORTCUTS:
valid_modes = list(config.CONTROLNET_CONFIG_SHORTCUTS.keys())
valid_modes = ", ".join(valid_modes)
msg = f"Invalid controlnet mode: '{v}'. Valid modes are: {valid_modes}"
raise ValueError(msg)
return v
class WeightedPrompt(BaseModel):
text: str
weight: int = Field(1, ge=0)
weight: float = Field(1, ge=0)
def __repr__(self):
return f"{self.weight}*({self.text})"
class ImaginePrompt(BaseModel):
prompt: Optional[List[WeightedPrompt]]
negative_prompt: Optional[List[WeightedPrompt]]
prompt_strength: Optional[float] = 7.5
prompt: Optional[List[WeightedPrompt]] = Field(default=None, validate_default=True)
negative_prompt: Optional[List[WeightedPrompt]] = Field(
default=None, validate_default=True
)
prompt_strength: Optional[float] = Field(
default=7.5, le=10_000, ge=-10_000, validate_default=True
)
init_image: Optional[LazyLoadingImage] = Field(
None, description="base64 encoded image"
None, description="base64 encoded image", validate_default=True
)
init_image_strength: Optional[float] = Field(
ge=0, le=1, default=None, validate_default=True
)
control_inputs: List[ControlNetInput] = Field(
default_factory=list, validate_default=True
)
init_image_strength: Optional[float] = Field(ge=0, le=1)
control_inputs: Optional[List[ControlNetInput]]
mask_prompt: Optional[str] = Field(
description="text description of the things to be masked"
default=None,
description="text description of the things to be masked",
validate_default=True,
)
mask_image: Optional[LazyLoadingImage]
mask_image: Optional[LazyLoadingImage] = Field(default=None, validate_default=True)
mask_mode: Optional[Literal["keep", "replace"]] = "replace"
mask_modify_original: bool = True
outpaint: Optional[str]
model: str = config.DEFAULT_MODEL
model_config_path: Optional[str]
sampler_type: str = config.DEFAULT_SAMPLER
seed: Optional[int]
steps: Optional[int]
height: Optional[int] = Field(None, ge=1)
width: Optional[int] = Field(None, ge=1)
outpaint: Optional[str] = ""
model: str = Field(default=config.DEFAULT_MODEL, validate_default=True)
model_config_path: Optional[str] = None
sampler_type: str = Field(default=config.DEFAULT_SAMPLER, validate_default=True)
seed: Optional[int] = Field(default=None, validate_default=True)
steps: Optional[int] = Field(default=None, validate_default=True)
height: Optional[int] = Field(None, ge=1, le=100_000, validate_default=True)
width: Optional[int] = Field(None, ge=1, le=100_000, validate_default=True)
upscale: bool = False
fix_faces: bool = False
fix_faces_fidelity: Optional[float] = Field(0.2, ge=0, le=1)
fix_faces_fidelity: Optional[float] = Field(0.2, ge=0, le=1, validate_default=True)
conditioning: Optional[str] = None
tile_mode: str = ""
allow_compose_phase: bool = True
is_intermediate: bool = False
collect_progress_latents: bool = False
caption_text: str = Field("", description="text to be overlaid on the image")
caption_text: str = Field(
"", description="text to be overlaid on the image", validate_default=True
)
class MaskMode:
REPLACE = "replace"
@ -199,108 +298,150 @@ class ImaginePrompt(BaseModel):
# allows `prompt` to be positional
super().__init__(prompt=prompt, **kwargs)
@validator("prompt", "negative_prompt", pre=True, always=True)
@field_validator("prompt", "negative_prompt", mode="before")
@classmethod
def make_into_weighted_prompts(cls, v):
# if isinstance(v, list):
# v = [WeightedPrompt.parse_obj(p) if isinstance(p, dict) else p for p in v]
if isinstance(v, str):
v = [WeightedPrompt(text=v)]
elif isinstance(v, WeightedPrompt):
v = [v]
return v
@validator("prompt", "negative_prompt", always=True)
@field_validator("prompt", "negative_prompt", mode="after")
@classmethod
def must_have_some_weight(cls, v):
if v:
total_weight = sum(p.weight for p in v)
if total_weight == 0:
raise ValueError("Total weight of prompts cannot be 0")
return v
@field_validator("prompt", "negative_prompt", mode="after")
def sort_prompts(cls, v):
if isinstance(v, list):
v.sort(key=lambda p: p.weight, reverse=True)
return v
@validator("negative_prompt", always=True)
def validate_negative_prompt(cls, v, values):
if not v:
model_config = config.MODEL_CONFIG_SHORTCUTS.get(v, None)
@model_validator(mode="after")
def validate_negative_prompt(self):
if self.negative_prompt is None:
model_config = config.MODEL_CONFIG_SHORTCUTS.get(self.model, None)
if model_config:
v = [WeightedPrompt(text=model_config.default_negative_prompt)]
self.negative_prompt = [
WeightedPrompt(text=model_config.default_negative_prompt)
]
else:
v = [WeightedPrompt(text=config.DEFAULT_NEGATIVE_PROMPT)]
self.negative_prompt = [
WeightedPrompt(text=config.DEFAULT_NEGATIVE_PROMPT)
]
return self
return v
@validator("prompt_strength", always=True)
@field_validator("prompt_strength")
def validate_prompt_strength(cls, v):
return 7.5 if v is None else v
@validator("tile_mode", always=True, pre=True)
@field_validator("tile_mode", mode="before")
def validate_tile_mode(cls, v):
valid_tile_modes = ("", "x", "y", "xy")
if v is True:
return "xy"
if v is False:
if v is False or v is None:
return ""
if not isinstance(v, str):
raise ValueError(
f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}"
)
v = v.lower()
assert v in ("", "x", "y", "xy")
if v not in valid_tile_modes:
raise ValueError(
f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}"
)
return v
@validator("init_image", "mask_image", always=True)
def handle_images(cls, v):
if isinstance(v, str):
return LazyLoadingImage(filepath=v)
@field_validator("outpaint", mode="after")
def validate_outpaint(cls, v):
from imaginairy.outpaint import outpaint_arg_str_parse
outpaint_arg_str_parse(v)
return v
@validator("init_image", always=True)
def set_init_from_control_inputs(cls, v, values):
if v is None and values.get("control_inputs"):
for control_input in values["control_inputs"]:
@field_validator("conditioning", mode="after")
def validate_conditioning(cls, v):
from torch import Tensor
if v is None:
return v
if not isinstance(v, Tensor):
raise ValueError("conditioning must be a torch.Tensor")
return v
# @field_validator("init_image", "mask_image", mode="after")
# def handle_images(cls, v):
# if isinstance(v, str):
# return LazyLoadingImage(filepath=v)
#
# return v
@model_validator(mode="after")
def set_init_from_control_inputs(self):
if self.init_image is None:
for control_input in self.control_inputs:
if control_input.image:
return control_input.image
self.init_image = control_input.image
break
return self
@field_validator("control_inputs", mode="before")
def validate_control_inputs(cls, v):
if v is None:
v = []
return v
@validator("control_inputs", always=True)
def set_image_from_init_image(cls, v, values):
@field_validator("control_inputs", mode="after")
def set_image_from_init_image(cls, v, info: FieldValidationInfo):
v = v or []
for control_input in v:
print(control_input)
if control_input.image is None and control_input.image_raw is None:
control_input.image = values["init_image"]
control_input.image = info.data["init_image"]
return v
@validator("mask_image", always=True)
def validate_mask_image(cls, v, values):
if v is not None and values["mask_prompt"] is not None:
@field_validator("mask_image")
def validate_mask_image(cls, v, info: FieldValidationInfo):
if v is not None and info.data.get("mask_prompt") is not None:
raise ValueError("You can only set one of `mask_image` and `mask_prompt`")
return v
@validator("mask_prompt", always=True)
def validate_mask_prompt(cls, v, values):
if values["init_image"] is None and v:
raise ValueError(
"You must set `init_image` if you want to use `mask_prompt`"
)
@field_validator("mask_prompt", "mask_image", mode="before")
def validate_mask_prompt(cls, v, info: FieldValidationInfo):
if info.data.get("init_image") is None and v:
raise ValueError("You must set `init_image` if you want to use a mask")
return v
@validator("model", always=True)
@field_validator("model", mode="before")
def set_default_diffusion_model(cls, v):
if v is None:
return config.DEFAULT_MODEL
return v
@validator("seed", always=True)
@field_validator("seed")
def validate_seed(cls, v):
return v
@validator("fix_faces_fidelity", always=True)
@field_validator("fix_faces_fidelity", mode="before")
def validate_fix_faces_fidelity(cls, v):
if v is None:
return 0.2
return v
@validator("sampler_type", pre=True, always=True)
def validate_sampler_type(cls, v, values):
@field_validator("sampler_type", mode="after")
def validate_sampler_type(cls, v, info: FieldValidationInfo):
from imaginairy.samplers import SamplerName
if v is None:
@ -308,10 +449,10 @@ class ImaginePrompt(BaseModel):
v = v.lower()
if values["model"] == "SD-2.0-v" and v == SamplerName.PLMS:
if info.data.get("model") == "SD-2.0-v" and v == SamplerName.PLMS:
raise ValueError("PLMS sampler is not supported for SD-2.0-v model.")
if values["model"] == "edit" and v in (
if info.data.get("model") == "edit" and v in (
SamplerName.PLMS,
SamplerName.DDIM,
):
@ -320,43 +461,39 @@ class ImaginePrompt(BaseModel):
)
return v
@validator("steps", always=True)
def validate_steps(cls, v, values):
@field_validator("steps")
def validate_steps(cls, v, info: FieldValidationInfo):
from imaginairy.samplers import SAMPLER_LOOKUP
if v is None:
SamplerCls = SAMPLER_LOOKUP[values["sampler_type"]]
SamplerCls = SAMPLER_LOOKUP[info.data["sampler_type"]]
v = SamplerCls.default_steps
return int(v)
@validator("init_image_strength", always=True)
def validate_init_image_strength(cls, v, values):
if v is None:
if values.get("control_inputs"):
v = 0.0
elif (
values.get("outpaint")
or values.get("mask_image")
or values.get("mask_prompt")
):
v = 0.0
@model_validator(mode="after")
def validate_init_image_strength(self):
if self.init_image_strength is None:
if self.control_inputs:
self.init_image_strength = 0.0
elif self.outpaint or self.mask_image or self.mask_prompt:
self.init_image_strength = 0.0
else:
v = 0.2
self.init_image_strength = 0.2
return v
return self
@validator("height", "width", always=True)
def validate_image_size(cls, v, values):
@field_validator("height", "width")
def validate_image_size(cls, v, info: FieldValidationInfo):
from imaginairy.model_manager import get_model_default_image_size
if v is None:
v = get_model_default_image_size(values["model"])
v = get_model_default_image_size(info.data["model"])
return v
@validator("caption_text", pre=True, always=True)
def validate_caption_text(cls, v, values):
@field_validator("caption_text", mode="before")
def validate_caption_text(cls, v):
if v is None:
v = ""
@ -391,7 +528,7 @@ class ImaginePrompt(BaseModel):
def logging_dict(self):
"""Return a dict of the object but with binary data replaced with reprs."""
data = self.dict()
data = self.model_dump()
data["init_image"] = repr(self.init_image)
data["mask_image"] = repr(self.mask_image)
if self.control_inputs:
@ -399,17 +536,12 @@ class ImaginePrompt(BaseModel):
return data
def full_copy(self, deep=True, update=None):
new_prompt = self.copy(
new_prompt = self.model_copy(
deep=deep,
update=update,
)
new_prompt = new_prompt.validate(
dict(
new_prompt._iter( # noqa
to_dict=False, by_alias=False, exclude_unset=True
)
)
)
# new_prompt = self.model_validate(new_prompt) doesn't work for some reason https://github.com/pydantic/pydantic/issues/7387
new_prompt = new_prompt.model_validate(dict(new_prompt))
return new_prompt
def make_concrete_copy(self):

@ -0,0 +1,177 @@
import math
import sys
from copy import deepcopy
from decimal import Decimal
from typing import Dict, Tuple, Union
NODE_DELETE = object()
DISTORTED_NUMBERS = [
math.nan,
math.inf,
-math.inf,
0,
1,
-1,
0.000000000001,
-0.000000000001,
2**1024,
-(2**1024),
Decimal("0.000000000001"),
Decimal(1) / Decimal(3),
1.0 / 3.0,
sys.float_info.max,
"20",
1.3333333333333333333e20,
]
DISTORTED_DATES = [
"2021-01-01T00:00:00",
"2021-01-01",
"0000-00-00",
"0001-01-01",
"9001-01-01",
]
DISTORTED_STRINGS = [
"",
b"\00\001\002\003\004\005\006\007\010\011\012\013\014\015\016\017",
" ",
"\t\r\n",
"\\r\\n\\t",
"hello",
"👩‍👩‍👧‍👧👩‍👩‍👧‍👧👩‍👩‍👧‍👧👩‍👩‍👧‍👧👩‍👩‍👧‍👧",
"a" * 10000,
"0",
"!@#$%^&*()_+-=[]{}|;':\",.<>?/©™®",
"你好こんにちは안녕하세요Привет",
"<script>alert('Hello')</script>",
(
"àáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿ"
"ĀāĂ㥹ĆćĈĉĊċČčĎďĐđĒēĔĕĖėĘęĚěĜĝĞğĠġĢģĤĥĦħ"
"ĨĩĪīĬĭĮįİıIJijĴĵĶķĸĹĺĻļĽľĿŀŁłŃńŅņŇňʼnŊŋ"
"ŌōŎŏŐőŒœŔŕŖŗŘřŚśŜŝŞşŠšŢţŤťŦŧŨũŪūŬŭŮůŰűŲų"
"ŴŵŶŷŸŹźŻżŽžſ"
),
]
DISTORTED_BOOLEAN = [
True,
False,
"True",
"False",
]
DISTORTED_OTHER = [(), object(), type(object), lambda x: x, NODE_DELETE, None]
DISTORTED_VALUES = (
DISTORTED_NUMBERS
+ DISTORTED_DATES
+ DISTORTED_STRINGS
+ DISTORTED_BOOLEAN
+ DISTORTED_OTHER
)
class DataDistorter:
def __init__(self, data, add_data_values=True):
self.data = deepcopy(data)
self.data_map, self.data_unique_values = create_node_map(self.data)
self.distortion_values = DISTORTED_VALUES + []
if add_data_values:
self.distortion_values += list(self.data_unique_values)
def make_distorted_copy(self, node_number: int, distorted_value):
"""
Make a distorted copy of the data.
The node number is the index in the node map.
"""
data = deepcopy(self.data)
data = replace_value_at_path(data, self.data_map[node_number], distorted_value)
return data
def single_distortions(self):
for node_number in range(len(self.data_map)):
for distorted_value in DISTORTED_VALUES:
yield self.make_distorted_copy(node_number, distorted_value)
def double_distortions(self):
for node_number in range(len(self.data_map)):
for distorted_value in DISTORTED_VALUES:
self.make_distorted_copy(node_number, distorted_value)
def __iter__(self):
for node_number in range(len(self.data_map)):
for distorted_value in DISTORTED_VALUES:
yield self.make_distorted_copy(node_number, distorted_value)
# nested dictionary helper functions
def create_node_map(data: Union[dict, list, tuple]) -> Tuple[Dict[int, list], set]:
"""
Create a map of node numbers to paths in a nested dictionary.
Include all nodes, not just leaves.
Example:
data = {"a": {"b": ["c", "d"]}, "e": "f"}
node_map = create_node_map(data)
assert node_map = {
0: [],
1: ["a"],
2: ["a", "b"],
3: ["a", "b", 0],
4: ["a", "b", 1],
5: ["e"],
}
"""
node_map = {}
node_values = set()
node_num = [
0
] # Using a list to hold the current node number as integers are immutable
def _traverse(curr_data, curr_path):
node_map[node_num[0]] = curr_path.copy()
node_num[0] += 1
if isinstance(curr_data, dict):
for key, value in curr_data.items():
_traverse(value, curr_path + [key])
elif isinstance(curr_data, (list, tuple)):
for idx, item in enumerate(curr_data):
_traverse(item, curr_path + [idx])
else:
try:
node_values.add(curr_data)
except TypeError:
pass
_traverse(data, [])
return node_map, node_values
def get_path(data: dict, path):
"""Get a value from a nested dictionary using a path."""
curr_data = data
for key in path:
curr_data = curr_data[key]
return curr_data
def replace_value_at_path(data, path, new_value):
"""Replace a value in a nested dictionary using a path."""
if not path:
return new_value
parent = get_path(data, path[:-1])
last_key = path[-1]
if new_value == NODE_DELETE:
del parent[last_key]
else:
parent[last_key] = new_value
return data

@ -8,6 +8,8 @@ aiohttp==3.8.5
# via fsspec
aiosignal==1.3.1
# via aiohttp
annotated-types==0.5.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via omegaconf
anyio==3.7.1
@ -20,7 +22,7 @@ async-timeout==4.0.3
# via aiohttp
attrs==23.1.0
# via aiohttp
black==23.9.0
black==23.9.1
# via -r requirements-dev.in
certifi==2023.7.22
# via requests
@ -85,7 +87,7 @@ ftfy==6.1.1
# open-clip-torch
h11==0.14.0
# via uvicorn
huggingface-hub==0.16.4
huggingface-hub==0.17.1
# via
# diffusers
# open-clip-torch
@ -120,7 +122,7 @@ lightning-utilities==0.9.0
# torchmetrics
llvmlite==0.40.1
# via numba
matplotlib==3.7.2
matplotlib==3.7.3
# via filterpy
mccabe==0.7.0
# via
@ -200,10 +202,12 @@ pycln==2.2.2
# via -r requirements-dev.in
pycodestyle==2.11.0
# via pylama
pydantic==1.10.12
pydantic==2.3.0
# via
# fastapi
# imaginAIry (setup.py)
pydantic-core==2.6.3
# via pydantic
pydocstyle==6.3.0
# via pylama
pyflakes==3.1.0
@ -212,7 +216,7 @@ pylama==8.4.1
# via -r requirements-dev.in
pylint==2.17.5
# via -r requirements-dev.in
pyparsing==3.0.9
pyparsing==3.1.1
# via matplotlib
pytest==7.4.2
# via
@ -253,7 +257,7 @@ requests==2.31.0
# transformers
responses==0.23.3
# via -r requirements-dev.in
ruff==0.0.287
ruff==0.0.288
# via -r requirements-dev.in
safetensors==0.3.3
# via
@ -308,7 +312,7 @@ torch==1.13.1
# torchvision
torchdiffeq==0.2.3
# via imaginAIry (setup.py)
torchmetrics==1.1.1
torchmetrics==1.1.2
# via
# imaginAIry (setup.py)
# pytorch-lightning
@ -342,6 +346,7 @@ typing-extensions==4.7.1
# libcst
# lightning-utilities
# pydantic
# pydantic-core
# pytorch-lightning
# torch
# torchvision

@ -95,7 +95,7 @@ setup(
"open-clip-torch>=2.0.0",
"opencv-python>=4.4.0.46",
# need to migration to 2.0
"pydantic<2.0.0",
"pydantic>=2.3.0",
"requests>=2.28.1",
"einops>=0.3.0",
"safetensors>=0.2.1",

@ -1,35 +0,0 @@
from typing import Optional
import pytest
from pydantic import BaseModel
from imaginairy import LazyLoadingImage
from imaginairy.schema import InvalidUrlError
from tests import TESTS_FOLDER
def test_lazy_load_image():
with pytest.raises(ValueError, match=r".*specify a url or filepath.*"):
LazyLoadingImage()
with pytest.raises(FileNotFoundError, match=r".*File does not exist.*"):
LazyLoadingImage(filepath="/tmp/bterpojirewpdfsn/ergqgr")
with pytest.raises(InvalidUrlError):
LazyLoadingImage(url="/tmp/bterpojirewpdfsn/ergqgr")
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
assert img.size == (1686, 1246)
def test_image_serialization():
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png")
orig_size = img.size
img64 = str(img)
class TestModel(BaseModel):
img: Optional[LazyLoadingImage]
m = TestModel.parse_raw(f'{{"img": "{img64}"}}')
assert m.img.size == orig_size
assert str(m.img) == img64

@ -0,0 +1,39 @@
import pytest
from pydantic import ValidationError
from imaginairy import LazyLoadingImage
from imaginairy.schema import ControlNetInput
from tests import TESTS_FOLDER
@pytest.fixture(name="lazy_img")
def _lazy_img():
return LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png")
def test_controlnetinput_basic(lazy_img):
ControlNetInput(mode="canny", image=lazy_img)
ControlNetInput(mode="canny", image_raw=lazy_img)
def test_controlnetinput_invalid_mode(lazy_img):
with pytest.raises(ValueError, match=r".*Invalid controlnet mode.*"):
ControlNetInput(mode="pizza", image=lazy_img)
def test_controlnetinput_both_images(lazy_img):
with pytest.raises(ValueError, match=r".*cannot specify both.*"):
ControlNetInput(mode="canny", image=lazy_img, image_raw=lazy_img)
def test_controlnetinput_filepath_input(lazy_img):
"""Test that we accept filepaths here."""
c = ControlNetInput(mode="canny", image=f"{TESTS_FOLDER}/data/red.png")
c.image.convert("RGB")
c = ControlNetInput(mode="canny", image_raw=f"{TESTS_FOLDER}/data/red.png")
c.image_raw.convert("RGB")
def test_controlnetinput_big(lazy_img):
with pytest.raises(ValidationError, match=r".*less than or.*"):
ControlNetInput(mode="canny", strength=2**2048)

@ -0,0 +1,229 @@
import pytest
from pydantic import ValidationError
from imaginairy import LazyLoadingImage, config
from imaginairy.schema import ControlNetInput, ImaginePrompt, WeightedPrompt
from imaginairy.utils.data_distorter import DataDistorter
from tests import TESTS_FOLDER
def test_imagine_prompt_has_default_negative():
prompt = ImaginePrompt("fruit salad", model="foobar")
assert isinstance(prompt.prompt[0], WeightedPrompt)
assert isinstance(prompt.negative_prompt[0], WeightedPrompt)
def test_imagine_prompt_custom_negative_prompt():
prompt = ImaginePrompt("fruit salad", negative_prompt="pizza")
assert isinstance(prompt.prompt[0], WeightedPrompt)
assert isinstance(prompt.negative_prompt[0], WeightedPrompt)
assert prompt.negative_prompt[0].text == "pizza"
def test_imagine_prompt_model_specific_negative_prompt():
prompt = ImaginePrompt("fruit salad", model="openjourney-v1")
assert isinstance(prompt.prompt[0], WeightedPrompt)
assert isinstance(prompt.negative_prompt[0], WeightedPrompt)
assert prompt.negative_prompt[0].text == ""
def test_imagine_prompt_weighted_prompts():
prompt = ImaginePrompt(WeightedPrompt(text="cat", weight=0.1))
assert isinstance(prompt.prompt[0], WeightedPrompt)
prompt = ImaginePrompt(
[
WeightedPrompt(text="cat", weight=0.1),
WeightedPrompt(text="dog", weight=0.2),
]
)
assert isinstance(prompt.prompt[0], WeightedPrompt)
assert prompt.prompt[0].text == "dog"
def test_imagine_prompt_tile_mode():
prompt = ImaginePrompt("fruit")
assert prompt.tile_mode == ""
prompt = ImaginePrompt("fruit", tile_mode=True)
assert prompt.tile_mode == "xy"
prompt = ImaginePrompt("fruit", tile_mode=False)
assert prompt.tile_mode == ""
prompt = ImaginePrompt("fruit", tile_mode="X")
assert prompt.tile_mode == "x"
with pytest.raises(ValueError, match=r".*Invalid tile_mode.*"):
ImaginePrompt("fruit", tile_mode="pizza")
def test_imagine_prompt_copy():
p1 = ImaginePrompt("fruit")
p2 = p1.full_copy()
assert p1 == p2
assert id(p1) != id(p2)
def test_imagine_prompt_concrete_copy():
p1 = ImaginePrompt("fruit")
p2 = p1.make_concrete_copy()
assert p1 != p2
assert id(p1) != id(p2)
assert p1.seed is None
assert p2.seed is not None
def test_imagine_prompt_image_paths():
p = ImaginePrompt("fruit", init_image=f"{TESTS_FOLDER}/data/red.png")
assert isinstance(p.init_image, LazyLoadingImage)
def test_imagine_prompt_control_inputs():
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png")
prompt = ImaginePrompt(
"fruit",
control_inputs=[
ControlNetInput(mode="depth", image=img),
],
)
prompt.control_inputs[0].image.convert("RGB")
# init image should be set from first control-image if init image wasn't set
assert prompt.init_image is not None
assert isinstance(prompt.init_image, LazyLoadingImage)
# if an image isn't specified for a controlnet, use an init image
prompt = ImaginePrompt(
"fruit",
init_image=img,
control_inputs=[
ControlNetInput(mode="depth"),
],
)
assert prompt.control_inputs[0].image is not None
# if an image isn't specified for a controlnet or init image, what should happen?
prompt = ImaginePrompt(
"fruit",
control_inputs=[
ControlNetInput(mode="depth"),
],
)
assert prompt.control_inputs[0].image is None
def test_imagine_prompt_mask_params():
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png")
with pytest.raises(ValueError, match=r".*only set one.*"):
ImaginePrompt(
"fruit",
init_image=img,
mask_prompt="apple",
mask_image=img,
)
with pytest.raises(ValueError, match=r".*if you want to use a mask.*"):
ImaginePrompt(
"fruit",
mask_prompt="apple",
)
with pytest.raises(ValueError, match=r".*if you want to use a mask.*"):
ImaginePrompt(
"fruit",
mask_image=img,
)
def test_imagine_prompt_default_model():
prompt = ImaginePrompt("fruit", model=None)
assert prompt.model == config.DEFAULT_MODEL
def test_imagine_prompt_default_negative():
prompt = ImaginePrompt("fruit")
assert prompt.negative_prompt[0].text == config.DEFAULT_NEGATIVE_PROMPT
def test_imagine_prompt_fix_faces_fidelity():
assert ImaginePrompt("fruit", fix_faces_fidelity=None).fix_faces_fidelity == 0.2
def test_imagine_prompt_init_strength_zero():
lazy_img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png")
prompt = ImaginePrompt(
"fruit", control_inputs=[ControlNetInput(mode="depth", image=lazy_img)]
)
assert prompt.init_image_strength == 0.0
prompt = ImaginePrompt("fruit")
assert prompt.init_image_strength == 0.2
def test_distorted_prompts():
prompt_obj = ImaginePrompt(
prompt=[
WeightedPrompt(text="sunset", weight=0.7),
WeightedPrompt(text="beach", weight=1.3),
],
negative_prompt=[WeightedPrompt(text="night", weight=1.0)],
prompt_strength=7.0,
init_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"),
init_image_strength=0.5,
control_inputs=[
ControlNetInput(
mode="details",
image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"),
strength=2,
),
ControlNetInput(
mode="depth",
image_raw=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"),
strength=3,
),
],
mask_prompt=None,
mask_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"),
mask_mode="replace",
mask_modify_original=False,
outpaint="all5,up0,down20",
model=config.DEFAULT_MODEL,
model_config_path=None,
sampler_type=config.DEFAULT_SAMPLER,
seed=42,
steps=10,
height=256,
width=256,
upscale=True,
fix_faces=True,
fix_faces_fidelity=0.7,
conditioning=None,
tile_mode="xy",
allow_compose_phase=False,
is_intermediate=False,
collect_progress_latents=False,
caption_text="Sample Caption",
)
data = prompt_obj.model_dump(mode="python")
valid_prompts = []
total_prompts = 0
for i, distorted_data in enumerate(DataDistorter(data)):
total_prompts += 1
try:
distorted_prompt = ImaginePrompt.model_validate(distorted_data)
valid_prompts.append(distorted_prompt)
except ValidationError:
continue
print(f"Valid prompts: {len(valid_prompts)}")
print(f"Invalid prompts: {total_prompts - len(valid_prompts)}")
# for p in valid_prompts:
# try:
# imagine_image_files(p, f"{TESTS_FOLDER}/test_output/distorted_prompts/")
# except ValueError as e:
# print(f"################{e}")
# continue
# except Exception as e:
# print("################")
# print(p)
# raise e

@ -0,0 +1,103 @@
import os.path
from typing import Optional
import pytest
from PIL import Image
from pydantic import BaseModel
from imaginairy import LazyLoadingImage
from imaginairy.schema import InvalidUrlError
from tests import TESTS_FOLDER
class TestModel(BaseModel):
header_img: Optional[LazyLoadingImage]
@pytest.fixture(name="red_url")
def _red_url(mocked_responses):
url = "http://example.com/red.png"
with open(os.path.join(TESTS_FOLDER, "data", "red.png"), "rb") as f:
img_data = f.read()
mocked_responses.get(
url,
body=img_data,
status=200,
content_type="image/png",
)
yield url
@pytest.fixture(name="red_path")
def _red_path():
return os.path.join(TESTS_FOLDER, "data", "red.png")
@pytest.fixture(name="red_b64")
def _red_b64():
return "iVBORw0KGgoAAAANSUhEUgAAAgAAAAIAAQMAAADOtka5AAAABlBMVEX/AAD///9BHTQRAAAANklEQVR4nO3BAQEAAACCIP+vbkhAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB8G4IAAAHSeInwAAAAAElFTkSuQmCC"
def test_lazy_load_image(mocked_responses, red_url, red_path, red_b64):
ll_img = LazyLoadingImage(filepath=red_path)
assert ll_img.size == (512, 512)
assert ll_img.as_base64() == red_b64
ll_img = LazyLoadingImage(url=red_url)
assert ll_img.size == (512, 512)
assert ll_img.as_base64() == red_b64
ll_img = LazyLoadingImage(img=Image.open(red_path))
assert ll_img.size == (512, 512)
assert ll_img.as_base64() == red_b64
ll_img = LazyLoadingImage(b64=red_b64)
assert ll_img.size == (512, 512)
assert ll_img.as_base64() == red_b64
def test_lazy_load_image_validation():
with pytest.raises(ValueError, match=r".*specify a url or filepath.*"):
LazyLoadingImage()
with pytest.raises(FileNotFoundError, match=r".*File does not exist.*"):
LazyLoadingImage(filepath="/tmp/bterpojirewpdfsn/ergqgr")
with pytest.raises(InvalidUrlError):
LazyLoadingImage(url="/tmp/bterpojirewpdfsn/ergqgr")
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
assert img.size == (1686, 1246)
def test_image_dump(red_path, red_b64):
obj = TestModel(header_img=LazyLoadingImage(filepath=red_path))
assert obj.header_img.size == (512, 512)
obj_data = obj.model_dump_json()
new_obj = TestModel.model_validate_json(obj_data)
assert new_obj.header_img.size == (512, 512)
assert new_obj.header_img.as_base64() == red_b64
obj_data = obj.model_dump(mode="json")
new_obj = TestModel.model_validate(obj_data)
assert new_obj.header_img.size == (512, 512)
assert new_obj.header_img.as_base64() == red_b64
obj_data = obj.model_dump(mode="python")
new_obj = TestModel.model_validate(obj_data)
assert new_obj.header_img.size == (512, 512)
assert new_obj.header_img.as_base64() == red_b64
def test_image_deserialization(red_path, red_url):
rows = [
{"header_img": LazyLoadingImage(filepath=red_path)},
{"header_img": red_path},
{"header_img": {"filepath": red_path}},
{"header_img": {"url": red_url}},
]
for row in rows:
obj = TestModel.model_validate(row)
assert obj.header_img.size == (512, 512)

@ -13,7 +13,7 @@ linters = pylint,pycodestyle,pyflakes,mypy
ignore =
Z999,C0103,C0201,C0301,C0302,C0114,C0115,C0116,C0415,
Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D406,D407,D413,D415,D417,
Z999,E203,E501,E1101,E1131,E1135,E1136,
Z999,E203,E501,E1101,E1121,E1131,E1133,E1135,E1136,
Z999,R0901,R0902,R0903,R0904,R0193,R0912,R0913,R0914,R0915,R1702,
Z999,W0221,W0511,W0612,W0613,W0632,W1203

Loading…
Cancel
Save