|
|
|
@ -9,7 +9,7 @@ import random
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
from typing import TYPE_CHECKING, Any, List, cast
|
|
|
|
|
from typing import TYPE_CHECKING, Any, List, Literal, cast
|
|
|
|
|
|
|
|
|
|
from pydantic import (
|
|
|
|
|
BaseModel,
|
|
|
|
@ -250,6 +250,7 @@ class MaskMode(str, Enum):
|
|
|
|
|
|
|
|
|
|
MaskInput = MaskMode | str
|
|
|
|
|
PromptInput = str | WeightedPrompt | list[WeightedPrompt] | list[str] | None
|
|
|
|
|
InpaintMethod = Literal["finetune", "control"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImaginePrompt(BaseModel, protected_namespaces=()):
|
|
|
|
@ -278,8 +279,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
|
|
|
|
mask_mode: MaskMode = MaskMode.REPLACE
|
|
|
|
|
mask_modify_original: bool = True
|
|
|
|
|
outpaint: str | None = ""
|
|
|
|
|
model_architecture: str | None = None
|
|
|
|
|
model_weights: str = Field(
|
|
|
|
|
model_weights: config.ModelWeightsConfig = Field(
|
|
|
|
|
default=config.DEFAULT_MODEL_WEIGHTS, validate_default=True
|
|
|
|
|
)
|
|
|
|
|
solver_type: str = Field(default=config.DEFAULT_SOLVER, validate_default=True)
|
|
|
|
@ -297,6 +297,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
|
|
|
|
caption_text: str = Field(
|
|
|
|
|
"", description="text to be overlaid on the image", validate_default=True
|
|
|
|
|
)
|
|
|
|
|
inpaint_method: InpaintMethod = "finetune"
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
@ -312,8 +313,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
|
|
|
|
mask_mode: MaskInput = MaskMode.REPLACE,
|
|
|
|
|
mask_modify_original: bool = True,
|
|
|
|
|
outpaint: str | None = "",
|
|
|
|
|
model_architecture: str | None = None,
|
|
|
|
|
model_weights: str = config.DEFAULT_MODEL_WEIGHTS,
|
|
|
|
|
model_weights: str | config.ModelWeightsConfig = config.DEFAULT_MODEL_WEIGHTS,
|
|
|
|
|
solver_type: str = config.DEFAULT_SOLVER,
|
|
|
|
|
seed: int | None = None,
|
|
|
|
|
steps: int | None = None,
|
|
|
|
@ -327,6 +327,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
|
|
|
|
is_intermediate: bool = False,
|
|
|
|
|
collect_progress_latents: bool = False,
|
|
|
|
|
caption_text: str = "",
|
|
|
|
|
inpaint_method: InpaintMethod = "finetune",
|
|
|
|
|
):
|
|
|
|
|
super().__init__(
|
|
|
|
|
prompt=prompt,
|
|
|
|
@ -340,7 +341,6 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
|
|
|
|
mask_mode=mask_mode,
|
|
|
|
|
mask_modify_original=mask_modify_original,
|
|
|
|
|
outpaint=outpaint,
|
|
|
|
|
model_architecture=model_architecture,
|
|
|
|
|
model_weights=model_weights,
|
|
|
|
|
solver_type=solver_type,
|
|
|
|
|
seed=seed,
|
|
|
|
@ -355,6 +355,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
|
|
|
|
is_intermediate=is_intermediate,
|
|
|
|
|
collect_progress_latents=collect_progress_latents,
|
|
|
|
|
caption_text=caption_text,
|
|
|
|
|
inpaint_method=inpaint_method,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@field_validator("prompt", "negative_prompt", mode="before")
|
|
|
|
@ -398,12 +399,9 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
|
|
|
|
@model_validator(mode="after")
|
|
|
|
|
def validate_negative_prompt(self):
|
|
|
|
|
if self.negative_prompt == []:
|
|
|
|
|
model_weight_config = config.MODEL_WEIGHT_CONFIG_LOOKUP.get(
|
|
|
|
|
self.model_weights, None
|
|
|
|
|
)
|
|
|
|
|
default_negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
|
|
|
|
|
if model_weight_config:
|
|
|
|
|
default_negative_prompt = model_weight_config.defaults.get(
|
|
|
|
|
if self.model_weights:
|
|
|
|
|
default_negative_prompt = self.model_weights.defaults.get(
|
|
|
|
|
"negative_prompt", default_negative_prompt
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -496,12 +494,30 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
@field_validator("model_weights", mode="before")
|
|
|
|
|
def set_default_diffusion_model(cls, v):
|
|
|
|
|
if v is None:
|
|
|
|
|
return config.DEFAULT_MODEL_WEIGHTS
|
|
|
|
|
@model_validator(mode="before")
|
|
|
|
|
def resolve_model_weights(cls, data: Any):
|
|
|
|
|
if not isinstance(data, dict):
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
return v
|
|
|
|
|
model_weights = data.get("model_weights")
|
|
|
|
|
if model_weights is None:
|
|
|
|
|
model_weights = config.DEFAULT_MODEL_WEIGHTS
|
|
|
|
|
from imaginairy.model_manager import resolve_model_weights_config
|
|
|
|
|
|
|
|
|
|
should_use_inpainting = (
|
|
|
|
|
data.get("mask_image") or data.get("mask_prompt") or data.get("outpaint")
|
|
|
|
|
)
|
|
|
|
|
should_use_inpainting_weights = (
|
|
|
|
|
should_use_inpainting and data.get("inpaint_method") == "finetune"
|
|
|
|
|
)
|
|
|
|
|
model_weights_config = resolve_model_weights_config(
|
|
|
|
|
model_weights=model_weights,
|
|
|
|
|
default_model_architecture=None,
|
|
|
|
|
for_inpainting=should_use_inpainting_weights,
|
|
|
|
|
)
|
|
|
|
|
data["model_weights"] = model_weights_config
|
|
|
|
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
@field_validator("seed")
|
|
|
|
|
def validate_seed(cls, v):
|
|
|
|
@ -564,9 +580,15 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
|
|
|
|
from imaginairy.utils.named_resolutions import normalize_image_size
|
|
|
|
|
|
|
|
|
|
if v is None:
|
|
|
|
|
v = get_model_default_image_size(info.data["model_architecture"])
|
|
|
|
|
v = get_model_default_image_size(info.data["model_weights"].architecture)
|
|
|
|
|
|
|
|
|
|
width, height = normalize_image_size(v)
|
|
|
|
|
|
|
|
|
|
return width, height
|
|
|
|
|
|
|
|
|
|
@field_validator("size", mode="after")
|
|
|
|
|
def validate_image_size_after(cls, v, info: core_schema.FieldValidationInfo):
|
|
|
|
|
width, height = v
|
|
|
|
|
min_size = 8
|
|
|
|
|
max_size = 100_000
|
|
|
|
|
if not min_size <= width <= max_size:
|
|
|
|
@ -576,8 +598,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
|
|
|
|
if not min_size <= height <= max_size:
|
|
|
|
|
msg = f"Height must be between {min_size} and {max_size}. Got: {height}"
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
|
|
return width, height
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
@field_validator("caption_text", mode="before")
|
|
|
|
|
def validate_caption_text(cls, v):
|
|
|
|
@ -614,6 +635,18 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
|
|
|
|
def height(self) -> int:
|
|
|
|
|
return self.size[1]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def should_use_inpainting(self) -> bool:
|
|
|
|
|
return bool(self.outpaint or self.mask_image or self.mask_prompt)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def should_use_inpainting_weights(self) -> bool:
|
|
|
|
|
return self.should_use_inpainting and self.inpaint_method == "finetune"
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def model_architecture(self) -> config.ModelArchitecture:
|
|
|
|
|
return self.model_weights.architecture
|
|
|
|
|
|
|
|
|
|
def prompt_description(self):
|
|
|
|
|
return (
|
|
|
|
|
f'"{self.prompt_text}" {self.width}x{self.height}px '
|
|
|
|
@ -622,8 +655,8 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
|
|
|
|
f"prompt-strength:{self.prompt_strength} "
|
|
|
|
|
f"steps:{self.steps} solver-type:{self.solver_type} "
|
|
|
|
|
f"init-image-strength:{self.init_image_strength} "
|
|
|
|
|
f"arch:{self.model_architecture} "
|
|
|
|
|
f"weights: {self.model_weights}"
|
|
|
|
|
f"arch:{self.model_architecture.aliases[0]} "
|
|
|
|
|
f"weights: {self.model_weights.aliases[0]}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def logging_dict(self):
|
|
|
|
@ -730,8 +763,8 @@ class ImagineResult:
|
|
|
|
|
exif[ExifCodes.ImageDescription] = self.prompt.prompt_description()
|
|
|
|
|
exif[ExifCodes.UserComment] = json.dumps(self.metadata_dict())
|
|
|
|
|
# help future web scrapes not ingest AI generated art
|
|
|
|
|
sd_version = self.prompt.model_weights
|
|
|
|
|
if len(sd_version) > 20:
|
|
|
|
|
sd_version = self.prompt.model_weights.name
|
|
|
|
|
if len(sd_version) > 40:
|
|
|
|
|
sd_version = "custom weights"
|
|
|
|
|
exif[ExifCodes.Software] = f"Imaginairy / Stable Diffusion {sd_version}"
|
|
|
|
|
exif[ExifCodes.DateTime] = self.created_at.isoformat(sep=" ")[:19]
|
|
|
|
|