ci: add type checker

fix some typehint issues
This commit is contained in:
Bryce 2023-12-10 18:29:47 -08:00 committed by Bryce Drennan
parent e898e3a799
commit eae4f20ae2
14 changed files with 112 additions and 75 deletions

View File

@ -40,6 +40,9 @@ lint: ## Run the code linter.
@ruff check --config tests/ruff.toml .
@echo -e "No linting errors - well done! ✨ 🍰 ✨"
type-check: ## Run the type checker.
@mypy --config-file tox.ini .
deploy: ## Deploy the package to pypi.org
pip install twine wheel
-git tag $$(python setup.py -V)

View File

@ -279,7 +279,7 @@ def _generate_single_image_compvis(
if control_inputs:
control_modes = [c.mode for c in prompt.control_inputs]
if inpaint_method == "auto":
if prompt.model in {"SD-1.5", "SD-2.0"}:
if prompt.model_weights in {"SD-1.5", "SD-2.0"}:
inpaint_method = "finetune"
else:
inpaint_method = "controlnet"
@ -287,8 +287,8 @@ def _generate_single_image_compvis(
if for_inpainting and inpaint_method == "controlnet":
control_modes.append("inpaint")
model = get_diffusion_model(
weights_location=prompt.model,
config_path=prompt.model_config_path,
weights_location=prompt.model_weights,
config_path=prompt.model_architecture,
control_weights_locations=control_modes,
half_mode=half_mode,
for_inpainting=for_inpainting and inpaint_method == "finetune",
@ -548,14 +548,16 @@ def _generate_single_image_compvis(
prompt=prompt,
target_height=init_image.height,
target_width=init_image.width,
cutoff=get_model_default_image_size(prompt.model_architecture),
cutoff=get_model_default_image_size(
prompt.model_weights.model_architecture
),
)
else:
comp_image = _generate_composition_image(
prompt=prompt,
target_height=prompt.height,
target_width=prompt.width,
cutoff=get_model_default_image_size(prompt.model),
cutoff=get_model_default_image_size(prompt.model_architecture),
)
if comp_image is not None:
result_images["composition"] = comp_image
@ -637,15 +639,16 @@ def _generate_single_image_compvis(
caption_text = prompt.caption_text.format(prompt=prompt.prompt_text)
add_caption_to_image(gen_img, caption_text)
result_images["upscaled"] = upscaled_img
result_images["modified_original"] = rebuilt_orig_img
result_images["mask_binary"] = mask_image_orig
result_images["mask_grayscale"] = mask_grayscale
result = ImagineResult(
img=gen_img,
prompt=prompt,
upscaled_img=upscaled_img,
is_nsfw=safety_score.is_nsfw,
safety_score=safety_score,
modified_original=rebuilt_orig_img,
mask_binary=mask_image_orig,
mask_grayscale=mask_grayscale,
result_images=result_images,
timings=lc.get_timings(),
progress_latents=progress_latents.copy(),

View File

@ -305,7 +305,7 @@ def _generate_single_image(
sd.scheduler.to(device=sd.device, dtype=sd.dtype)
sd.set_num_inference_steps(prompt.steps)
if hasattr(sd, "mask_latents"):
if hasattr(sd, "mask_latents") and mask_image is not None:
sd.set_inpainting_conditions(
target_image=init_image,
mask=ImageOps.invert(mask_image),

View File

@ -175,9 +175,9 @@ MODEL_WEIGHT_CONFIGS = [
]
MODEL_WEIGHT_CONFIG_LOOKUP = {}
for m in MODEL_WEIGHT_CONFIGS:
for a in m.aliases:
MODEL_WEIGHT_CONFIG_LOOKUP[a] = m
for mw in MODEL_WEIGHT_CONFIGS:
for a in mw.aliases:
MODEL_WEIGHT_CONFIG_LOOKUP[a] = mw
IMAGE_WEIGHTS_SHORT_NAMES = [
@ -272,9 +272,9 @@ CONTROL_CONFIGS = [
]
CONTROL_CONFIG_SHORTCUTS: dict[str, ControlConfig] = {}
for m in CONTROL_CONFIGS:
for a in m.aliases:
CONTROL_CONFIG_SHORTCUTS[a] = m
for cc in CONTROL_CONFIGS:
for ca in cc.aliases:
CONTROL_CONFIG_SHORTCUTS[ca] = cc
@dataclass

View File

@ -48,11 +48,11 @@ def pillow_fit_image_within(
def pillow_img_to_torch_image(img: PIL.Image.Image, convert="RGB"):
if convert:
img = img.convert(convert)
img = np.array(img).astype(np.float32) / 255.0
img_np = np.array(img).astype(np.float32) / 255.0
# b, h, w, c => b, c, h, w
img = img[None].transpose(0, 3, 1, 2)
img = torch.from_numpy(img)
return 2.0 * img - 1.0
img_np = img_np[None].transpose(0, 3, 1, 2)
img_t = torch.from_numpy(img_np)
return 2.0 * img_t - 1.0
def pillow_mask_to_latent_mask(mask_img: PIL.Image.Image, downsampling_factor):
@ -77,17 +77,17 @@ def pillow_img_to_opencv_img(img: PIL.Image.Image):
return open_cv_image
def torch_image_to_openvcv_img(img: torch.Tensor):
def torch_image_to_openvcv_img(img: torch.Tensor) -> np.ndarray:
img = (img + 1) / 2
img = img.detach().cpu().numpy()
img_np = img.detach().cpu().numpy()
# assert there is only one image
assert img.shape[0] == 1
img = img[0]
img = img.transpose(1, 2, 0)
img = (img * 255).astype(np.uint8)
assert img_np.shape[0] == 1
img_np = img_np[0]
img_np = img_np.transpose(1, 2, 0)
img_np = (img_np * 255).astype(np.uint8)
# RGB to BGR
img = img[:, :, ::-1]
return img
img_np = img_np[:, :, ::-1]
return img_np
def torch_img_to_pillow_img(img_t: torch.Tensor):

View File

@ -1,3 +1,5 @@
# type: ignore
"""
wild mixture of
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

View File

@ -9,7 +9,7 @@ import random
from datetime import datetime, timezone
from enum import Enum
from io import BytesIO
from typing import TYPE_CHECKING, Any, List
from typing import TYPE_CHECKING, Any, List, cast
from pydantic import (
BaseModel,
@ -43,9 +43,9 @@ class LazyLoadingImage:
def __init__(
self,
*,
filepath=None,
url=None,
img: "Image.Image" = None,
filepath: str | None = None,
url: str | None = None,
img: "Image.Image | None" = None,
b64: str | None = None,
):
if not filepath and not url and not img and not b64:
@ -255,8 +255,10 @@ PromptInput = str | WeightedPrompt | list[WeightedPrompt] | list[str] | None
class ImaginePrompt(BaseModel, protected_namespaces=()):
model_config = ConfigDict(extra="forbid", validate_assignment=True)
prompt: List[WeightedPrompt] = Field(default=None, validate_default=True)
negative_prompt: List[WeightedPrompt] = Field(default=None, validate_default=True)
prompt: List[WeightedPrompt] = Field(default=None, validate_default=True) # type: ignore
negative_prompt: List[WeightedPrompt] = Field(
default_factory=list, validate_default=True
)
prompt_strength: float = Field(default=7.5, le=50, ge=-50, validate_default=True)
init_image: LazyLoadingImage | None = Field(
None, description="base64 encoded image", validate_default=True
@ -282,8 +284,8 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
)
solver_type: str = Field(default=config.DEFAULT_SOLVER, validate_default=True)
seed: int | None = Field(default=None, validate_default=True)
steps: int | None = Field(default=None, validate_default=True)
size: tuple[int, int] | None = Field(default=None, validate_default=True)
steps: int = Field(validate_default=True)
size: tuple[int, int] = Field(validate_default=True)
upscale: bool = False
fix_faces: bool = False
fix_faces_fidelity: float | None = Field(0.2, ge=0, le=1, validate_default=True)
@ -373,9 +375,9 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return [value]
case list():
if all(isinstance(item, str) for item in value):
return [WeightedPrompt(text=p) for p in value]
return [WeightedPrompt(text=str(p)) for p in value]
elif all(isinstance(item, WeightedPrompt) for item in value):
return value
return cast(List[WeightedPrompt], value)
raise ValueError("Invalid prompt input")
@field_validator("prompt", "negative_prompt", mode="after")
@ -532,7 +534,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
raise ValueError(msg)
return v
@field_validator("steps")
@field_validator("steps", mode="before")
def validate_steps(cls, v, info: core_schema.FieldValidationInfo):
steps_lookup = {"ddim": 50, "dpmpp": 20}

View File

@ -4,7 +4,7 @@ import platform
import time
from contextlib import contextmanager, nullcontext
from functools import lru_cache
from typing import Any, List, Optional, Union
from typing import Any, List, Optional
import torch
from torch import Tensor, autocast
@ -62,7 +62,7 @@ def get_obj_from_str(import_path: str, reload=False) -> Any:
return getattr(module, obj_name)
def instantiate_from_config(config: Union[dict, str]) -> Any:
def instantiate_from_config(config: dict) -> Any:
"""Instantiate an object from a config dict."""
if "target" not in config:
if config == "__is_first_stage__":
@ -70,6 +70,7 @@ def instantiate_from_config(config: Union[dict, str]) -> Any:
if config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
assert isinstance(config, dict)
params = config.get("params", {})
_cls = get_obj_from_str(config["target"])
start = time.perf_counter()

View File

@ -49,31 +49,28 @@ _NAMED_RESOLUTIONS = {k.upper(): v for k, v in _NAMED_RESOLUTIONS.items()}
def normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int, int]:
match resolution:
case (int(), int()):
size = resolution
return resolution # type: ignore
case int():
size = resolution, resolution
return resolution, resolution
case str():
resolution = resolution.strip().upper()
resolution = resolution.replace(" ", "").replace("X", ",").replace("*", ",")
size = _NAMED_RESOLUTIONS.get(resolution.upper())
if size is None:
# is it WIDTH,HEIGHT format?
try:
width, height = resolution.split(",")
size = int(width), int(height)
except ValueError:
pass
if size is None:
# is it just a single number?
with contextlib.suppress(ValueError):
size = (int(resolution), int(resolution))
if size is None:
msg = f"Invalid resolution: '{resolution}'"
raise ValueError(msg)
if resolution.upper() in _NAMED_RESOLUTIONS:
return _NAMED_RESOLUTIONS[resolution.upper()]
# is it WIDTH,HEIGHT format?
try:
width, height = resolution.split(",")
return int(width), int(height)
except ValueError:
pass
# is it just a single number?
with contextlib.suppress(ValueError):
return int(resolution), int(resolution)
msg = f"Invalid resolution: '{resolution}'"
raise ValueError(msg)
case _:
msg = f"Invalid resolution: {resolution!r}"
raise ValueError(msg)
if size[0] <= 0 or size[1] <= 0:
msg = f"Invalid resolution: {resolution!r}"
raise ValueError(msg)
return size

View File

@ -330,20 +330,20 @@ def get_batch(keys, value_dict, N, T, device):
def load_model(
config: str, device: str, num_frames: int, num_steps: int, weights_url: str
):
config = OmegaConf.load(config)
oconfig = OmegaConf.load(config)
ckpt_path = get_cached_url_path(weights_url)
config["model"]["params"]["ckpt_path"] = ckpt_path
oconfig["model"]["params"]["ckpt_path"] = ckpt_path
if device == "cuda":
config.model.params.conditioner_config.params.emb_models[
oconfig.model.params.conditioner_config.params.emb_models[
0
].params.open_clip_embedding_config.params.init_device = device
config.model.params.sampler_config.params.num_steps = num_steps
config.model.params.sampler_config.params.guider_config.params.num_frames = (
oconfig.model.params.sampler_config.params.num_steps = num_steps
oconfig.model.params.sampler_config.params.guider_config.params.num_frames = (
num_frames
)
model = instantiate_from_config(config.model).to(device).half().eval()
model = instantiate_from_config(oconfig.model).to(device).half().eval()
# safety_filter = DeepFloydDataFiltering(verbose=False, device=device)
def safety_filter(x):

View File

@ -89,7 +89,7 @@ class WeightMap:
@lru_cache(maxsize=None)
def load_state_dict_conversion_maps():
def load_state_dict_conversion_maps() -> dict[str, dict]:
import json
conversion_maps = {}
@ -102,7 +102,11 @@ def load_state_dict_conversion_maps():
def cast_weights(
source_weights, source_model_name, source_component_name, source_format, dest_format
source_weights,
source_model_name: str,
source_component_name: str,
source_format: str,
dest_format: str,
):
weight_map = WeightMap(
model_name=source_model_name,

View File

@ -1,10 +1,15 @@
black
coverage
mypy
ruff
pytest
pytest-randomly
pytest-sugar
responses
types-pillow
types-psutil
types-requests
types-tqdm
wheel
-c tests/constraints.txt

View File

@ -71,7 +71,7 @@ frozenlist==1.4.0
# via
# aiohttp
# aiosignal
fsspec[http]==2023.12.0
fsspec[http]==2023.12.1
# via
# huggingface-hub
# pytorch-lightning
@ -126,8 +126,12 @@ multidict==6.0.4
# via
# aiohttp
# yarl
mypy==1.7.1
# via -r requirements-dev.in
mypy-extensions==1.0.0
# via black
# via
# black
# mypy
networkx==3.2.1
# via torch
numba==0.58.1
@ -172,7 +176,7 @@ packaging==23.2
# pytorch-lightning
# torchmetrics
# transformers
pathspec==0.11.2
pathspec==0.12.1
# via black
pillow==10.1.0
# via
@ -276,6 +280,7 @@ tokenizers==0.15.0
tomli==2.0.1
# via
# black
# mypy
# pytest
torch==2.1.1
# via
@ -314,13 +319,22 @@ transformers==4.35.2
# via imaginAIry (setup.py)
typeguard==2.13.3
# via jaxtyping
typing-extensions==4.8.0
types-pillow==10.1.0.2
# via -r requirements-dev.in
types-psutil==5.9.5.17
# via -r requirements-dev.in
types-requests==2.31.0.10
# via -r requirements-dev.in
types-tqdm==4.66.0.5
# via -r requirements-dev.in
typing-extensions==4.9.0
# via
# black
# fastapi
# huggingface-hub
# jaxtyping
# lightning-utilities
# mypy
# pydantic
# pydantic-core
# pytorch-lightning
@ -330,13 +344,14 @@ urllib3==2.1.0
# via
# requests
# responses
# types-requests
uvicorn==0.24.0.post1
# via imaginAIry (setup.py)
wcwidth==0.2.12
# via ftfy
wheel==0.42.0
# via -r requirements-dev.in
yarl==1.9.3
yarl==1.9.4
# via aiohttp
zipp==3.17.0
# via importlib-metadata

View File

@ -4,3 +4,8 @@ norecursedirs = build dist downloads other prolly_delete imaginairy/vendored
filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
[mypy]
plugins = pydantic.mypy
exclude = ^(downloads|dist|other|testing_support|imaginairy/vendored|imaginairy/modules/sgm)/.*
ignore_missing_imports = True