style: speed up linting and autoformatting. fix lints

pull/385/head
Bryce 8 months ago committed by Bryce Drennan
parent 460add16b8
commit 558d3388e5

@ -9,35 +9,42 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: 3.9
cache: pip
cache-dependency-path: requirements-dev.txt
- name: Install dependencies
run: |
python -m pip install --disable-pip-version-check wheel pip-tools
pip-sync requirements-dev.txt
python -m pip install --disable-pip-version-check --no-deps .
- name: Lint
run: |
echo "::add-matcher::.github/pylama_matcher.json"
pylama --options tox.ini
- uses: actions/checkout@v3
- uses: actions/setup-python@v4.5.0
with:
python-version: 3.9
- name: Cache dependencies
uses: actions/cache@v3.2.4
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements-dev.txt') }}-lint
- name: Install Ruff
if: steps.cache.outputs.cache-hit != 'true'
run: grep -E 'ruff==' requirements-dev.txt | xargs pip install
- name: Lint
run: |
echo "::add-matcher::.github/pylama_matcher.json"
ruff --config tests/ruff.toml .
autoformat:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --disable-pip-version-check black==23.1.0 isort==5.12.0
- name: Autoformatter
run: |
black --diff .
isort --atomic --profile black --check-only .
- uses: actions/checkout@v3
- uses: actions/setup-python@v4.5.0
with:
python-version: 3.9
- name: Cache dependencies
uses: actions/cache@v3.2.4
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements-dev.txt') }}-autoformat
- name: Install Black
if: steps.cache.outputs.cache-hit != 'true'
run: grep -E 'black==' requirements-dev.txt | xargs pip install
- name: Lint
run: |
black --diff --fast .
test:
runs-on: ubuntu-latest
strategy:

@ -28,18 +28,16 @@ init: require_pyenv ## Setup a dev environment for local development.
af: autoformat ## Alias for `autoformat`
autoformat: ## Run the autoformatter.
@pycln . --all --quiet --extend-exclude __init__\.py
@# ERA,T201
@-ruff --extend-ignore ANN,ARG001,C90,DTZ,D100,D101,D102,D103,D202,D203,D212,D415,E501,RET504,S101,UP006,UP007 --extend-select C,D400,I,W --unfixable T,ERA --fix-only .
@-ruff check --config tests/ruff.toml . --fix-only
@black .
@isort --atomic --profile black --skip downloads/** .
test: ## Run the tests.
@pytest
@echo -e "The tests pass! ✨ 🍰 ✨"
lint: ## Run the code linter.
@pylama
@ruff check --config tests/ruff.toml .
@echo -e "No linting errors - well done! ✨ 🍰 ✨"
deploy: ## Deploy the package to pypi.org

@ -63,7 +63,7 @@ def generate_image_morph_video():
if os.path.exists(filename):
continue
result = list(imagine([prompt]))[0]
result = next(iter(imagine([prompt])))
generated_image = result.images["generated"]
draw = ImageDraw.Draw(generated_image)

@ -33,7 +33,7 @@ def make_bounce_animation(
middle_imgs = shrink_list(middle_imgs, max_frames)
frames = [first_img] + middle_imgs + [last_img] + list(reversed(middle_imgs))
frames = [first_img, *middle_imgs, last_img, *list(reversed(middle_imgs))]
# convert from latents
converted_frames = []

@ -92,13 +92,13 @@ def imagine_image_files(
os.makedirs(subpath, exist_ok=True)
filepath = os.path.join(subpath, f"{basefilename}.gif")
frames = result.progress_latents + [result.images["generated"]]
frames = [*result.progress_latents, result.images["generated"]]
if prompt.init_image:
resized_init_image = pillow_fit_image_within(
prompt.init_image, prompt.width, prompt.height
)
frames = [resized_init_image] + frames
frames = [resized_init_image, *frames]
frames.reverse()
make_bounce_animation(
imgs=frames,
@ -170,7 +170,7 @@ def imagine(
logger.info(
f"🖼 Generating {i + 1}/{num_prompts}: {prompt.prompt_description()}"
)
for attempt in range(0, unsafe_retry_count + 1):
for attempt in range(unsafe_retry_count + 1):
if attempt > 0 and isinstance(prompt.seed, int):
prompt.seed += 100_000_000 + attempt
result = _generate_single_image(
@ -238,7 +238,7 @@ def _generate_single_image(
latent_channels = 4
downsampling_factor = 8
batch_size = 1
global _most_recent_result # noqa
global _most_recent_result
# handle prompt pulling in previous values
# if isinstance(prompt.init_image, str) and prompt.init_image.startswith("*prev"):
# _, img_type = prompt.init_image.strip("*").split(".")
@ -457,16 +457,17 @@ def _generate_single_image(
if control_image_t.shape[1] != 3:
raise RuntimeError("Control image must have 3 channels")
if control_input.mode != "inpaint":
if control_image_t.min() < 0 or control_image_t.max() > 1:
raise RuntimeError(
f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
)
if (
control_input.mode != "inpaint"
and control_image_t.min() < 0
or control_image_t.max() > 1
):
msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
raise RuntimeError(msg)
if control_image_t.max() == control_image_t.min():
raise RuntimeError(
f"No control signal found in control image {control_input.mode}."
)
msg = f"No control signal found in control image {control_input.mode}."
raise RuntimeError(msg)
c_cat.append(control_image_t)
control_strengths.append(control_input.strength)
@ -517,7 +518,7 @@ def _generate_single_image(
if (
prompt.allow_compose_phase
and not is_controlnet_model
and not model.cond_stage_key == "edit"
and model.cond_stage_key != "edit"
):
if prompt.init_image:
comp_image = _generate_composition_image(

@ -4,6 +4,7 @@ import logging
import shlex
import traceback
from functools import update_wrapper
from typing import ClassVar
import click
from click_help_colors import HelpColorsCommand, HelpColorsMixin
@ -43,27 +44,23 @@ def mod_get_invoke(command):
# and that's not ideal when running in a shell.
pass
except Exception as e: # noqa
traceback.print_exception(e) # noqa
traceback.print_exception(e)
# logger.warning(traceback.format_exc())
# Always return False so the shell doesn't exit
return False
invoke_ = update_wrapper(invoke_, command.callback)
invoke_.__name__ = "do_%s" % command.name # noqa
invoke_.__name__ = "do_%s" % command.name
return invoke_
class ModClickShell(ClickShell):
def add_command(self, cmd, name):
# Use the MethodType to add these as bound methods to our current instance
setattr(
self, "do_%s" % name, get_method_type(mod_get_invoke(cmd), self) # noqa
)
setattr(self, "help_%s" % name, get_method_type(get_help(cmd), self)) # noqa
setattr(
self, "complete_%s" % name, get_method_type(get_complete(cmd), self) # noqa
)
setattr(self, "do_%s" % name, get_method_type(mod_get_invoke(cmd), self))
setattr(self, "help_%s" % name, get_method_type(get_help(cmd), self))
setattr(self, "complete_%s" % name, get_method_type(get_complete(cmd), self))
class ModShell(Shell):
@ -85,7 +82,7 @@ class ColorShell(HelpColorsMixin, ModShell):
class ImagineColorsCommand(HelpColorsCommand):
_option_order = []
_option_order: ClassVar = []
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@ -43,7 +43,7 @@ remove_option(edit_options, "allow_compose_phase")
)
@add_options(edit_options)
@click.pass_context
def edit_cmd( # noqa
def edit_cmd(
ctx,
image_paths,
image_strength,
@ -77,7 +77,7 @@ def edit_cmd( # noqa
model_weights_path,
model_config_path,
prompt_library_path,
version, # noqa
version,
make_gif,
make_compare_gif,
arg_schedules,
@ -130,7 +130,7 @@ def edit_cmd( # noqa
model_weights_path,
model_config_path,
prompt_library_path,
version, # noqa
version,
make_gif,
make_compare_gif,
arg_schedules,

@ -90,7 +90,7 @@ def imagine_cmd(
model_weights_path,
model_config_path,
prompt_library_path,
version, # noqa
version,
make_gif,
make_compare_gif,
arg_schedules,
@ -110,7 +110,7 @@ def imagine_cmd(
# hacky method of getting order of control images (mixing raw and normal images)
control_images = [
(o, path)
for o, path in ImagineColorsCommand._option_order # noqa
for o, path in ImagineColorsCommand._option_order
if o.name in ("control_image", "control_image_raw")
]
control_inputs = []
@ -176,7 +176,7 @@ def imagine_cmd(
model_weights_path,
model_config_path,
prompt_library_path,
version, # noqa
version,
make_gif,
make_compare_gif,
arg_schedules,
@ -187,4 +187,4 @@ def imagine_cmd(
if __name__ == "__main__":
imagine_cmd() # noqa
imagine_cmd()

@ -92,4 +92,4 @@ def model_list_cmd():
if __name__ == "__main__":
aimg() # noqa
aimg()

@ -43,7 +43,7 @@ def _imagine_cmd(
model_weights_path,
model_config_path,
prompt_library_path,
version=False, # noqa
version=False,
make_gif=False,
make_compare_gif=False,
arg_schedules=None,
@ -78,10 +78,7 @@ def _imagine_cmd(
configure_logging(log_level)
if isinstance(init_image, str):
init_images = [init_image]
else:
init_images = init_image
init_images = [init_image] if isinstance(init_image, str) else init_image
from imaginairy.utils import glob_expand_paths
@ -89,9 +86,8 @@ def _imagine_cmd(
init_images = glob_expand_paths(init_images)
if len(init_images) < num_prexpaned_init_images:
raise ValueError(
f"Could not find any images matching the glob pattern(s) {init_image}. Are you sure the file(s) exists?"
)
msg = f"Could not find any images matching the glob pattern(s) {init_image}. Are you sure the file(s) exists?"
raise ValueError(msg)
total_image_count = len(prompt_texts) * max(len(init_images), 1) * repeats
logger.info(
@ -227,7 +223,8 @@ def replace_option(options, option_name, new_option):
if option.name == option_name:
options[i] = new_option
return
raise ValueError(f"Option {option_name} not found")
msg = f"Option {option_name} not found"
raise ValueError(msg)
def remove_option(options, option_name):
@ -242,7 +239,8 @@ def remove_option(options, option_name):
if option.name == option_name:
del options[i]
return
raise ValueError(f"Option {option_name} not found")
msg = f"Option {option_name} not found"
raise ValueError(msg)
common_options = [

@ -36,7 +36,7 @@ def colorize_img(img, max_width=1024, max_height=1024, caption=None):
steps=30,
prompt_strength=12,
)
result = list(imagine(prompt))[0]
result = next(iter(imagine(prompt)))
colorized_img = replace_color(img, result.images["generated"])
# allows the algorithm some leeway for the overall brightness of the image

@ -18,6 +18,7 @@ Examples:
"""
import operator
from abc import ABC
from typing import ClassVar
import pyparsing as pp
import torch
@ -57,7 +58,7 @@ class SimpleMask(Mask):
class ModifiedMask(Mask):
ops = {
ops: ClassVar = {
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
@ -80,7 +81,7 @@ class ModifiedMask(Mask):
return cls(mask=ret_tokens[0][0], modifier=ret_tokens[0][1])
def __repr__(self):
return f"{repr(self.mask)}{self.modifier}"
return f"{self.mask!r}{self.modifier}"
def gather_text_descriptions(self):
return self.mask.gather_text_descriptions()
@ -141,7 +142,8 @@ class NestedMask(Mask):
elif self.op == "NOT":
mask = 1 - mask
else:
raise ValueError(f"Invalid operand {self.op}")
msg = f"Invalid operand {self.op}"
raise ValueError(msg)
return torch.clamp(mask, 0, 1)

@ -14,9 +14,9 @@ from imaginairy.vendored.clipseg import CLIPDensePredT
weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth"
@lru_cache()
@lru_cache
def clip_mask_model():
from imaginairy.paths import PKG_ROOT # noqa
from imaginairy.paths import PKG_ROOT
model = CLIPDensePredT(version="ViT-B/16", reduce_dim=64, complex_trans_conv=True)
model.eval()
@ -36,7 +36,7 @@ def get_img_mask(
mask_description_statement: str,
threshold: Optional[float] = None,
):
from imaginairy.enhancers.bool_masker import MASK_PROMPT # noqa
from imaginairy.enhancers.bool_masker import MASK_PROMPT
parsed = MASK_PROMPT.parseString(mask_description_statement)
parsed_mask = parsed[0][0]

@ -17,9 +17,9 @@ if "mps" in device:
BLIP_EVAL_SIZE = 384
@lru_cache()
@lru_cache
def blip_model():
from imaginairy.paths import PKG_ROOT # noqa
from imaginairy.paths import PKG_ROOT
config_path = os.path.join(
PKG_ROOT, "vendored", "blip", "configs", "med_config.json"
@ -28,7 +28,7 @@ def blip_model():
model = BLIP_Decoder(image_size=BLIP_EVAL_SIZE, vit="base", med_config=config_path)
cached_url_path = get_cached_url_path(url)
model, msg = load_checkpoint(model, cached_url_path) # noqa
model, msg = load_checkpoint(model, cached_url_path)
model.eval()
model = model.to(device)
return model

@ -10,7 +10,7 @@ from imaginairy.vendored import clip
device = "cuda" if torch.cuda.is_available() else "cpu"
@lru_cache()
@lru_cache
def get_model():
model_name = "ViT-L/14"
model, preprocess = clip.load(model_name, device=device)

@ -17,7 +17,7 @@ face_restore_device = torch.device("cuda" if torch.cuda.is_available() else "cpu
half_mode = face_restore_device == "cuda"
@lru_cache()
@lru_cache
def codeformer_model():
model = CodeFormer(
dim_embd=512,
@ -36,7 +36,7 @@ def codeformer_model():
return model
@lru_cache()
@lru_cache
def face_restore_helper():
"""
Provide a singleton of FaceRestoreHelper.
@ -85,11 +85,11 @@ def enhance_faces(img, fidelity=0):
try:
with torch.no_grad():
output = net(cropped_face_t, w=fidelity, adain=True)[0] # noqa
output = net(cropped_face_t, w=fidelity, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
except Exception as error: # noqa
except Exception as error:
logger.exception(f"\tFailed inference for CodeFormer: {error}")
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))

@ -15,10 +15,10 @@ formatter = Formatter()
PROMPT_EXPANSION_PATTERN = re.compile(r"[|a-z0-9_ -]+")
@lru_cache()
@lru_cache
def prompt_library_filepaths(prompt_library_paths=None):
"""Return all available category/filepath pairs."""
prompt_library_paths = [] if not prompt_library_paths else prompt_library_paths
prompt_library_paths = prompt_library_paths if prompt_library_paths else []
combined_prompt_library_filepaths = {}
for prompt_path in DEFAULT_PROMPT_LIBRARY_PATHS + list(prompt_library_paths):
library_prompts = prompt_library_filepath(prompt_path)
@ -27,7 +27,7 @@ def prompt_library_filepaths(prompt_library_paths=None):
return combined_prompt_library_filepaths
@lru_cache()
@lru_cache
def category_list(prompt_library_paths=None):
"""Return the names of available phrase-lists."""
categories = list(prompt_library_filepaths(prompt_library_paths).keys())
@ -35,7 +35,7 @@ def category_list(prompt_library_paths=None):
return categories
@lru_cache()
@lru_cache
def prompt_library_filepath(library_path):
lookup = {}
@ -55,9 +55,8 @@ def get_phrases(category_name, prompt_library_paths=None):
try:
filepath = lookup[category_name]
except KeyError as e:
raise LookupError(
f"'{category_name}' is not a valid prompt expansion category. Could not find the txt file."
) from e
msg = f"'{category_name}' is not a valid prompt expansion category. Could not find the txt file."
raise LookupError(msg) from e
_open = open
if filepath.endswith(".gz"):
_open = gzip.open
@ -83,13 +82,12 @@ def expand_prompts(prompt_text, n=1, prompt_library_paths=None):
"""
prompt_parts = list(formatter.parse(prompt_text))
field_names = []
for literal_text, field_name, format_spec, conversion in prompt_parts: # noqa
for literal_text, field_name, format_spec, conversion in prompt_parts:
if field_name:
field_name = field_name.lower()
if not PROMPT_EXPANSION_PATTERN.match(field_name):
raise ValueError(
"Invalid prompt expansion. Only a-z0-9_|- characters permitted. "
)
msg = "Invalid prompt expansion. Only a-z0-9_|- characters permitted. "
raise ValueError(msg)
field_names.append(field_name)
phrases = []
@ -120,9 +118,7 @@ def expand_prompts(prompt_text, n=1, prompt_library_paths=None):
yield output_prompt
def get_random_non_repeating_combination( # noqa
n=1, *sequences, allow_oversampling=True
):
def get_random_non_repeating_combination(n=1, *sequences, allow_oversampling=True):
"""
Efficiently return a non-repeating random sample of the product sequences.

@ -187,7 +187,7 @@ class CLIPEmbedder(nn.Module):
)
@lru_cache()
@lru_cache
def clip_up_models():
with platform_appropriate_autocast():
tok_up = CLIPTokenizerTransform()
@ -290,7 +290,8 @@ def upscale_latent(
eta=eta,
**sampler_opts,
)
raise ValueError(f"Unknown sampler {sampler}")
msg = f"Unknown sampler {sampler}"
raise ValueError(msg)
for _ in range((num_samples - 1) // batch_size + 1):
if noise_aug_type == "gaussian":
@ -300,7 +301,7 @@ def upscale_latent(
elif noise_aug_type == "fake":
latent_noised = low_res_latent * (noise_aug_level**2 + 1) ** 0.5
extra_args = {
"low_res": latent_noised, # noqa
"low_res": latent_noised,
"low_res_sigma": low_res_sigma,
"c": c,
}

@ -63,7 +63,7 @@ class StableStudioInput(BaseModel, extra=Extra.forbid):
initial_image: Optional[StableStudioInputImage] = Field(None, alias="initialImage")
@validator("seed")
def validate_seed(cls, v): # noqa
def validate_seed(cls, v):
if v == 0:
return None
return v
@ -74,10 +74,7 @@ class StableStudioInput(BaseModel, extra=Extra.forbid):
from PIL import Image
if self.prompts:
positive_prompt = self.prompts[0].text
else:
positive_prompt = None
positive_prompt = self.prompts[0].text if self.prompts else None
if self.prompts and len(self.prompts) > 1:
negative_prompt = self.prompts[1].text if len(self.prompts) > 1 else None
else:

@ -79,7 +79,7 @@ def _create_depth_map_raw(img):
align_corners=False,
)
depth_pt = model(img)[0] # noqa
depth_pt = model(img)[0]
return depth_pt
@ -209,7 +209,10 @@ def inpaint_prep(mask_image_t, target_image_t):
def to_grayscale(img):
# The dimensions of input should be (batch_size, channels, height, width)
assert img.dim() == 4 and img.size(1) == 3
if img.dim() != 4:
raise ValueError("Input should be a 4d tensor")
if img.size(1) != 3:
raise ValueError("Input should have 3 channels")
# Apply the formula to convert to grayscale.
gray = (

@ -3,7 +3,7 @@ from collections import OrderedDict
from functools import lru_cache
import cv2
import matplotlib
import matplotlib as mpl
import numpy as np
import torch
from scipy.ndimage.filters import gaussian_filter
@ -40,7 +40,7 @@ def pad_right_down_corner(img, stride, padValue):
def transfer(model, model_weights):
# transfer caffe model to pytorch which will match the layer name
transfered_model_weights = {}
for weights_name in model.state_dict().keys():
for weights_name in model.state_dict():
transfered_model_weights[weights_name] = model_weights[
".".join(weights_name.split(".")[1:])
]
@ -93,14 +93,14 @@ def draw_bodypose(canvas, candidate, subset):
[255, 0, 85],
]
for i in range(18):
for n in range(len(subset)): # noqa
for n in range(len(subset)):
index = int(subset[n][i])
if index == -1:
continue
x, y = candidate[index][0:2]
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
for i in range(17):
for n in range(len(subset)): # noqa
for n in range(len(subset)):
index = subset[n][np.array(limbSeq[i]) - 1]
if -1 in index:
continue
@ -155,8 +155,7 @@ def draw_handpose(canvas, all_hand_peaks, show_number=False):
canvas,
(x1, y1),
(x2, y2),
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0])
* 255,
mpl.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
thickness=2,
)
@ -368,7 +367,7 @@ class bodypose_model(nn.Module):
]
)
for k in blocks.keys():
for k in blocks:
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_1 = blocks["block1_1"]
@ -473,7 +472,7 @@ class handpose_model(nn.Module):
]
)
for k in blocks.keys():
for k in blocks:
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_0 = blocks["block1_0"]
@ -625,7 +624,7 @@ def create_body_pose(original_img_t):
peaks = list(
zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])
) # note reverse
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
peaks_with_score = [(*x, map_ori[x[1], x[0]]) for x in peaks]
peak_id = range(peak_counter, peak_counter + len(peaks))
peaks_with_score_and_id = [
peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))
@ -751,7 +750,7 @@ def create_body_pose(original_img_t):
connection_candidate, key=lambda x: x[2], reverse=True
)
connection = np.zeros((0, 5))
for c in range(len(connection_candidate)): # noqa
for c in range(len(connection_candidate)):
i, j, s = connection_candidate[c][0:3]
if i not in connection[:, 3] and j not in connection[:, 4]:
connection = np.vstack(

@ -101,9 +101,10 @@ def torch_img_to_pillow_img(img_t: torch.Tensor):
elif img_t.shape[1] == 3:
colorspace = "RGB"
else:
raise ValueError(
msg = (
f"Unsupported colorspace. {img_t.shape[1]} channels in {img_t.shape} shape"
)
raise ValueError(msg)
img_t = rearrange(img_t, "b c h w -> b h w c")
img_t = torch.clamp((img_t + 1.0) / 2.0, min=0.0, max=1.0)
img_np = (255.0 * img_t).cpu().numpy().astype(np.uint8)[0]
@ -113,7 +114,7 @@ def torch_img_to_pillow_img(img_t: torch.Tensor):
def model_latent_to_pillow_img(latent: torch.Tensor) -> PIL.Image.Image:
from imaginairy.model_manager import get_current_diffusion_model # noqa
from imaginairy.model_manager import get_current_diffusion_model
if len(latent.shape) == 3:
latent = latent.unsqueeze(0)

@ -94,13 +94,13 @@ class ImageLoggingContext:
self._prev_log_context = None
def __enter__(self):
global _CURRENT_LOGGING_CONTEXT # noqa
global _CURRENT_LOGGING_CONTEXT
self._prev_log_context = _CURRENT_LOGGING_CONTEXT
_CURRENT_LOGGING_CONTEXT = self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_LOGGING_CONTEXT # noqa
global _CURRENT_LOGGING_CONTEXT
_CURRENT_LOGGING_CONTEXT = self._prev_log_context
def timing(self, description):
@ -120,21 +120,20 @@ class ImageLoggingContext:
)
def log_latents(self, latents, description):
from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa
from imaginairy.img_utils import model_latents_to_pillow_imgs
if "predicted_latent" in description:
if self.progress_latent_callback is not None:
self.progress_latent_callback(latents)
if (
self.step_count - self.last_progress_img_step
) > self.progress_img_interval_steps:
if (
time.perf_counter() - self.last_progress_img_ts
> self.progress_img_interval_min_s
):
self.log_progress_latent(latents)
self.last_progress_img_step = self.step_count
self.last_progress_img_ts = time.perf_counter()
) > self.progress_img_interval_steps and (
time.perf_counter() - self.last_progress_img_ts
> self.progress_img_interval_min_s
):
self.log_progress_latent(latents)
self.last_progress_img_step = self.step_count
self.last_progress_img_ts = time.perf_counter()
if not self.debug_img_callback:
return
@ -168,7 +167,7 @@ class ImageLoggingContext:
)
def log_progress_latent(self, latent):
from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa
from imaginairy.img_utils import model_latents_to_pillow_imgs
if not self.progress_img_callback:
return
@ -280,7 +279,7 @@ def disable_pytorch_lighting_custom_logging():
from pytorch_lightning import _logger as pytorch_logger
try:
from pytorch_lightning.utilities.seed import log # noqa
from pytorch_lightning.utilities.seed import log
log.setLevel(logging.NOTSET)
log.handlers = []

@ -24,9 +24,8 @@ class LambdaWarmUpCosineScheduler:
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (
self.lr_max - self.lr_start
@ -66,7 +65,7 @@ class LambdaWarmUpCosineScheduler2:
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.cum_cycles = np.cumsum([0, *list(self.cycle_lengths)])
self.last_f = 0.0
self.verbosity_interval = verbosity_interval
@ -81,12 +80,11 @@ class LambdaWarmUpCosineScheduler2:
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
@ -112,12 +110,11 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[

@ -7,9 +7,11 @@ from functools import wraps
import requests
import torch
from huggingface_hub import HfFolder
from huggingface_hub import hf_hub_download as _hf_hub_download
from huggingface_hub import try_to_load_from_cache
from huggingface_hub import (
HfFolder,
hf_hub_download as _hf_hub_download,
try_to_load_from_cache,
)
from omegaconf import OmegaConf
from safetensors.torch import load_file
@ -65,16 +67,19 @@ def load_state_dict(weights_location, half_mode=False, device=None):
f'Error: "{ckpt_path}" not a valid path to model weights.\nPreconfigured models you can use: {MODEL_SHORT_NAMES}.'
)
sys.exit(1)
raise e
raise
except RuntimeError as e:
if "PytorchStreamReader failed reading zip archive" in str(e):
if weights_location.startswith("http"):
logger.warning("Corrupt checkpoint. deleting and re-downloading...")
os.remove(ckpt_path)
ckpt_path = get_cached_url_path(weights_location, category="weights")
state_dict = load_tensors(ckpt_path, map_location="cpu")
err_str = str(e)
if (
"PytorchStreamReader failed reading zip archive" in err_str
and weights_location.startswith("http")
):
logger.warning("Corrupt checkpoint. deleting and re-downloading...")
os.remove(ckpt_path)
ckpt_path = get_cached_url_path(weights_location, category="weights")
state_dict = load_tensors(ckpt_path, map_location="cpu")
if state_dict is None:
raise e
raise
state_dict = state_dict.get("state_dict", state_dict)
@ -166,7 +171,7 @@ def get_diffusion_model(
except HuggingFaceAuthorizationError as e:
if for_inpainting:
logger.warning(
f"Failed to load inpainting model. Attempting to fall-back to standard model. {str(e)}"
f"Failed to load inpainting model. Attempting to fall-back to standard model. {e!s}"
)
return _get_diffusion_model(
iconfig.DEFAULT_MODEL,
@ -176,7 +181,7 @@ def get_diffusion_model(
for_training=for_training,
control_weights_locations=control_weights_locations,
)
raise e
raise
def _get_diffusion_model(
@ -192,7 +197,7 @@ def _get_diffusion_model(
Weights location may also be shortcut name, e.g. "SD-1.5"
"""
global MOST_RECENTLY_LOADED_MODEL # noqa
global MOST_RECENTLY_LOADED_MODEL
(
model_config,
@ -293,9 +298,8 @@ def resolve_model_paths(
if for_training:
weights_path = model_metadata_w.weights_url_full
if weights_path is None:
raise ValueError(
"No full training weights configured for this model. Edit the code or subimt a github issue."
)
msg = "No full training weights configured for this model. Edit the code or subimt a github issue."
raise ValueError(msg)
else:
weights_path = model_metadata_w.weights_url
@ -306,9 +310,8 @@ def resolve_model_paths(
config_path = iconfig.MODEL_CONFIG_SHORTCUTS[iconfig.DEFAULT_MODEL].config_path
if control_net_metadatas:
if "stable-diffusion-v1" not in config_path:
raise ValueError(
"Control net is only supported for stable diffusion v1. Please use a different model."
)
msg = "Control net is only supported for stable diffusion v1. Please use a different model."
raise ValueError(msg)
control_weights_paths = [cnm.weights_url for cnm in control_net_metadatas]
config_path = control_net_metadatas[0].config_path
model_metadata = model_metadata_w or model_metadata_c
@ -374,7 +377,7 @@ def get_cached_url_path(url, category=None):
os.rename(old_dest_path, dest_path)
return dest_path
r = requests.get(url) # noqa
r = requests.get(url)
with open(dest_path, "wb") as f:
f.write(r.content)
@ -390,12 +393,8 @@ def check_huggingface_url_authorized(url):
headers["authorization"] = f"Bearer {token}"
response = requests.head(url, allow_redirects=True, headers=headers, timeout=5)
if response.status_code == 401:
raise HuggingFaceAuthorizationError(
"Unauthorized access to HuggingFace model. This model requires a huggingface token. "
"Please login to HuggingFace "
"or set HUGGING_FACE_HUB_TOKEN to your User Access Token. "
"See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information"
)
msg = "Unauthorized access to HuggingFace model. This model requires a huggingface token. Please login to HuggingFace or set HUGGING_FACE_HUB_TOKEN to your User Access Token. See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information"
raise HuggingFaceAuthorizationError(msg)
return None
@ -413,7 +412,7 @@ def hf_hub_download(*args, **kwargs):
if "unexpected keyword argument 'token'" in str(e):
kwargs["use_auth_token"] = kwargs.pop("token")
return _hf_hub_download(*args, **kwargs)
raise e
raise
def huggingface_cached_path(url):

@ -14,8 +14,8 @@ XFORMERS_IS_AVAILABLE = False
try:
if get_device() == "cuda":
import xformers # noqa
import xformers.ops # noqa
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
except ImportError:
@ -79,7 +79,7 @@ class LinearAttention(nn.Module):
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape # noqa
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
@ -120,7 +120,7 @@ class SpatialSelfAttention(nn.Module):
v = self.v(h_)
# compute attention
b, c, h, w = q.shape # noqa
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
@ -183,7 +183,7 @@ class CrossAttention(nn.Module):
# if mask is None and _global_mask_hack is not None:
# mask = _global_mask_hack.to(torch.bool)
if get_device() == "cuda" or "mps" in get_device():
if get_device() == "cuda" or "mps" in get_device(): # noqa
if not XFORMERS_IS_AVAILABLE and ALLOW_SPLITMEM:
return self.forward_splitmem(x, context=context, mask=mask)
@ -222,7 +222,7 @@ class CrossAttention(nn.Module):
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
def forward_splitmem(self, x, context=None, mask=None): # noqa
def forward_splitmem(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
@ -262,10 +262,8 @@ class CrossAttention(nn.Module):
max_res = (
math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
)
raise RuntimeError(
f"Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). "
f"Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free"
)
msg = f"Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free"
raise RuntimeError(msg)
slice_size = (
q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
)
@ -474,7 +472,7 @@ class SpatialTransformer(nn.Module):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape # noqa
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if self.use_linear:

@ -290,10 +290,7 @@ class AutoencoderKL(pl.LightningModule):
def forward(self, input, sample_posterior=True): # noqa
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
z = posterior.sample() if sample_posterior else posterior.mode()
dec = self.decode(z)
return dec, posterior
@ -484,7 +481,7 @@ class AutoencoderKL(pl.LightningModule):
:param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
"""
bs, nc, h, w = x.shape # noqa
bs, nc, h, w = x.shape
# number of crops in image
Ly = (h - kernel_size[0]) // stride[0] + 1

@ -19,12 +19,12 @@ from imaginairy.modules.diffusion.util import (
class ControlledUnetModel(UNetModel):
def forward( # noqa
def forward(
self,
x,
timesteps=None,
context=None,
control=None, # noqa
control=None,
only_mid_control=False,
**kwargs,
):
@ -129,10 +129,8 @@ class ControlNet(nn.Module):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
msg = "provide num_res_blocks either as an int (globally constant) or as a list/tuple (per-level) with the same length as channel_mult"
raise ValueError(msg)
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
@ -140,10 +138,8 @@ class ControlNet(nn.Module):
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(
(
self.num_res_blocks[i] >= num_attention_blocks[i]
for i in range(len(num_attention_blocks))
)
self.num_res_blocks[i] >= num_attention_blocks[i]
for i in range(len(num_attention_blocks))
)
print(
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
@ -425,10 +421,10 @@ class ControlLDM(LatentDiffusion):
if is_diffusing:
self.model = self.model.cuda()
self.control_models = [cm.cuda() for cm in self.control_models]
self.first_stage_model = self.first_stage_model.cpu() # noqa
self.first_stage_model = self.first_stage_model.cpu()
self.cond_stage_model = self.cond_stage_model.cpu()
else:
self.model = self.model.cpu()
self.control_models = [cm.cpu() for cm in self.control_models]
self.first_stage_model = self.first_stage_model.cuda() # noqa
self.first_stage_model = self.first_stage_model.cuda()
self.cond_stage_model = self.cond_stage_model.cuda()

@ -102,9 +102,7 @@ class FrozenClipImageEmbedder(nn.Module):
antialias=False,
):
super().__init__()
self.model, preprocess = clip.load( # noqa
name=model_name, device=device, jit=jit
)
self.model, preprocess = clip.load(name=model_name, device=device, jit=jit)
self.antialias = antialias

@ -9,6 +9,7 @@ import itertools
import logging
from contextlib import contextmanager, nullcontext
from functools import partial
from typing import Optional
import numpy as np
import pytorch_lightning as pl
@ -371,7 +372,7 @@ class DDPM(pl.LightningModule):
# we only modify first two axes
assert new_shape[2:] == old_shape[2:]
# assumes first axis corresponds to output dim
if not new_shape == old_shape:
if new_shape != old_shape:
new_param = param.clone()
old_param = sd[name]
if len(new_shape) == 1:
@ -495,7 +496,7 @@ class DDPM(pl.LightningModule):
img = torch.randn(shape, device=device)
intermediates = [img]
for i in tqdm(
reversed(range(0, self.num_timesteps)),
reversed(range(self.num_timesteps)),
desc="Sampling t",
total=self.num_timesteps,
):
@ -563,9 +564,8 @@ class DDPM(pl.LightningModule):
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError(
f"Parameterization {self.parameterization} not yet supported"
)
msg = f"Parameterization {self.parameterization} not yet supported"
raise NotImplementedError(msg)
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
@ -706,7 +706,7 @@ class DDPM(pl.LightningModule):
lr = self.learning_rate
params = list(self.model.parameters())
if self.learn_logvar:
params = params + [self.logvar]
params = [*params, self.logvar]
opt = torch.optim.AdamW(params, lr=lr)
return opt
@ -716,7 +716,7 @@ def _TileModeConv2DConvForward(
):
if self.padding_modeX == self.padding_modeY:
self.padding_mode = self.padding_modeX
return self._orig_conv_forward(input, weight, bias) # noqa
return self._orig_conv_forward(input, weight, bias)
w1 = F.pad(input, self.paddingX, mode=self.padding_modeX)
del input
@ -790,9 +790,7 @@ class LatentDiffusion(DDPM):
if isinstance(m, nn.Conv2d):
m._initial_padding_mode = m.padding_mode
m._orig_conv_forward = m._conv_forward
m._conv_forward = _TileModeConv2DConvForward.__get__( # noqa
m, nn.Conv2d
)
m._conv_forward = _TileModeConv2DConvForward.__get__(m, nn.Conv2d)
self.tile_mode(tile_mode=False)
def tile_mode(self, tile_mode):
@ -807,16 +805,16 @@ class LatentDiffusion(DDPM):
if m.padding_modeY == m.padding_modeX:
m.padding_mode = m.padding_modeX
m.paddingX = (
m._reversed_padding_repeated_twice[0], # noqa
m._reversed_padding_repeated_twice[1], # noqa
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.paddingY = (
0,
0,
m._reversed_padding_repeated_twice[2], # noqa
m._reversed_padding_repeated_twice[3], # noqa
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
def make_cond_schedule(
@ -896,9 +894,8 @@ class LatentDiffusion(DDPM):
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
)
msg = f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
raise NotImplementedError(msg)
return self.scale_factor * z
def get_learned_conditioning(self, c):
@ -967,7 +964,7 @@ class LatentDiffusion(DDPM):
:param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
"""
bs, nc, h, w = x.shape # noqa
bs, nc, h, w = x.shape
# number of crops in image
Ly = (h - kernel_size[0]) // stride[0] + 1
@ -1167,7 +1164,7 @@ class LatentDiffusion(DDPM):
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
h, w = x_noisy.shape[-2:] # noqa
h, w = x_noisy.shape[-2:]
fold, unfold, normalization, weighting = self.get_fold_unfold(
x_noisy, ks, stride
@ -1239,9 +1236,7 @@ class LatentDiffusion(DDPM):
# tokenize crop coordinates for the bounding boxes of the respective patches
patch_limits_tknzd = [
torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[ # noqa
None
].to( # noqa
torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(
self.device
)
for bbox in patch_limits
@ -1292,7 +1287,7 @@ class LatentDiffusion(DDPM):
return x_recon
def p_losses(self, x_start, cond, t, noise=None): # noqa
def p_losses(self, x_start, cond, t, noise=None):
noise = noise if noise is not None else torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = self.apply_model(x_noisy, t, cond)
@ -1374,7 +1369,7 @@ class LatentDiffusion(DDPM):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample( # noqa
def p_sample(
self,
x,
c,
@ -1609,7 +1604,7 @@ class LatentDiffusion(DDPM):
if inpaint:
# make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3]
b, h, w = z.shape[0], z.shape[2], z.shape[3] # noqa
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0
@ -1674,9 +1669,8 @@ class LatentDiffusion(DDPM):
logger.info("Training the full unet")
params = list(self.model.parameters())
else:
raise ValueError(
f"Unrecognised setting for unet_trainable: {self.unet_trainable}"
)
msg = f"Unrecognised setting for unet_trainable: {self.unet_trainable}"
raise ValueError(msg)
if self.cond_stage_trainable:
logger.info(
@ -1706,7 +1700,7 @@ class LatentDiffusion(DDPM):
def to_rgb(self, x):
x = x.float()
if not hasattr(self, "colorize"):
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) # noqa
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
x = nn.functional.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
@ -1719,17 +1713,19 @@ class DiffusionWrapper(pl.LightningModule):
self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, "concat", "crossattn", "hybrid", "adm"]
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
def forward(
self, x, t, c_concat: Optional[list] = None, c_crossattn: Optional[list] = None
):
if self.conditioning_key is None:
out = self.diffusion_model(x, t)
elif self.conditioning_key == "concat":
xc = torch.cat([x] + c_concat, dim=1)
xc = torch.cat([x, *c_concat], dim=1)
out = self.diffusion_model(xc, t)
elif self.conditioning_key == "crossattn":
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == "hybrid":
xc = torch.cat([x] + c_concat, dim=1)
xc = torch.cat([x, *c_concat], dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == "adm":
@ -1818,7 +1814,7 @@ class LatentFinetuneDiffusion(LatentDiffusion):
# print(f"Unexpected Keys: {unexpected}")
@torch.no_grad()
def log_images( # noqa
def log_images(
self,
batch,
N=8,
@ -1866,7 +1862,7 @@ class LatentFinetuneDiffusion(LatentDiffusion):
if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
log["c_concat_decoded"] = self.decode_first_stage(
c_cat[:, self.c_concat_log_start : self.c_concat_log_end] # noqa
c_cat[:, self.c_concat_log_start : self.c_concat_log_end]
)
if plot_diffusion_rows:
@ -1929,11 +1925,11 @@ class LatentFinetuneDiffusion(LatentDiffusion):
class LatentInpaintDiffusion(LatentDiffusion):
def __init__( # noqa
def __init__(
self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
finetune_keys=None, # noqa
finetune_keys=None,
*args,
**kwargs,
):

@ -16,8 +16,8 @@ XFORMERS_IS_AVAILABLE = False
try:
if get_device() == "cuda":
import xformers # noqa
import xformers.ops # noqa
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
except ImportError:
@ -415,7 +415,7 @@ class Model(nn.Module):
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
in_ch_mult = (1, *tuple(ch_mult))
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
@ -581,7 +581,7 @@ class Encoder(nn.Module):
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
in_ch_mult = (1, *tuple(ch_mult))
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
@ -853,10 +853,7 @@ class SimpleDecoder(nn.Module):
def forward(self, x):
for i, layer in enumerate(self.model):
if i in [1, 2, 3]:
x = layer(x, None)
else:
x = layer(x)
x = layer(x, None) if i in [1, 2, 3] else layer(x)
h = self.norm_out(x)
h = silu(h)

@ -1,5 +1,6 @@
import math
from abc import abstractmethod
from typing import Optional
import numpy as np
import torch as th
@ -38,7 +39,7 @@ class AttentionPool2d(nn.Module):
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
output_dim: Optional[int] = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(
@ -519,10 +520,8 @@ class UNetModel(nn.Module):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
msg = "provide num_res_blocks either as an int (globally constant) or as a list/tuple (per-level) with the same length as channel_mult"
raise ValueError(msg)
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not

@ -52,7 +52,8 @@ def make_beta_schedule(
** 0.5
)
else:
raise ValueError(f"schedule '{schedule}' unknown.")
msg = f"schedule '{schedule}' unknown."
raise ValueError(msg)
return betas.numpy()
@ -80,9 +81,8 @@ def make_ddim_timesteps(
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
).astype(int)
else:
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
msg = f'There is no ddim discretization method called "{ddim_discr_method}"'
raise NotImplementedError(msg)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
@ -93,7 +93,7 @@ def make_ddim_timesteps(
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
alphas_prev = np.asarray([alphacums[0], *alphacums[ddim_timesteps[:-1]].tolist()])
# according to the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt(
@ -151,7 +151,7 @@ def checkpoint(func, inputs, params, flag):
return func(*inputs)
class CheckpointFunction(torch.autograd.Function): # noqa
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
@ -180,7 +180,7 @@ class CheckpointFunction(torch.autograd.Function): # noqa
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
return (None, None, *input_grads)
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
@ -246,7 +246,7 @@ def normalization(channels):
class GroupNorm32(nn.GroupNorm):
def forward(self, x): # noqa
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
@ -260,7 +260,8 @@ def conv_nd(dims, *args, **kwargs):
return nn.Conv2d(*args, **kwargs)
if dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
msg = f"unsupported dimensions: {dims}"
raise ValueError(msg)
def linear(*args, **kwargs):
@ -278,7 +279,8 @@ def avg_pool_nd(dims, *args, **kwargs):
return nn.AvgPool2d(*args, **kwargs)
if dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
msg = f"unsupported dimensions: {dims}"
raise ValueError(msg)
class HybridConditioner(nn.Module):

@ -43,7 +43,7 @@ class ClassEmbedder(nn.Module):
return uc
def disabled_train(self, mode=True): # noqa
def disabled_train(self, mode=True):
"""
For disabling train/eval mode.

@ -61,9 +61,8 @@ def load_midas_transform(model_type="dpt_hybrid"):
)
else:
assert (
False
), f"model_type '{model_type}' not implemented, use: --model_type large"
msg = f"model_type '{model_type}' not implemented, use: --model_type large"
raise NotImplementedError(msg)
transform = Compose(
[
@ -133,8 +132,8 @@ def load_model(model_type):
)
else:
print(f"model_type '{model_type}' not implemented, use: --model_type large")
assert False
msg = f"model_type '{model_type}' not implemented, use: --model_type large"
raise NotImplementedError(msg)
transform = Compose(
[
@ -155,13 +154,13 @@ def load_model(model_type):
return model.eval(), transform
@lru_cache()
@lru_cache
def midas_device():
# mps returns incorrect results ~50% of the time
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
@lru_cache()
@lru_cache
def load_midas(model_type="dpt_hybrid"):
model = MiDaSInference(model_type)
model.to(midas_device())

@ -4,7 +4,7 @@ from imaginairy import config
from imaginairy.model_manager import get_cached_url_path
class BaseModel(torch.nn.Module): # noqa
class BaseModel(torch.nn.Module):
def load(self, path):
"""
Load model from file.

@ -56,8 +56,8 @@ def _make_encoder(
[32, 48, 136, 384], features, groups=groups, expand=expand
) # efficientnet_lite3
else:
print(f"Backbone '{backbone}' not implemented")
assert False
msg = f"Backbone '{backbone}' not implemented"
raise NotImplementedError(msg)
return pretrained, scratch

@ -135,9 +135,8 @@ class Resize:
# fit height
scale_width = scale_height
else:
raise ValueError(
f"resize_method {self.__resize_method} not implemented"
)
msg = f"resize_method {self.__resize_method} not implemented"
raise ValueError(msg)
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(
@ -157,7 +156,8 @@ class Resize:
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
msg = f"resize_method {self.__resize_method} not implemented"
raise ValueError(msg)
return (new_width, new_height)

@ -22,10 +22,7 @@ class AddReadout(nn.Module):
self.start_index = start_index
def forward(self, x):
if self.start_index == 2:
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
readout = (x[:, 0] + x[:, 1]) / 2 if self.start_index == 2 else x[:, 0]
return x[:, self.start_index :] + readout.unsqueeze(1)
@ -118,7 +115,7 @@ def _resize_pos_embed(self, posemb, gs_h, gs_w):
def forward_flex(self, x):
b, c, h, w = x.shape
pos_embed = self._resize_pos_embed( # noqa
pos_embed = self._resize_pos_embed(
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
)
@ -174,9 +171,8 @@ def get_readout_oper(vit_features, features, use_readout, start_index=1):
ProjectReadout(vit_features, start_index) for out_feat in features
]
else:
assert (
False
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
msg = "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
raise ValueError(msg)
return readout_oper
@ -288,7 +284,7 @@ def _make_vit_b16_backbone(
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType( # noqa
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
@ -469,7 +465,7 @@ def _make_vit_b_rn50_backbone(
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model._resize_pos_embed = types.MethodType( # noqa
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)

@ -94,9 +94,8 @@ def write_pfm(path, image, scale=1):
): # greyscale
color = False
else:
raise ValueError(
"Image must have H x W x 3, H x W x 1 or H x W dimensions."
)
msg = "Image must have H x W x 3, H x W x 1 or H x W dimensions."
raise ValueError(msg)
file.write("PF\n" if color else b"Pf\n")
file.write(b"%d %d\n" % (image.shape[1], image.shape[0]))
@ -144,10 +143,7 @@ def resize_image(img):
height_orig = img.shape[0]
width_orig = img.shape[1]
if width_orig > height_orig:
scale = width_orig / 384
else:
scale = height_orig / 384
scale = width_orig / 384 if width_orig > height_orig else height_orig / 384
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)

@ -217,15 +217,18 @@ def outpaint_arg_str_parse(arg_str):
for arg in args:
match = arg_pattern.match(arg)
if not match:
raise ValueError(f"Invalid outpaint argument '{arg}'")
msg = f"Invalid outpaint argument '{arg}'"
raise ValueError(msg)
direction, amount = match.groups()
direction = direction.lower()
if len(direction) == 1:
if direction not in valid_direction_chars:
raise ValueError(f"Invalid outpaint direction '{direction}'")
msg = f"Invalid outpaint direction '{direction}'"
raise ValueError(msg)
direction = valid_direction_chars[direction]
elif direction not in valid_directions:
raise ValueError(f"Invalid outpaint direction '{direction}'")
msg = f"Invalid outpaint direction '{direction}'"
raise ValueError(msg)
kwargs[direction] = int(amount)
if "all" in kwargs:

@ -11,13 +11,13 @@ def parse_schedule_str(schedule_str):
pattern = re.compile(r"([a-zA-Z0-9_-]+)\[([a-zA-Z0-9_:,. -]+)\]")
match = pattern.match(schedule_str)
if not match:
raise ValueError(f"Invalid kwarg schedule: {schedule_str}")
msg = f"Invalid kwarg schedule: {schedule_str}"
raise ValueError(msg)
arg_name = match.group(1).replace("-", "_")
if not hasattr(ImaginePrompt(), arg_name):
raise ValueError(
f"Invalid kwarg schedule. Not a valid argument name: {arg_name}"
)
msg = f"Invalid kwarg schedule. Not a valid argument name: {arg_name}"
raise ValueError(msg)
arg_values = match.group(2)
if ":" in arg_values:
@ -53,7 +53,7 @@ def prompt_mutator(prompt, schedules):
}
"""
schedule_length = len(list(schedules.values())[0])
schedule_length = len(next(iter(schedules.values())))
for i in range(schedule_length):
new_prompt = copy(prompt)
for attr_name, schedule in schedules.items():

@ -24,7 +24,8 @@ def square_roi_coordinate(roi, max_width, max_height, best_effort=False):
width = x2 - x1
height = y2 - y1
if not best_effort and width != height:
raise RuntimeError(f"ROI is not square: {width}x{height}")
msg = f"ROI is not square: {width}x{height}"
raise RuntimeError(msg)
return x1, y1, x2, y2
@ -96,8 +97,7 @@ def move_roi_into_bounds(roi, max_width, max_height, force=False):
if x1 < 0 or y1 < 0 or x2 > max_width or y2 > max_height:
roi_width = x2 - x1
roi_height = y2 - y1
raise RoiNotInBoundsError(
f"Not possible to fit ROI into boundaries: {roi_width}x{roi_height} won't fit inside {max_width}x{max_height}"
)
msg = f"Not possible to fit ROI into boundaries: {roi_width}x{roi_height} won't fit inside {max_width}x{max_height}"
raise RoiNotInBoundsError(msg)
return x1, y1, x2, y2

@ -113,7 +113,7 @@ class EnhancedStableDiffusionSafetyChecker(
return safety_results
@lru_cache()
@lru_cache
def safety_models():
safety_model_id = "CompVis/stable-diffusion-safety-checker"
monkeypatch_safety_cosine_distance()
@ -124,7 +124,7 @@ def safety_models():
return safety_feature_extractor, safety_checker
@lru_cache()
@lru_cache
def monkeypatch_safety_cosine_distance():
orig_cosine_distance = safety_checker_mod.cosine_distance

@ -51,17 +51,19 @@ class InvalidUrlError(ValueError):
class LazyLoadingImage:
"""Image file encoded as base64 string."""
def __init__(self, *, filepath=None, url=None, img: Image = None, b64: str = None):
def __init__(
self, *, filepath=None, url=None, img: Image = None, b64: Optional[str] = None
):
if not filepath and not url and not img and not b64:
raise ValueError(
"You must specify a url or filepath or img or base64 string"
)
msg = "You must specify a url or filepath or img or base64 string"
raise ValueError(msg)
if sum([bool(filepath), bool(url), bool(img), bool(b64)]) > 1:
raise ValueError("You cannot multiple input methods")
# validate file exists
if filepath and not os.path.exists(filepath):
raise FileNotFoundError(f"File does not exist: {filepath}")
msg = f"File does not exist: {filepath}"
raise FileNotFoundError(msg)
# validate url is valid url
if url:
@ -73,7 +75,8 @@ class LazyLoadingImage:
except LocationParseError:
raise InvalidUrlError(f"Invalid url: {url}") # noqa
if parsed_url.scheme not in {"http", "https"} or not parsed_url.host:
raise InvalidUrlError(f"Invalid url: {url}")
msg = f"Invalid url: {url}"
raise InvalidUrlError(msg)
if b64:
img = self.load_image_from_base64(b64)
@ -145,16 +148,14 @@ class LazyLoadingImage:
raise ValueError(msg) # noqa
if isinstance(value, dict):
return cls(**value)
raise ValueError(
"Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string"
)
msg = "Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string"
raise ValueError(msg)
def handle_b64(value: Any) -> "LazyLoadingImage":
if isinstance(value, str):
return cls(b64=value)
raise ValueError(
"Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string"
)
msg = "Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string"
raise ValueError(msg)
return core_schema.json_or_python_schema(
json_schema=core_schema.chain_schema(
@ -349,15 +350,13 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return ""
if not isinstance(v, str):
raise ValueError(
f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}"
)
msg = f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}"
raise ValueError(msg) # noqa
v = v.lower()
if v not in valid_tile_modes:
raise ValueError(
f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}"
)
msg = f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}"
raise ValueError(msg)
return v
@field_validator("outpaint", mode="after")
@ -375,7 +374,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return v
if not isinstance(v, Tensor):
raise ValueError("conditioning must be a torch.Tensor")
raise ValueError("conditioning must be a torch.Tensor") # noqa
return v
# @field_validator("init_image", "mask_image", mode="after")
@ -412,13 +411,15 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
@field_validator("mask_image")
def validate_mask_image(cls, v, info: core_schema.FieldValidationInfo):
if v is not None and info.data.get("mask_prompt") is not None:
raise ValueError("You can only set one of `mask_image` and `mask_prompt`")
msg = "You can only set one of `mask_image` and `mask_prompt`"
raise ValueError(msg)
return v
@field_validator("mask_prompt", "mask_image", mode="before")
def validate_mask_prompt(cls, v, info: core_schema.FieldValidationInfo):
if info.data.get("init_image") is None and v:
raise ValueError("You must set `init_image` if you want to use a mask")
msg = "You must set `init_image` if you want to use a mask"
raise ValueError(msg)
return v
@field_validator("model", mode="before")
@ -455,9 +456,8 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
SamplerName.PLMS,
SamplerName.DDIM,
):
raise ValueError(
"PLMS and DDIM samplers are not supported for pix2pix edit model."
)
msg = "PLMS and DDIM samplers are not supported for pix2pix edit model."
raise ValueError(msg)
return v
@field_validator("steps")
@ -620,7 +620,7 @@ class ImagineResult:
self.is_nsfw = is_nsfw
self.safety_score = safety_score
self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc)
self.created_at = datetime.now(tz=timezone.utc)
self.torch_backend = get_device()
self.hardware_name = get_hardware_description(get_device())
@ -655,9 +655,8 @@ class ImagineResult:
def save(self, save_path, image_type="generated"):
img = self.images.get(image_type, None)
if img is None:
raise ValueError(
f"Image of type {image_type} not stored. Options are: {self.images.keys()}"
)
msg = f"Image of type {image_type} not stored. Options are: {self.images.keys()}"
raise ValueError(msg)
img.convert("RGB").save(save_path, exif=self._exif())

@ -192,7 +192,7 @@ def surprise_me_prompts(
width=width,
height=height,
seed=seed,
**kwargs, # noqa
**kwargs,
)
)
else:
@ -206,7 +206,7 @@ def surprise_me_prompts(
width=width,
height=height,
seed=seed,
**kwargs, # noqa
**kwargs,
)
)

@ -20,6 +20,8 @@ except ImportError:
# let's not break all of imaginairy just because a training import doesn't exist in an older version of PL
# Use >= 1.6.0 to make this work
DDPStrategy = None
import contextlib
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.distributed import rank_zero_only
@ -220,10 +222,8 @@ class SetupCallback(Callback):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
try:
with contextlib.suppress(FileNotFoundError):
os.rename(self.logdir, dst)
except FileNotFoundError:
pass
class ImageLogger(Callback):
@ -342,11 +342,12 @@ class ImageLogger(Callback):
):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split="val")
if hasattr(pl_module, "calibrate_grad_norm"):
if (
pl_module.calibrate_grad_norm and batch_idx % 25 == 0
) and batch_idx > 0:
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
if (
hasattr(pl_module, "calibrate_grad_norm")
and (pl_module.calibrate_grad_norm and batch_idx % 25 == 0)
and batch_idx > 0
):
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
class CUDACallback(Callback):
@ -356,9 +357,9 @@ class CUDACallback(Callback):
if "cuda" in get_device():
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
torch.cuda.synchronize(trainer.strategy.root_device.index)
self.start_time = time.time() # noqa
self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module): # noqa
def on_train_epoch_end(self, trainer, pl_module):
if "cuda" in get_device():
torch.cuda.synchronize(trainer.strategy.root_device.index)
max_memory = (
@ -394,19 +395,20 @@ def train_diffusion_model(
accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf
"""
if DDPStrategy is None:
raise ImportError("Please install pytorch-lightning>=1.6.0 to train a model")
msg = "Please install pytorch-lightning>=1.6.0 to train a model"
raise ImportError(msg)
batch_size = 1
seed = 23
num_workers = 1
num_val_workers = 0
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # noqa: DTZ005
logdir = os.path.join(logdir, now)
ckpt_output_dir = os.path.join(logdir, "checkpoints")
cfg_output_dir = os.path.join(logdir, "configs")
seed_everything(seed)
model = get_diffusion_model( # noqa
model = get_diffusion_model(
weights_location=weights_location, half_mode=False, for_training=True
)._model
model.learning_rate = learning_rate * accumulate_grad_batches * batch_size
@ -501,9 +503,7 @@ def train_diffusion_model(
num_sanity_val_steps=0,
accumulate_grad_batches=accumulate_grad_batches,
strategy=DDPStrategy(),
callbacks=[
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg # noqa
],
callbacks=[instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg],
gpus=1,
default_root_dir=".",
)

@ -145,7 +145,7 @@ def create_class_images(class_description, output_folder, num_images=200):
while existing_image_count < num_images:
prompt = ImaginePrompt(class_description, steps=20)
result = list(imagine([prompt]))[0]
result = next(iter(imagine([prompt])))
if result.is_nsfw:
continue
dest = os.path.join(

@ -29,7 +29,7 @@ def prune_model_data(data, only_keep_ema=True):
data.pop("optimizer_states", None)
if only_keep_ema:
state_dict = data["state_dict"]
model_keys = [k for k in state_dict.keys() if k.startswith("model.")]
model_keys = [k for k in state_dict if k.startswith("model.")]
for model_key in model_keys:
ema_key = "model_ema." + model_key[6:].replace(".", "")

@ -92,7 +92,8 @@ class SingleConceptDataset(Dataset):
try:
image = Image.open(img_path).convert("RGB")
except RuntimeError as e:
raise RuntimeError(f"Could not read image {img_path}") from e
msg = f"Could not read image {img_path}"
raise RuntimeError(msg) from e
image = self.image_transforms(image)
data = {"image": image, "txt": txt}
return data

@ -14,7 +14,7 @@ from torch.overrides import handle_torch_function, has_torch_function_variadic
logger = logging.getLogger(__name__)
@lru_cache()
@lru_cache
def get_device() -> str:
"""Return the best torch backend available."""
if torch.cuda.is_available():
@ -26,7 +26,7 @@ def get_device() -> str:
return "cpu"
@lru_cache()
@lru_cache
def get_hardware_description(device_type: str) -> str:
"""Description of the hardware being used."""
desc = platform.platform()
@ -185,10 +185,9 @@ def check_torch_working():
torch.randn(1, device=get_device())
except RuntimeError as e:
if "CUDA" in str(e):
raise RuntimeError(
"CUDA is not working. Make sure you have a GPU and CUDA installed."
) from e
raise e
msg = "CUDA is not working. Make sure you have a GPU and CUDA installed."
raise RuntimeError(msg) from e
raise
def frange(start, stop, step):
@ -209,7 +208,7 @@ def shrink_list(items, max_size):
new_items = {}
for i, item in enumerate(items):
new_items[int(i / removal_ratio)] = item
return [items[0]] + list(new_items.values())
return [items[0], *list(new_items.values())]
def glob_expand_paths(paths):

@ -1,3 +1,4 @@
import contextlib
import math
import sys
from copy import deepcopy
@ -77,7 +78,7 @@ class DataDistorter:
def __init__(self, data, add_data_values=True):
self.data = deepcopy(data)
self.data_map, self.data_unique_values = create_node_map(self.data)
self.distortion_values = DISTORTED_VALUES + []
self.distortion_values = [*DISTORTED_VALUES]
if add_data_values:
self.distortion_values += list(self.data_unique_values)
@ -141,15 +142,13 @@ def create_node_map(data: Union[dict, list, tuple]) -> Tuple[Dict[int, list], se
if isinstance(curr_data, dict):
for key, value in curr_data.items():
_traverse(value, curr_path + [key])
_traverse(value, [*curr_path, key])
elif isinstance(curr_data, (list, tuple)):
for idx, item in enumerate(curr_data):
_traverse(item, curr_path + [idx])
_traverse(item, [*curr_path, idx])
else:
try:
with contextlib.suppress(TypeError):
node_values.add(curr_data)
except TypeError:
pass
_traverse(data, [])
return node_map, node_values

@ -203,9 +203,8 @@ class GPUModelCache:
total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2)
pct_to_use = float(self._max_cpu_memory_gb[:-1]) / 100.0
return total_ram_gb * pct_to_use * (1024**3)
raise ValueError(
f"Invalid value for max_cpu_memory_gb: {self._max_cpu_memory_gb}"
)
msg = f"Invalid value for max_cpu_memory_gb: {self._max_cpu_memory_gb}"
raise ValueError(msg)
return self._max_cpu_memory_gb * (1024**3)
@cached_property
@ -224,9 +223,8 @@ class GPUModelCache:
total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2)
pct_to_use = float(self._max_gpu_memory_gb[:-1]) / 100.0
return total_ram_gb * pct_to_use * (1024**3)
raise ValueError(
f"Invalid value for max_gpu_memory_gb: {self._max_gpu_memory_gb}"
)
msg = f"Invalid value for max_gpu_memory_gb: {self._max_gpu_memory_gb}"
raise ValueError(msg)
return self._max_gpu_memory_gb * (1024**3)
def _move_to_gpu(self, key, model):
@ -280,12 +278,12 @@ class GPUModelCache:
import torch
if key not in self:
raise KeyError(f"The key {key} does not exist in the cache")
msg = f"The key {key} does not exist in the cache"
raise KeyError(msg)
if key in self.cpu_cache:
if self.device != torch.device("cpu"):
self.cpu_cache.move_to_end(key)
self._move_to_gpu(key, self.cpu_cache[key])
if key in self.cpu_cache and self.device != torch.device("cpu"):
self.cpu_cache.move_to_end(key)
self._move_to_gpu(key, self.cpu_cache[key])
if key in self.gpu_cache:
self.gpu_cache.move_to_end(key)
@ -337,7 +335,7 @@ class MemoryManagedModelWrapper:
self._mmmw_kwargs = kwargs
self._mmmw_namespace = namespace
self._mmmw_estimated_ram_size_mb = estimated_ram_size_mb
self._mmmw_cache_key = (namespace,) + args + tuple(kwargs.items())
self._mmmw_cache_key = (namespace, *args, *tuple(kwargs.items()))
def _mmmw_load_model(self):
if self._mmmw_cache_key not in self.__class__._mmmw_cache:

@ -85,7 +85,7 @@ class BatchedBrownianTree:
seed = [seed]
self.batched = False
self.trees = [
torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed
torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed # noqa
]
@staticmethod

@ -1,10 +1,6 @@
black
coverage
isort
ruff
pycln
pylama
pylint
pytest
pytest-randomly
pytest-sugar

@ -16,8 +16,6 @@ anyio==3.7.1
# via
# fastapi
# starlette
astroid==2.15.8
# via pylint
async-timeout==4.0.3
# via aiohttp
attrs==23.1.0
@ -36,7 +34,6 @@ click==8.1.7
# click-help-colors
# click-shell
# imaginAIry (setup.py)
# typer
# uvicorn
click-help-colors==0.9.2
# via imaginAIry (setup.py)
@ -48,10 +45,8 @@ coverage==7.3.1
# via -r requirements-dev.in
cycler==0.12.0
# via matplotlib
diffusers==0.21.3
diffusers==0.21.4
# via imaginAIry (setup.py)
dill==0.3.7
# via pylint
einops==0.6.1
# via imaginAIry (setup.py)
exceptiongroup==1.1.3
@ -71,7 +66,7 @@ filelock==3.12.4
# transformers
filterpy==1.4.5
# via facexlib
fonttools==4.42.1
fonttools==4.43.0
# via matplotlib
frozenlist==1.4.0
# via
@ -104,18 +99,10 @@ importlib-metadata==6.8.0
# via diffusers
iniconfig==2.0.0
# via pytest
isort==5.12.0
# via
# -r requirements-dev.in
# pylint
kiwisolver==1.4.5
# via matplotlib
kornia==0.7.0
# via imaginAIry (setup.py)
lazy-object-proxy==1.9.0
# via astroid
libcst==1.0.1
# via pycln
lightning-utilities==0.9.0
# via
# pytorch-lightning
@ -126,18 +113,12 @@ matplotlib==3.7.3
# via
# -c tests/constraints.txt
# filterpy
mccabe==0.7.0
# via
# pylama
# pylint
multidict==6.0.4
# via
# aiohttp
# yarl
mypy-extensions==1.0.0
# via
# black
# typing-inspect
# via black
numba==0.58.0
# via facexlib
numpy==1.24.4
@ -178,9 +159,7 @@ packaging==23.1
# pytorch-lightning
# transformers
pathspec==0.11.2
# via
# black
# pycln
# via black
pillow==10.0.1
# via
# diffusers
@ -190,9 +169,7 @@ pillow==10.0.1
# matplotlib
# torchvision
platformdirs==3.10.0
# via
# black
# pylint
# via black
pluggy==1.3.0
# via pytest
protobuf==3.20.3
@ -201,24 +178,12 @@ protobuf==3.20.3
# open-clip-torch
psutil==5.9.5
# via imaginAIry (setup.py)
pycln==2.2.2
# via -r requirements-dev.in
pycodestyle==2.11.0
# via pylama
pydantic==2.4.2
# via
# fastapi
# imaginAIry (setup.py)
pydantic-core==2.10.1
# via pydantic
pydocstyle==6.3.0
# via pylama
pyflakes==3.1.0
# via pylama
pylama==8.4.1
# via -r requirements-dev.in
pylint==2.17.6
# via -r requirements-dev.in
pyparsing==3.1.1
# via matplotlib
pytest==7.4.2
@ -237,9 +202,7 @@ pytorch-lightning==1.9.5
pyyaml==6.0.1
# via
# huggingface-hub
# libcst
# omegaconf
# pycln
# pytorch-lightning
# responses
# timm
@ -280,8 +243,6 @@ six==1.16.0
# via python-dateutil
sniffio==1.3.0
# via anyio
snowballstemmer==2.2.0
# via pydocstyle
starlette==0.27.0
# via fastapi
termcolor==2.3.0
@ -295,12 +256,7 @@ tokenizers==0.13.3
tomli==2.0.1
# via
# black
# pylint
# pytest
tomlkit==0.12.1
# via
# pycln
# pylint
torch==1.13.1
# via
# facexlib
@ -335,28 +291,20 @@ tqdm==4.66.1
# transformers
transformers==4.33.3
# via imaginAIry (setup.py)
typer==0.9.0
# via pycln
types-pyyaml==6.0.12.12
# via responses
typing-extensions==4.8.0
# via
# astroid
# black
# fastapi
# huggingface-hub
# libcst
# lightning-utilities
# pydantic
# pydantic-core
# pytorch-lightning
# torch
# torchvision
# typer
# typing-inspect
# uvicorn
typing-inspect==0.9.0
# via libcst
urllib3==2.0.5
# via
# requests
@ -367,8 +315,6 @@ wcwidth==0.2.7
# via ftfy
wheel==0.41.2
# via -r requirements-dev.in
wrapt==1.15.0
# via astroid
yarl==1.9.2
# via aiohttp
zipp==3.17.0

@ -52,8 +52,8 @@ def main():
time.sleep(1)
controlnet_statedict = torch.load(controlnet_path, map_location="cpu")
print("\n\nComparing reconstructed controlnet with original")
for k in controlnet_statedict.keys():
if k not in reconstituted_controlnet_statedict.keys():
for k in controlnet_statedict:
if k not in reconstituted_controlnet_statedict:
print(f"Key {k} not in reconstituted")
elif (
controlnet_statedict[k].shape

@ -54,7 +54,7 @@ def make_txts():
with open(src_json, encoding="utf-8") as f:
prompts = json.load(f)
categories = []
for c in prompts.keys():
for c in prompts:
if any(c.startswith(p) for p in excluded_prefixes):
continue
categories.append(c)

@ -19,7 +19,7 @@ else:
entry_points = None
@lru_cache()
@lru_cache
def get_git_revision_hash() -> str:
try:
return (

@ -1,3 +1,4 @@
import contextlib
import logging
import os
import sys
@ -33,16 +34,14 @@ elif get_device() == "cpu":
@pytest.fixture(scope="session", autouse=True)
def pre_setup():
def _pre_setup():
api.IMAGINAIRY_SAFETY_MODE = "disabled"
suppress_annoying_logs_and_warnings()
test_output_folder = f"{TESTS_FOLDER}/test_output"
# delete the testoutput folder and recreate it
try:
with contextlib.suppress(FileNotFoundError):
rmtree(test_output_folder)
except FileNotFoundError:
pass
os.makedirs(test_output_folder, exist_ok=True)
orig_urlopen = HTTPConnectionPool.urlopen
@ -73,7 +72,7 @@ def pre_setup():
@pytest.fixture(autouse=True)
def reset_get_device():
def _reset_get_device():
get_device.cache_clear()
@ -94,7 +93,7 @@ def sampler_type(request):
return request.param
@pytest.fixture
@pytest.fixture()
def mocked_responses():
with responses.RequestsMock() as rsps:
yield rsps

@ -12,7 +12,7 @@ blur_params = [
]
@pytest.mark.parametrize("img_path,expected", blur_params)
@pytest.mark.parametrize(("img_path", "expected"), blur_params)
def test_calculate_blurriness_level(img_path, expected):
img = Image.open(img_path)

@ -15,7 +15,7 @@ def control_img_to_pillow_img(img_t):
control_mode_params = list(CONTROL_MODES.items())
@pytest.mark.parametrize("control_name,control_func", control_mode_params)
@pytest.mark.parametrize(("control_name", "control_func"), control_mode_params)
def test_control_images(filename_base_for_outputs, control_func, control_name):
seed_everything(42)
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png")

@ -26,7 +26,7 @@ strat_combos = [
@pytest.mark.skipif(True, reason="Run manually as needed. Uses too much memory.")
@pytest.mark.parametrize("encode_strat,decode_strat", strat_combos)
@pytest.mark.parametrize(("encode_strat", "decode_strat"), strat_combos)
def test_encode_decode(filename_base_for_outputs, encode_strat, decode_strat):
"""
Test that encoding and decoding works.

@ -0,0 +1,14 @@
extend-ignore = ["E501", "G004", "PT005", "RET504", "SIM114", "TRY003", "TRY400", "TRY401", "RUF012", "RUF100"]
extend-exclude = ["imaginairy/vendored", "downloads", "other"]
extend-select = [
"I", "E", "W", "UP", "ASYNC", "BLE", "A001", "A002",
"C4", "DTZ", "T10", "EM", "ISC", "ICN", "G", "PIE", "PT",
"Q", "SIM", "TID", "TCH", "PLC", "PLE", "TRY", "RUF"
]
[isort]
combine-as-imports = true
[flake8-errmsg]
max-string-length = 50

@ -174,10 +174,7 @@ def test_img_to_img_fruit_2_gold(
filepath=os.path.join(TESTS_FOLDER, "data", "bowl_of_fruit.jpg")
)
target_steps = 25
if init_strength >= 1:
needed_steps = 25
else:
needed_steps = int(target_steps / (1 - init_strength))
needed_steps = 25 if init_strength >= 1 else int(target_steps / (1 - init_strength))
prompt = ImaginePrompt(
"a white bowl filled with gold coins",
prompt_strength=12,

@ -126,7 +126,7 @@ boolean_mask_test_cases = [
]
@pytest.mark.parametrize("mask_text,expected", boolean_mask_test_cases)
@pytest.mark.parametrize(("mask_text", "expected"), boolean_mask_test_cases)
def test_clip_mask_parser(mask_text, expected):
parsed = MASK_PROMPT.parseString(mask_text)[0][0]
assert str(parsed) == expected

@ -46,7 +46,7 @@ cases = [
]
@pytest.mark.parametrize("img_ratio, tile_size, overlap_pct", cases)
@pytest.mark.parametrize(("img_ratio", "tile_size", "overlap_pct"), cases)
def test_feather_tile_simple(img_ratio, tile_size, overlap_pct):
img = pillow_img_to_torch_image(
LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bowl_of_fruit.jpg")

@ -23,7 +23,7 @@ def test_outpainting_outpaint(filename_base_for_outputs):
steps=20,
seed=542906833,
)
result = list(imagine([prompt]))[0]
result = next(iter(imagine([prompt])))
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=17000)
@ -37,6 +37,6 @@ outpaint_test_params = [
]
@pytest.mark.parametrize("arg_str, expected_kwargs", outpaint_test_params)
@pytest.mark.parametrize(("arg_str", "expected_kwargs"), outpaint_test_params)
def test_outpaint_parse_kwargs(arg_str, expected_kwargs):
assert outpaint_arg_str_parse(arg_str) == expected_kwargs

@ -5,7 +5,7 @@ from imaginairy.utils import frange
@pytest.mark.parametrize(
"schedule_str,expected",
("schedule_str", "expected"),
[
("prompt_strength[2:40:1]", ("prompt_strength", list(range(2, 40)))),
("prompt_strength[2:40:0.5]", ("prompt_strength", list(frange(2, 40, 0.5)))),

@ -26,7 +26,7 @@ def _red_url(mocked_responses):
status=200,
content_type="image/png",
)
yield url
return url
@pytest.fixture(name="red_path")

@ -107,7 +107,7 @@ def test_cache_ordering():
)
cache.set("key-0", create_model_of_n_bytes(4_000_000))
assert list(cache.cpu_cache.keys()) == [] # noqa
assert list(cache.cpu_cache.keys()) == []
assert list(cache.gpu_cache.keys()) == ["key-0"]
assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == (
0,
@ -115,7 +115,7 @@ def test_cache_ordering():
)
cache.set("key-1", create_model_of_n_bytes(4_000_000))
assert list(cache.cpu_cache.keys()) == [] # noqa
assert list(cache.cpu_cache.keys()) == []
assert list(cache.gpu_cache.keys()) == ["key-0", "key-1"]
assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == (
0,

@ -68,7 +68,7 @@ def test_instantiate_from_config():
"params": {"year": 2002, "month": 10, "day": 1},
}
o = instantiate_from_config(config)
assert o == datetime(2002, 10, 1)
assert o == datetime(2002, 10, 1) # noqa: DTZ001
config = "__is_first_stage__"
assert instantiate_from_config(config) is None

@ -4,25 +4,3 @@ norecursedirs = build dist downloads other prolly_delete imaginairy/vendored
filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
[pylama]
format = pylint
skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads/*,imaginairy/vendored/*,testing_support/vastai_cli_official.py,.eggs/*
linters = pylint,pycodestyle,pyflakes,mypy
ignore =
Z999,C0103,C0201,C0301,C0302,C0114,C0115,C0116,C0415,
Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D406,D407,D413,D415,D417,
Z999,E203,E501,E1101,E1121,E1131,E1133,E1135,E1136,
Z999,R0901,R0902,R0903,R0904,R0193,R0912,R0913,R0914,R0915,R1702,
Z999,W0221,W0511,W0612,W0613,W0632,W1203
[pylama:tests/*]
ignore = C0104,C0114,C0116,D103,W0143,W0613
[pylama:*/__init__.py]
ignore = D104
[pylama:pylint]
generated_members=torch.*
extension-pkg-whitelist=pydantic

Loading…
Cancel
Save