mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
ci: add type checker
fix some typehint issues
This commit is contained in:
parent
e898e3a799
commit
eae4f20ae2
3
Makefile
3
Makefile
@ -40,6 +40,9 @@ lint: ## Run the code linter.
|
||||
@ruff check --config tests/ruff.toml .
|
||||
@echo -e "No linting errors - well done! ✨ 🍰 ✨"
|
||||
|
||||
type-check: ## Run the type checker.
|
||||
@mypy --config-file tox.ini .
|
||||
|
||||
deploy: ## Deploy the package to pypi.org
|
||||
pip install twine wheel
|
||||
-git tag $$(python setup.py -V)
|
||||
|
@ -279,7 +279,7 @@ def _generate_single_image_compvis(
|
||||
if control_inputs:
|
||||
control_modes = [c.mode for c in prompt.control_inputs]
|
||||
if inpaint_method == "auto":
|
||||
if prompt.model in {"SD-1.5", "SD-2.0"}:
|
||||
if prompt.model_weights in {"SD-1.5", "SD-2.0"}:
|
||||
inpaint_method = "finetune"
|
||||
else:
|
||||
inpaint_method = "controlnet"
|
||||
@ -287,8 +287,8 @@ def _generate_single_image_compvis(
|
||||
if for_inpainting and inpaint_method == "controlnet":
|
||||
control_modes.append("inpaint")
|
||||
model = get_diffusion_model(
|
||||
weights_location=prompt.model,
|
||||
config_path=prompt.model_config_path,
|
||||
weights_location=prompt.model_weights,
|
||||
config_path=prompt.model_architecture,
|
||||
control_weights_locations=control_modes,
|
||||
half_mode=half_mode,
|
||||
for_inpainting=for_inpainting and inpaint_method == "finetune",
|
||||
@ -548,14 +548,16 @@ def _generate_single_image_compvis(
|
||||
prompt=prompt,
|
||||
target_height=init_image.height,
|
||||
target_width=init_image.width,
|
||||
cutoff=get_model_default_image_size(prompt.model_architecture),
|
||||
cutoff=get_model_default_image_size(
|
||||
prompt.model_weights.model_architecture
|
||||
),
|
||||
)
|
||||
else:
|
||||
comp_image = _generate_composition_image(
|
||||
prompt=prompt,
|
||||
target_height=prompt.height,
|
||||
target_width=prompt.width,
|
||||
cutoff=get_model_default_image_size(prompt.model),
|
||||
cutoff=get_model_default_image_size(prompt.model_architecture),
|
||||
)
|
||||
if comp_image is not None:
|
||||
result_images["composition"] = comp_image
|
||||
@ -637,15 +639,16 @@ def _generate_single_image_compvis(
|
||||
caption_text = prompt.caption_text.format(prompt=prompt.prompt_text)
|
||||
add_caption_to_image(gen_img, caption_text)
|
||||
|
||||
result_images["upscaled"] = upscaled_img
|
||||
result_images["modified_original"] = rebuilt_orig_img
|
||||
result_images["mask_binary"] = mask_image_orig
|
||||
result_images["mask_grayscale"] = mask_grayscale
|
||||
|
||||
result = ImagineResult(
|
||||
img=gen_img,
|
||||
prompt=prompt,
|
||||
upscaled_img=upscaled_img,
|
||||
is_nsfw=safety_score.is_nsfw,
|
||||
safety_score=safety_score,
|
||||
modified_original=rebuilt_orig_img,
|
||||
mask_binary=mask_image_orig,
|
||||
mask_grayscale=mask_grayscale,
|
||||
result_images=result_images,
|
||||
timings=lc.get_timings(),
|
||||
progress_latents=progress_latents.copy(),
|
||||
|
@ -305,7 +305,7 @@ def _generate_single_image(
|
||||
sd.scheduler.to(device=sd.device, dtype=sd.dtype)
|
||||
sd.set_num_inference_steps(prompt.steps)
|
||||
|
||||
if hasattr(sd, "mask_latents"):
|
||||
if hasattr(sd, "mask_latents") and mask_image is not None:
|
||||
sd.set_inpainting_conditions(
|
||||
target_image=init_image,
|
||||
mask=ImageOps.invert(mask_image),
|
||||
|
@ -175,9 +175,9 @@ MODEL_WEIGHT_CONFIGS = [
|
||||
]
|
||||
|
||||
MODEL_WEIGHT_CONFIG_LOOKUP = {}
|
||||
for m in MODEL_WEIGHT_CONFIGS:
|
||||
for a in m.aliases:
|
||||
MODEL_WEIGHT_CONFIG_LOOKUP[a] = m
|
||||
for mw in MODEL_WEIGHT_CONFIGS:
|
||||
for a in mw.aliases:
|
||||
MODEL_WEIGHT_CONFIG_LOOKUP[a] = mw
|
||||
|
||||
|
||||
IMAGE_WEIGHTS_SHORT_NAMES = [
|
||||
@ -272,9 +272,9 @@ CONTROL_CONFIGS = [
|
||||
]
|
||||
|
||||
CONTROL_CONFIG_SHORTCUTS: dict[str, ControlConfig] = {}
|
||||
for m in CONTROL_CONFIGS:
|
||||
for a in m.aliases:
|
||||
CONTROL_CONFIG_SHORTCUTS[a] = m
|
||||
for cc in CONTROL_CONFIGS:
|
||||
for ca in cc.aliases:
|
||||
CONTROL_CONFIG_SHORTCUTS[ca] = cc
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -48,11 +48,11 @@ def pillow_fit_image_within(
|
||||
def pillow_img_to_torch_image(img: PIL.Image.Image, convert="RGB"):
|
||||
if convert:
|
||||
img = img.convert(convert)
|
||||
img = np.array(img).astype(np.float32) / 255.0
|
||||
img_np = np.array(img).astype(np.float32) / 255.0
|
||||
# b, h, w, c => b, c, h, w
|
||||
img = img[None].transpose(0, 3, 1, 2)
|
||||
img = torch.from_numpy(img)
|
||||
return 2.0 * img - 1.0
|
||||
img_np = img_np[None].transpose(0, 3, 1, 2)
|
||||
img_t = torch.from_numpy(img_np)
|
||||
return 2.0 * img_t - 1.0
|
||||
|
||||
|
||||
def pillow_mask_to_latent_mask(mask_img: PIL.Image.Image, downsampling_factor):
|
||||
@ -77,17 +77,17 @@ def pillow_img_to_opencv_img(img: PIL.Image.Image):
|
||||
return open_cv_image
|
||||
|
||||
|
||||
def torch_image_to_openvcv_img(img: torch.Tensor):
|
||||
def torch_image_to_openvcv_img(img: torch.Tensor) -> np.ndarray:
|
||||
img = (img + 1) / 2
|
||||
img = img.detach().cpu().numpy()
|
||||
img_np = img.detach().cpu().numpy()
|
||||
# assert there is only one image
|
||||
assert img.shape[0] == 1
|
||||
img = img[0]
|
||||
img = img.transpose(1, 2, 0)
|
||||
img = (img * 255).astype(np.uint8)
|
||||
assert img_np.shape[0] == 1
|
||||
img_np = img_np[0]
|
||||
img_np = img_np.transpose(1, 2, 0)
|
||||
img_np = (img_np * 255).astype(np.uint8)
|
||||
# RGB to BGR
|
||||
img = img[:, :, ::-1]
|
||||
return img
|
||||
img_np = img_np[:, :, ::-1]
|
||||
return img_np
|
||||
|
||||
|
||||
def torch_img_to_pillow_img(img_t: torch.Tensor):
|
||||
|
@ -1,3 +1,5 @@
|
||||
# type: ignore
|
||||
|
||||
"""
|
||||
wild mixture of
|
||||
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
|
@ -9,7 +9,7 @@ import random
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
from typing import TYPE_CHECKING, Any, List, cast
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@ -43,9 +43,9 @@ class LazyLoadingImage:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
filepath=None,
|
||||
url=None,
|
||||
img: "Image.Image" = None,
|
||||
filepath: str | None = None,
|
||||
url: str | None = None,
|
||||
img: "Image.Image | None" = None,
|
||||
b64: str | None = None,
|
||||
):
|
||||
if not filepath and not url and not img and not b64:
|
||||
@ -255,8 +255,10 @@ PromptInput = str | WeightedPrompt | list[WeightedPrompt] | list[str] | None
|
||||
class ImaginePrompt(BaseModel, protected_namespaces=()):
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
||||
|
||||
prompt: List[WeightedPrompt] = Field(default=None, validate_default=True)
|
||||
negative_prompt: List[WeightedPrompt] = Field(default=None, validate_default=True)
|
||||
prompt: List[WeightedPrompt] = Field(default=None, validate_default=True) # type: ignore
|
||||
negative_prompt: List[WeightedPrompt] = Field(
|
||||
default_factory=list, validate_default=True
|
||||
)
|
||||
prompt_strength: float = Field(default=7.5, le=50, ge=-50, validate_default=True)
|
||||
init_image: LazyLoadingImage | None = Field(
|
||||
None, description="base64 encoded image", validate_default=True
|
||||
@ -282,8 +284,8 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
||||
)
|
||||
solver_type: str = Field(default=config.DEFAULT_SOLVER, validate_default=True)
|
||||
seed: int | None = Field(default=None, validate_default=True)
|
||||
steps: int | None = Field(default=None, validate_default=True)
|
||||
size: tuple[int, int] | None = Field(default=None, validate_default=True)
|
||||
steps: int = Field(validate_default=True)
|
||||
size: tuple[int, int] = Field(validate_default=True)
|
||||
upscale: bool = False
|
||||
fix_faces: bool = False
|
||||
fix_faces_fidelity: float | None = Field(0.2, ge=0, le=1, validate_default=True)
|
||||
@ -373,9 +375,9 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
||||
return [value]
|
||||
case list():
|
||||
if all(isinstance(item, str) for item in value):
|
||||
return [WeightedPrompt(text=p) for p in value]
|
||||
return [WeightedPrompt(text=str(p)) for p in value]
|
||||
elif all(isinstance(item, WeightedPrompt) for item in value):
|
||||
return value
|
||||
return cast(List[WeightedPrompt], value)
|
||||
raise ValueError("Invalid prompt input")
|
||||
|
||||
@field_validator("prompt", "negative_prompt", mode="after")
|
||||
@ -532,7 +534,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
|
||||
raise ValueError(msg)
|
||||
return v
|
||||
|
||||
@field_validator("steps")
|
||||
@field_validator("steps", mode="before")
|
||||
def validate_steps(cls, v, info: core_schema.FieldValidationInfo):
|
||||
steps_lookup = {"ddim": 50, "dpmpp": 20}
|
||||
|
||||
|
@ -4,7 +4,7 @@ import platform
|
||||
import time
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import lru_cache
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, autocast
|
||||
@ -62,7 +62,7 @@ def get_obj_from_str(import_path: str, reload=False) -> Any:
|
||||
return getattr(module, obj_name)
|
||||
|
||||
|
||||
def instantiate_from_config(config: Union[dict, str]) -> Any:
|
||||
def instantiate_from_config(config: dict) -> Any:
|
||||
"""Instantiate an object from a config dict."""
|
||||
if "target" not in config:
|
||||
if config == "__is_first_stage__":
|
||||
@ -70,6 +70,7 @@ def instantiate_from_config(config: Union[dict, str]) -> Any:
|
||||
if config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
assert isinstance(config, dict)
|
||||
params = config.get("params", {})
|
||||
_cls = get_obj_from_str(config["target"])
|
||||
start = time.perf_counter()
|
||||
|
@ -49,31 +49,28 @@ _NAMED_RESOLUTIONS = {k.upper(): v for k, v in _NAMED_RESOLUTIONS.items()}
|
||||
def normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int, int]:
|
||||
match resolution:
|
||||
case (int(), int()):
|
||||
size = resolution
|
||||
return resolution # type: ignore
|
||||
case int():
|
||||
size = resolution, resolution
|
||||
return resolution, resolution
|
||||
case str():
|
||||
resolution = resolution.strip().upper()
|
||||
resolution = resolution.replace(" ", "").replace("X", ",").replace("*", ",")
|
||||
size = _NAMED_RESOLUTIONS.get(resolution.upper())
|
||||
if size is None:
|
||||
# is it WIDTH,HEIGHT format?
|
||||
try:
|
||||
width, height = resolution.split(",")
|
||||
size = int(width), int(height)
|
||||
except ValueError:
|
||||
pass
|
||||
if size is None:
|
||||
# is it just a single number?
|
||||
with contextlib.suppress(ValueError):
|
||||
size = (int(resolution), int(resolution))
|
||||
if size is None:
|
||||
msg = f"Invalid resolution: '{resolution}'"
|
||||
raise ValueError(msg)
|
||||
if resolution.upper() in _NAMED_RESOLUTIONS:
|
||||
return _NAMED_RESOLUTIONS[resolution.upper()]
|
||||
|
||||
# is it WIDTH,HEIGHT format?
|
||||
try:
|
||||
width, height = resolution.split(",")
|
||||
return int(width), int(height)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# is it just a single number?
|
||||
with contextlib.suppress(ValueError):
|
||||
return int(resolution), int(resolution)
|
||||
|
||||
msg = f"Invalid resolution: '{resolution}'"
|
||||
raise ValueError(msg)
|
||||
case _:
|
||||
msg = f"Invalid resolution: {resolution!r}"
|
||||
raise ValueError(msg)
|
||||
if size[0] <= 0 or size[1] <= 0:
|
||||
msg = f"Invalid resolution: {resolution!r}"
|
||||
raise ValueError(msg)
|
||||
return size
|
||||
|
@ -330,20 +330,20 @@ def get_batch(keys, value_dict, N, T, device):
|
||||
def load_model(
|
||||
config: str, device: str, num_frames: int, num_steps: int, weights_url: str
|
||||
):
|
||||
config = OmegaConf.load(config)
|
||||
oconfig = OmegaConf.load(config)
|
||||
ckpt_path = get_cached_url_path(weights_url)
|
||||
config["model"]["params"]["ckpt_path"] = ckpt_path
|
||||
oconfig["model"]["params"]["ckpt_path"] = ckpt_path
|
||||
if device == "cuda":
|
||||
config.model.params.conditioner_config.params.emb_models[
|
||||
oconfig.model.params.conditioner_config.params.emb_models[
|
||||
0
|
||||
].params.open_clip_embedding_config.params.init_device = device
|
||||
|
||||
config.model.params.sampler_config.params.num_steps = num_steps
|
||||
config.model.params.sampler_config.params.guider_config.params.num_frames = (
|
||||
oconfig.model.params.sampler_config.params.num_steps = num_steps
|
||||
oconfig.model.params.sampler_config.params.guider_config.params.num_frames = (
|
||||
num_frames
|
||||
)
|
||||
|
||||
model = instantiate_from_config(config.model).to(device).half().eval()
|
||||
model = instantiate_from_config(oconfig.model).to(device).half().eval()
|
||||
|
||||
# safety_filter = DeepFloydDataFiltering(verbose=False, device=device)
|
||||
def safety_filter(x):
|
||||
|
@ -89,7 +89,7 @@ class WeightMap:
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def load_state_dict_conversion_maps():
|
||||
def load_state_dict_conversion_maps() -> dict[str, dict]:
|
||||
import json
|
||||
|
||||
conversion_maps = {}
|
||||
@ -102,7 +102,11 @@ def load_state_dict_conversion_maps():
|
||||
|
||||
|
||||
def cast_weights(
|
||||
source_weights, source_model_name, source_component_name, source_format, dest_format
|
||||
source_weights,
|
||||
source_model_name: str,
|
||||
source_component_name: str,
|
||||
source_format: str,
|
||||
dest_format: str,
|
||||
):
|
||||
weight_map = WeightMap(
|
||||
model_name=source_model_name,
|
||||
|
@ -1,10 +1,15 @@
|
||||
black
|
||||
coverage
|
||||
mypy
|
||||
ruff
|
||||
pytest
|
||||
pytest-randomly
|
||||
pytest-sugar
|
||||
responses
|
||||
types-pillow
|
||||
types-psutil
|
||||
types-requests
|
||||
types-tqdm
|
||||
wheel
|
||||
|
||||
-c tests/constraints.txt
|
@ -71,7 +71,7 @@ frozenlist==1.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2023.12.0
|
||||
fsspec[http]==2023.12.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# pytorch-lightning
|
||||
@ -126,8 +126,12 @@ multidict==6.0.4
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
mypy==1.7.1
|
||||
# via -r requirements-dev.in
|
||||
mypy-extensions==1.0.0
|
||||
# via black
|
||||
# via
|
||||
# black
|
||||
# mypy
|
||||
networkx==3.2.1
|
||||
# via torch
|
||||
numba==0.58.1
|
||||
@ -172,7 +176,7 @@ packaging==23.2
|
||||
# pytorch-lightning
|
||||
# torchmetrics
|
||||
# transformers
|
||||
pathspec==0.11.2
|
||||
pathspec==0.12.1
|
||||
# via black
|
||||
pillow==10.1.0
|
||||
# via
|
||||
@ -276,6 +280,7 @@ tokenizers==0.15.0
|
||||
tomli==2.0.1
|
||||
# via
|
||||
# black
|
||||
# mypy
|
||||
# pytest
|
||||
torch==2.1.1
|
||||
# via
|
||||
@ -314,13 +319,22 @@ transformers==4.35.2
|
||||
# via imaginAIry (setup.py)
|
||||
typeguard==2.13.3
|
||||
# via jaxtyping
|
||||
typing-extensions==4.8.0
|
||||
types-pillow==10.1.0.2
|
||||
# via -r requirements-dev.in
|
||||
types-psutil==5.9.5.17
|
||||
# via -r requirements-dev.in
|
||||
types-requests==2.31.0.10
|
||||
# via -r requirements-dev.in
|
||||
types-tqdm==4.66.0.5
|
||||
# via -r requirements-dev.in
|
||||
typing-extensions==4.9.0
|
||||
# via
|
||||
# black
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# jaxtyping
|
||||
# lightning-utilities
|
||||
# mypy
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pytorch-lightning
|
||||
@ -330,13 +344,14 @@ urllib3==2.1.0
|
||||
# via
|
||||
# requests
|
||||
# responses
|
||||
# types-requests
|
||||
uvicorn==0.24.0.post1
|
||||
# via imaginAIry (setup.py)
|
||||
wcwidth==0.2.12
|
||||
# via ftfy
|
||||
wheel==0.42.0
|
||||
# via -r requirements-dev.in
|
||||
yarl==1.9.3
|
||||
yarl==1.9.4
|
||||
# via aiohttp
|
||||
zipp==3.17.0
|
||||
# via importlib-metadata
|
||||
|
5
tox.ini
5
tox.ini
@ -4,3 +4,8 @@ norecursedirs = build dist downloads other prolly_delete imaginairy/vendored
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::UserWarning
|
||||
|
||||
[mypy]
|
||||
plugins = pydantic.mypy
|
||||
exclude = ^(downloads|dist|other|testing_support|imaginairy/vendored|imaginairy/modules/sgm)/.*
|
||||
ignore_missing_imports = True
|
||||
|
Loading…
Reference in New Issue
Block a user