@ -7,9 +7,18 @@ import logging
import os . path
import random
from datetime import datetime , timezone
from io import BytesIO
from typing import TYPE_CHECKING , Any , List , Literal , Optional
from pydantic import BaseModel , Field , validator
from pydantic import (
BaseModel ,
Field ,
GetCoreSchemaHandler ,
field_validator ,
model_validator ,
)
from pydantic_core import core_schema
from pydantic_core . core_schema import FieldValidationInfo
from imaginairy import config
@ -22,6 +31,20 @@ else:
logger = logging . getLogger ( __name__ )
def save_image_as_base64 ( image : " Image.Image " ) - > str :
buffered = io . BytesIO ( )
image . save ( buffered , format = " PNG " )
img_bytes = buffered . getvalue ( )
return base64 . b64encode ( img_bytes ) . decode ( )
def load_image_from_base64 ( image_str : str ) - > " Image.Image " :
from PIL import Image
img_bytes = base64 . b64decode ( image_str )
return Image . open ( io . BytesIO ( img_bytes ) )
class InvalidUrlError ( ValueError ) :
pass
@ -29,11 +52,13 @@ class InvalidUrlError(ValueError):
class LazyLoadingImage :
""" Image file encoded as base64 string. """
def __init__ ( self , * , filepath = None , url = None , img = None ) :
if not filepath and not url and not img :
raise ValueError ( " You must specify a url or filepath or img " )
if sum ( [ bool ( filepath ) , bool ( url ) , bool ( img ) ] ) > 1 :
raise ValueError ( " You cannot specify a url and filepath " )
def __init__ ( self , * , filepath = None , url = None , img : Image = None , b64 : str = None ) :
if not filepath and not url and not img and not b64 :
raise ValueError (
" You must specify a url or filepath or img or base64 string "
)
if sum ( [ bool ( filepath ) , bool ( url ) , bool ( img ) , bool ( b64 ) ] ) > 1 :
raise ValueError ( " You cannot multiple input methods " )
# validate file exists
if filepath and not os . path . exists ( filepath ) :
@ -51,6 +76,9 @@ class LazyLoadingImage:
if parsed_url . scheme not in { " http " , " https " } or not parsed_url . host :
raise InvalidUrlError ( f " Invalid url: { url } " )
if b64 :
img = self . load_image_from_base64 ( b64 )
self . _lazy_filepath = filepath
self . _lazy_url = url
self . _img = img
@ -75,8 +103,11 @@ class LazyLoadingImage:
import requests
self . _img = Image . open (
requests . get ( self . _lazy_url , stream = True , timeout = 60 ) . raw
BytesIO (
requests . get ( self . _lazy_url , stream = True , timeout = 60 ) . content
)
)
logger . debug (
f " Loaded input 🖼 of size { self . _img . size } from { self . _lazy_url } "
)
@ -86,25 +117,53 @@ class LazyLoadingImage:
self . _img = ImageOps . exif_transpose ( self . _img )
@classmethod
def __modify_schema__ ( cls , field_schema , field ) :
field_schema [ " title " ] = field . name . replace ( " _ " , " " ) . title ( )
@classmethod
def __get_validators__ ( cls ) :
yield cls . validate
def __get_pydantic_core_schema__ (
cls , source_type : Any , handler : GetCoreSchemaHandler
) - > core_schema . CoreSchema :
def validate ( value : Any ) - > " LazyLoadingImage " :
from PIL import Image , UnidentifiedImageError
if isinstance ( value , cls ) :
return value
if isinstance ( value , Image . Image ) :
return cls ( img = value )
if isinstance ( value , str ) :
if " . " in value [ : 1000 ] :
try :
return cls ( filepath = value )
except FileNotFoundError as e :
raise ValueError ( str ( e ) ) # noqa
try :
return cls ( b64 = value )
except UnidentifiedImageError :
msg = " base64 string was not recognized as a valid image type "
raise ValueError ( msg ) # noqa
if isinstance ( value , dict ) :
return cls ( * * value )
raise ValueError (
" Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string "
)
@classmethod
def validate ( cls , v ) :
from PIL import Image
def handle_b64 ( value : Any ) - > " LazyLoadingImage " :
if isinstance ( value , str ) :
return cls ( b64 = value )
raise ValueError (
" Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string "
)
if isinstance ( v , cls ) :
return v
if isinstance ( v , Image . Image ) :
return cls ( img = v )
if isinstance ( v , str ) :
return cls ( img = cls . load_image_from_base64 ( v ) )
raise ValueError (
" Image value must be either a PIL.Image.Image or a Base64 string "
return core_schema . json_or_python_schema (
json_schema = core_schema . chain_schema (
[
core_schema . str_schema ( ) ,
core_schema . no_info_before_validator_function (
handle_b64 , core_schema . any_schema ( )
) ,
]
) ,
python_schema = core_schema . no_info_before_validator_function (
validate , core_schema . any_schema ( )
) ,
serialization = core_schema . plain_serializer_function_ser_schema ( str ) ,
)
@staticmethod
@ -121,10 +180,13 @@ class LazyLoadingImage:
img_bytes = base64 . b64decode ( image_str )
return Image . open ( io . BytesIO ( img_bytes ) )
def __str__ ( self ) :
def as_base64 ( self ) :
self . _load_img ( )
return self . save_image_as_base64 ( self . _img ) # type: ignore
def __str__ ( self ) :
return self . as_base64 ( )
def __repr__ ( self ) :
""" human readable representation.
@ -133,15 +195,31 @@ class LazyLoadingImage:
return f " <LazyLoadingImage filepath= { self . _lazy_filepath } url= { self . _lazy_url } > "
#
# LazyLoadingImage = Annotated[
# _LazyLoadingImage,
# AfterValidator(_LazyLoadingImage.validate),
# PlainSerializer(lambda i: str(i), return_type=str),
# WithJsonSchema({"type": "string"}, mode="serialization"),
# ]
class ControlNetInput ( BaseModel ) :
mode : str
image : Optional [ LazyLoadingImage ] = None
image_raw : Optional [ LazyLoadingImage ] = None
strength : int = Field ( 1 , ge = 0 )
@validator ( " image_raw " )
def image_raw_validate ( cls , v , values ) :
if values . get ( " image " ) is not None and v is not None :
strength : int = Field ( 1 , ge = 0 , le = 1000 )
# @field_validator("image", "image_raw", mode="before")
# def validate_images(cls, v):
# if isinstance(v, str):
# return LazyLoadingImage(filepath=v)
#
# return v
@field_validator ( " image_raw " )
def image_raw_validate ( cls , v , info : FieldValidationInfo ) :
if info . data . get ( " image " ) is not None and v is not None :
raise ValueError ( " You cannot specify both image and image_raw " )
# if v is None and values.get("image") is None:
@ -149,47 +227,68 @@ class ControlNetInput(BaseModel):
return v
@field_validator ( " mode " )
def mode_validate ( cls , v ) :
if v not in config . CONTROLNET_CONFIG_SHORTCUTS :
valid_modes = list ( config . CONTROLNET_CONFIG_SHORTCUTS . keys ( ) )
valid_modes = " , " . join ( valid_modes )
msg = f " Invalid controlnet mode: ' { v } ' . Valid modes are: { valid_modes } "
raise ValueError ( msg )
return v
class WeightedPrompt ( BaseModel ) :
text : str
weight : int = Field ( 1 , ge = 0 )
weight : floa t = Field ( 1 , ge = 0 )
def __repr__ ( self ) :
return f " { self . weight } *( { self . text } ) "
class ImaginePrompt ( BaseModel ) :
prompt : Optional [ List [ WeightedPrompt ] ]
negative_prompt : Optional [ List [ WeightedPrompt ] ]
prompt_strength : Optional [ float ] = 7.5
prompt : Optional [ List [ WeightedPrompt ] ] = Field ( default = None , validate_default = True )
negative_prompt : Optional [ List [ WeightedPrompt ] ] = Field (
default = None , validate_default = True
)
prompt_strength : Optional [ float ] = Field (
default = 7.5 , le = 10_000 , ge = - 10_000 , validate_default = True
)
init_image : Optional [ LazyLoadingImage ] = Field (
None , description = " base64 encoded image "
None , description = " base64 encoded image " , validate_default = True
)
init_image_strength : Optional [ float ] = Field (
ge = 0 , le = 1 , default = None , validate_default = True
)
control_inputs : List [ ControlNetInput ] = Field (
default_factory = list , validate_default = True
)
init_image_strength : Optional [ float ] = Field ( ge = 0 , le = 1 )
control_inputs : Optional [ List [ ControlNetInput ] ]
mask_prompt : Optional [ str ] = Field (
description = " text description of the things to be masked "
default = None ,
description = " text description of the things to be masked " ,
validate_default = True ,
)
mask_image : Optional [ LazyLoadingImage ]
mask_image : Optional [ LazyLoadingImage ] = Field ( default = None , validate_default = True )
mask_mode : Optional [ Literal [ " keep " , " replace " ] ] = " replace "
mask_modify_original : bool = True
outpaint : Optional [ str ]
model : str = config. DEFAULT_MODEL
model_config_path : Optional [ str ]
sampler_type : str = config. DEFAULT_SAMPLER
seed : Optional [ int ]
steps : Optional [ int ]
height : Optional [ int ] = Field ( None , ge = 1 )
width : Optional [ int ] = Field ( None , ge = 1 )
outpaint : Optional [ str ] = " "
model : str = Field( default = config. DEFAULT_MODEL , validate_default = True )
model_config_path : Optional [ str ] = None
sampler_type : str = Field( default = config. DEFAULT_SAMPLER , validate_default = True )
seed : Optional [ int ] = Field ( default = None , validate_default = True )
steps : Optional [ int ] = Field ( default = None , validate_default = True )
height : Optional [ int ] = Field ( None , ge = 1 , le = 100_000 , validate_default = True )
width : Optional [ int ] = Field ( None , ge = 1 , le = 100_000 , validate_default = True )
upscale : bool = False
fix_faces : bool = False
fix_faces_fidelity : Optional [ float ] = Field ( 0.2 , ge = 0 , le = 1 )
fix_faces_fidelity : Optional [ float ] = Field ( 0.2 , ge = 0 , le = 1 , validate_default = True )
conditioning : Optional [ str ] = None
tile_mode : str = " "
allow_compose_phase : bool = True
is_intermediate : bool = False
collect_progress_latents : bool = False
caption_text : str = Field ( " " , description = " text to be overlaid on the image " )
caption_text : str = Field (
" " , description = " text to be overlaid on the image " , validate_default = True
)
class MaskMode :
REPLACE = " replace "
@ -199,108 +298,150 @@ class ImaginePrompt(BaseModel):
# allows `prompt` to be positional
super ( ) . __init__ ( prompt = prompt , * * kwargs )
@validator ( " prompt " , " negative_prompt " , pre = True , always = True )
@field_validator ( " prompt " , " negative_prompt " , mode = " before " )
@classmethod
def make_into_weighted_prompts ( cls , v ) :
# if isinstance(v, list):
# v = [WeightedPrompt.parse_obj(p) if isinstance(p, dict) else p for p in v]
if isinstance ( v , str ) :
v = [ WeightedPrompt ( text = v ) ]
elif isinstance ( v , WeightedPrompt ) :
v = [ v ]
return v
@validator ( " prompt " , " negative_prompt " , always = True )
@field_validator ( " prompt " , " negative_prompt " , mode = " after " )
@classmethod
def must_have_some_weight ( cls , v ) :
if v :
total_weight = sum ( p . weight for p in v )
if total_weight == 0 :
raise ValueError ( " Total weight of prompts cannot be 0 " )
return v
@field_validator ( " prompt " , " negative_prompt " , mode = " after " )
def sort_prompts ( cls , v ) :
if isinstance ( v , list ) :
v . sort ( key = lambda p : p . weight , reverse = True )
return v
@validator ( " negative_prompt " , always = True )
def validate_negative_prompt ( cls , v , values ) :
if not v :
model_config = config . MODEL_CONFIG_SHORTCUTS . get ( v , None )
@ model_ validator( mode = " af ter" )
def validate_negative_prompt ( self ) :
if self . negative_prompt is None :
model_config = config . MODEL_CONFIG_SHORTCUTS . get ( self . model , None )
if model_config :
v = [ WeightedPrompt ( text = model_config . default_negative_prompt ) ]
self . negative_prompt = [
WeightedPrompt ( text = model_config . default_negative_prompt )
]
else :
v = [ WeightedPrompt ( text = config . DEFAULT_NEGATIVE_PROMPT ) ]
self . negative_prompt = [
WeightedPrompt ( text = config . DEFAULT_NEGATIVE_PROMPT )
]
return self
return v
@validator ( " prompt_strength " , always = True )
@field_validator ( " prompt_strength " )
def validate_prompt_strength ( cls , v ) :
return 7.5 if v is None else v
@ validator( " tile_mode " , always= True , pre = True )
@ field_ validator( " tile_mode " , mode= " before " )
def validate_tile_mode ( cls , v ) :
valid_tile_modes = ( " " , " x " , " y " , " xy " )
if v is True :
return " xy "
if v is False :
if v is False or v is None :
return " "
if not isinstance ( v , str ) :
raise ValueError (
f " Invalid tile_mode: ' { v } ' . Valid modes are: { valid_tile_modes } "
)
v = v . lower ( )
assert v in ( " " , " x " , " y " , " xy " )
if v not in valid_tile_modes :
raise ValueError (
f " Invalid tile_mode: ' { v } ' . Valid modes are: { valid_tile_modes } "
)
return v
@validator ( " init_image " , " mask_image " , always = True )
def handle_images ( cls , v ) :
if isinstance ( v , str ) :
return LazyLoadingImage ( filepath = v )
@field_validator ( " outpaint " , mode = " after " )
def validate_outpaint ( cls , v ) :
from imaginairy . outpaint import outpaint_arg_str_parse
outpaint_arg_str_parse ( v )
return v
@validator ( " init_image " , always = True )
def set_init_from_control_inputs ( cls , v , values ) :
if v is None and values . get ( " control_inputs " ) :
for control_input in values [ " control_inputs " ] :
@field_validator ( " conditioning " , mode = " after " )
def validate_conditioning ( cls , v ) :
from torch import Tensor
if v is None :
return v
if not isinstance ( v , Tensor ) :
raise ValueError ( " conditioning must be a torch.Tensor " )
return v
# @field_validator("init_image", "mask_image", mode="after")
# def handle_images(cls, v):
# if isinstance(v, str):
# return LazyLoadingImage(filepath=v)
#
# return v
@model_validator ( mode = " after " )
def set_init_from_control_inputs ( self ) :
if self . init_image is None :
for control_input in self . control_inputs :
if control_input . image :
return control_input . image
self . init_image = control_input . image
break
return self
@field_validator ( " control_inputs " , mode = " before " )
def validate_control_inputs ( cls , v ) :
if v is None :
v = [ ]
return v
@validator ( " control_inputs " , always = True )
def set_image_from_init_image ( cls , v , values ) :
@ field_ validator( " control_inputs " , mode= " after " )
def set_image_from_init_image ( cls , v , info: FieldValidationInfo ) :
v = v or [ ]
for control_input in v :
print ( control_input )
if control_input . image is None and control_input . image_raw is None :
control_input . image = values [ " init_image " ]
control_input . image = info. data [ " init_image " ]
return v
@validator ( " mask_image " , always = True )
def validate_mask_image ( cls , v , values ) :
if v is not None and values [ " mask_prompt " ] is not None :
@ field_ validator( " mask_image " )
def validate_mask_image ( cls , v , info: FieldValidationInfo ) :
if v is not None and info. data . get ( " mask_prompt " ) is not None :
raise ValueError ( " You can only set one of `mask_image` and `mask_prompt` " )
return v
@validator ( " mask_prompt " , always = True )
def validate_mask_prompt ( cls , v , values ) :
if values [ " init_image " ] is None and v :
raise ValueError (
" You must set `init_image` if you want to use `mask_prompt` "
)
@field_validator ( " mask_prompt " , " mask_image " , mode = " before " )
def validate_mask_prompt ( cls , v , info : FieldValidationInfo ) :
if info . data . get ( " init_image " ) is None and v :
raise ValueError ( " You must set `init_image` if you want to use a mask " )
return v
@ validator( " model " , always= True )
@ field_ validator( " model " , mode= " before " )
def set_default_diffusion_model ( cls , v ) :
if v is None :
return config . DEFAULT_MODEL
return v
@ validator( " seed " , always = True )
@ field_ validator( " seed " )
def validate_seed ( cls , v ) :
return v
@ validator( " fix_faces_fidelity " , always= True )
@ field_ validator( " fix_faces_fidelity " , mode= " before " )
def validate_fix_faces_fidelity ( cls , v ) :
if v is None :
return 0.2
return v
@ validator( " sampler_type " , pre= True , always = True )
def validate_sampler_type ( cls , v , values ) :
@ field_ validator( " sampler_type " , mode= " after " )
def validate_sampler_type ( cls , v , info: FieldValidationInfo ) :
from imaginairy . samplers import SamplerName
if v is None :
@ -308,10 +449,10 @@ class ImaginePrompt(BaseModel):
v = v . lower ( )
if values[ " model " ] == " SD-2.0-v " and v == SamplerName . PLMS :
if info. data . get ( " model " ) == " SD-2.0-v " and v == SamplerName . PLMS :
raise ValueError ( " PLMS sampler is not supported for SD-2.0-v model. " )
if values[ " model " ] == " edit " and v in (
if info. data . get ( " model " ) == " edit " and v in (
SamplerName . PLMS ,
SamplerName . DDIM ,
) :
@ -320,43 +461,39 @@ class ImaginePrompt(BaseModel):
)
return v
@ validator( " steps " , always = True )
def validate_steps ( cls , v , values ) :
@ field_ validator( " steps " )
def validate_steps ( cls , v , info: FieldValidationInfo ) :
from imaginairy . samplers import SAMPLER_LOOKUP
if v is None :
SamplerCls = SAMPLER_LOOKUP [ values [ " sampler_type " ] ]
SamplerCls = SAMPLER_LOOKUP [ info. data [ " sampler_type " ] ]
v = SamplerCls . default_steps
return int ( v )
@validator ( " init_image_strength " , always = True )
def validate_init_image_strength ( cls , v , values ) :
if v is None :
if values . get ( " control_inputs " ) :
v = 0.0
elif (
values . get ( " outpaint " )
or values . get ( " mask_image " )
or values . get ( " mask_prompt " )
) :
v = 0.0
@model_validator ( mode = " after " )
def validate_init_image_strength ( self ) :
if self . init_image_strength is None :
if self . control_inputs :
self . init_image_strength = 0.0
elif self . outpaint or self . mask_image or self . mask_prompt :
self . init_image_strength = 0.0
else :
v = 0.2
self . init_image_strength = 0.2
return v
return self
@ validator( " height " , " width " , always = True )
def validate_image_size ( cls , v , values ) :
@ field_ validator( " height " , " width " )
def validate_image_size ( cls , v , info: FieldValidationInfo ) :
from imaginairy . model_manager import get_model_default_image_size
if v is None :
v = get_model_default_image_size ( values [ " model " ] )
v = get_model_default_image_size ( info. data [ " model " ] )
return v
@ validator( " caption_text " , pre= True , always = True )
def validate_caption_text ( cls , v , values ):
@ field_ validator( " caption_text " , mode= " before " )
def validate_caption_text ( cls , v ):
if v is None :
v = " "
@ -391,7 +528,7 @@ class ImaginePrompt(BaseModel):
def logging_dict ( self ) :
""" Return a dict of the object but with binary data replaced with reprs. """
data = self . dict ( )
data = self . model_dump ( )
data [ " init_image " ] = repr ( self . init_image )
data [ " mask_image " ] = repr ( self . mask_image )
if self . control_inputs :
@ -399,17 +536,12 @@ class ImaginePrompt(BaseModel):
return data
def full_copy ( self , deep = True , update = None ) :
new_prompt = self . copy(
new_prompt = self . model_ copy(
deep = deep ,
update = update ,
)
new_prompt = new_prompt . validate (
dict (
new_prompt . _iter ( # noqa
to_dict = False , by_alias = False , exclude_unset = True
)
)
)
# new_prompt = self.model_validate(new_prompt) doesn't work for some reason https://github.com/pydantic/pydantic/issues/7387
new_prompt = new_prompt . model_validate ( dict ( new_prompt ) )
return new_prompt
def make_concrete_copy ( self ) :