diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 06b2792..2982960 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -9,35 +9,42 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: 3.9 - cache: pip - cache-dependency-path: requirements-dev.txt - - name: Install dependencies - run: | - python -m pip install --disable-pip-version-check wheel pip-tools - pip-sync requirements-dev.txt - python -m pip install --disable-pip-version-check --no-deps . - - name: Lint - run: | - echo "::add-matcher::.github/pylama_matcher.json" - pylama --options tox.ini + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4.5.0 + with: + python-version: 3.9 + - name: Cache dependencies + uses: actions/cache@v3.2.4 + id: cache + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('requirements-dev.txt') }}-lint + - name: Install Ruff + if: steps.cache.outputs.cache-hit != 'true' + run: grep -E 'ruff==' requirements-dev.txt | xargs pip install + - name: Lint + run: | + echo "::add-matcher::.github/pylama_matcher.json" + ruff --config tests/ruff.toml . autoformat: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: 3.9 - - name: Install dependencies - run: | - python -m pip install --disable-pip-version-check black==23.1.0 isort==5.12.0 - - name: Autoformatter - run: | - black --diff . - isort --atomic --profile black --check-only . + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4.5.0 + with: + python-version: 3.9 + - name: Cache dependencies + uses: actions/cache@v3.2.4 + id: cache + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('requirements-dev.txt') }}-autoformat + - name: Install Black + if: steps.cache.outputs.cache-hit != 'true' + run: grep -E 'black==' requirements-dev.txt | xargs pip install + - name: Lint + run: | + black --diff --fast . test: runs-on: ubuntu-latest strategy: diff --git a/Makefile b/Makefile index 7cfb04e..2aeae06 100644 --- a/Makefile +++ b/Makefile @@ -28,18 +28,16 @@ init: require_pyenv ## Setup a dev environment for local development. af: autoformat ## Alias for `autoformat` autoformat: ## Run the autoformatter. - @pycln . --all --quiet --extend-exclude __init__\.py - @# ERA,T201 - @-ruff --extend-ignore ANN,ARG001,C90,DTZ,D100,D101,D102,D103,D202,D203,D212,D415,E501,RET504,S101,UP006,UP007 --extend-select C,D400,I,W --unfixable T,ERA --fix-only . + @-ruff check --config tests/ruff.toml . --fix-only @black . - @isort --atomic --profile black --skip downloads/** . + test: ## Run the tests. @pytest @echo -e "The tests pass! ✨ 🍰 ✨" lint: ## Run the code linter. - @pylama + @ruff check --config tests/ruff.toml . @echo -e "No linting errors - well done! ✨ 🍰 ✨" deploy: ## Deploy the package to pypi.org diff --git a/docs/examples/immortal_pearl_earring.py b/docs/examples/immortal_pearl_earring.py index a2e611e..aaaa5d9 100644 --- a/docs/examples/immortal_pearl_earring.py +++ b/docs/examples/immortal_pearl_earring.py @@ -63,7 +63,7 @@ def generate_image_morph_video(): if os.path.exists(filename): continue - result = list(imagine([prompt]))[0] + result = next(iter(imagine([prompt]))) generated_image = result.images["generated"] draw = ImageDraw.Draw(generated_image) diff --git a/imaginairy/animations.py b/imaginairy/animations.py index 7bd2998..9afe2a6 100644 --- a/imaginairy/animations.py +++ b/imaginairy/animations.py @@ -33,7 +33,7 @@ def make_bounce_animation( middle_imgs = shrink_list(middle_imgs, max_frames) - frames = [first_img] + middle_imgs + [last_img] + list(reversed(middle_imgs)) + frames = [first_img, *middle_imgs, last_img, *list(reversed(middle_imgs))] # convert from latents converted_frames = [] diff --git a/imaginairy/api.py b/imaginairy/api.py index c9162fe..745bc59 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -92,13 +92,13 @@ def imagine_image_files( os.makedirs(subpath, exist_ok=True) filepath = os.path.join(subpath, f"{basefilename}.gif") - frames = result.progress_latents + [result.images["generated"]] + frames = [*result.progress_latents, result.images["generated"]] if prompt.init_image: resized_init_image = pillow_fit_image_within( prompt.init_image, prompt.width, prompt.height ) - frames = [resized_init_image] + frames + frames = [resized_init_image, *frames] frames.reverse() make_bounce_animation( imgs=frames, @@ -170,7 +170,7 @@ def imagine( logger.info( f"🖼 Generating {i + 1}/{num_prompts}: {prompt.prompt_description()}" ) - for attempt in range(0, unsafe_retry_count + 1): + for attempt in range(unsafe_retry_count + 1): if attempt > 0 and isinstance(prompt.seed, int): prompt.seed += 100_000_000 + attempt result = _generate_single_image( @@ -238,7 +238,7 @@ def _generate_single_image( latent_channels = 4 downsampling_factor = 8 batch_size = 1 - global _most_recent_result # noqa + global _most_recent_result # handle prompt pulling in previous values # if isinstance(prompt.init_image, str) and prompt.init_image.startswith("*prev"): # _, img_type = prompt.init_image.strip("*").split(".") @@ -457,16 +457,17 @@ def _generate_single_image( if control_image_t.shape[1] != 3: raise RuntimeError("Control image must have 3 channels") - if control_input.mode != "inpaint": - if control_image_t.min() < 0 or control_image_t.max() > 1: - raise RuntimeError( - f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}" - ) + if ( + control_input.mode != "inpaint" + and control_image_t.min() < 0 + or control_image_t.max() > 1 + ): + msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}" + raise RuntimeError(msg) if control_image_t.max() == control_image_t.min(): - raise RuntimeError( - f"No control signal found in control image {control_input.mode}." - ) + msg = f"No control signal found in control image {control_input.mode}." + raise RuntimeError(msg) c_cat.append(control_image_t) control_strengths.append(control_input.strength) @@ -517,7 +518,7 @@ def _generate_single_image( if ( prompt.allow_compose_phase and not is_controlnet_model - and not model.cond_stage_key == "edit" + and model.cond_stage_key != "edit" ): if prompt.init_image: comp_image = _generate_composition_image( diff --git a/imaginairy/cli/clickshell_mod.py b/imaginairy/cli/clickshell_mod.py index 7c424d9..4ef3c0e 100644 --- a/imaginairy/cli/clickshell_mod.py +++ b/imaginairy/cli/clickshell_mod.py @@ -4,6 +4,7 @@ import logging import shlex import traceback from functools import update_wrapper +from typing import ClassVar import click from click_help_colors import HelpColorsCommand, HelpColorsMixin @@ -43,27 +44,23 @@ def mod_get_invoke(command): # and that's not ideal when running in a shell. pass except Exception as e: # noqa - traceback.print_exception(e) # noqa + traceback.print_exception(e) # logger.warning(traceback.format_exc()) # Always return False so the shell doesn't exit return False invoke_ = update_wrapper(invoke_, command.callback) - invoke_.__name__ = "do_%s" % command.name # noqa + invoke_.__name__ = "do_%s" % command.name return invoke_ class ModClickShell(ClickShell): def add_command(self, cmd, name): # Use the MethodType to add these as bound methods to our current instance - setattr( - self, "do_%s" % name, get_method_type(mod_get_invoke(cmd), self) # noqa - ) - setattr(self, "help_%s" % name, get_method_type(get_help(cmd), self)) # noqa - setattr( - self, "complete_%s" % name, get_method_type(get_complete(cmd), self) # noqa - ) + setattr(self, "do_%s" % name, get_method_type(mod_get_invoke(cmd), self)) + setattr(self, "help_%s" % name, get_method_type(get_help(cmd), self)) + setattr(self, "complete_%s" % name, get_method_type(get_complete(cmd), self)) class ModShell(Shell): @@ -85,7 +82,7 @@ class ColorShell(HelpColorsMixin, ModShell): class ImagineColorsCommand(HelpColorsCommand): - _option_order = [] + _option_order: ClassVar = [] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/imaginairy/cli/edit.py b/imaginairy/cli/edit.py index 9b1b771..5c911c9 100644 --- a/imaginairy/cli/edit.py +++ b/imaginairy/cli/edit.py @@ -43,7 +43,7 @@ remove_option(edit_options, "allow_compose_phase") ) @add_options(edit_options) @click.pass_context -def edit_cmd( # noqa +def edit_cmd( ctx, image_paths, image_strength, @@ -77,7 +77,7 @@ def edit_cmd( # noqa model_weights_path, model_config_path, prompt_library_path, - version, # noqa + version, make_gif, make_compare_gif, arg_schedules, @@ -130,7 +130,7 @@ def edit_cmd( # noqa model_weights_path, model_config_path, prompt_library_path, - version, # noqa + version, make_gif, make_compare_gif, arg_schedules, diff --git a/imaginairy/cli/imagine.py b/imaginairy/cli/imagine.py index 15c74ca..5c130ec 100644 --- a/imaginairy/cli/imagine.py +++ b/imaginairy/cli/imagine.py @@ -90,7 +90,7 @@ def imagine_cmd( model_weights_path, model_config_path, prompt_library_path, - version, # noqa + version, make_gif, make_compare_gif, arg_schedules, @@ -110,7 +110,7 @@ def imagine_cmd( # hacky method of getting order of control images (mixing raw and normal images) control_images = [ (o, path) - for o, path in ImagineColorsCommand._option_order # noqa + for o, path in ImagineColorsCommand._option_order if o.name in ("control_image", "control_image_raw") ] control_inputs = [] @@ -176,7 +176,7 @@ def imagine_cmd( model_weights_path, model_config_path, prompt_library_path, - version, # noqa + version, make_gif, make_compare_gif, arg_schedules, @@ -187,4 +187,4 @@ def imagine_cmd( if __name__ == "__main__": - imagine_cmd() # noqa + imagine_cmd() diff --git a/imaginairy/cli/main.py b/imaginairy/cli/main.py index 110b48a..840e5ba 100644 --- a/imaginairy/cli/main.py +++ b/imaginairy/cli/main.py @@ -92,4 +92,4 @@ def model_list_cmd(): if __name__ == "__main__": - aimg() # noqa + aimg() diff --git a/imaginairy/cli/shared.py b/imaginairy/cli/shared.py index 6cf7fd7..a6982a7 100644 --- a/imaginairy/cli/shared.py +++ b/imaginairy/cli/shared.py @@ -43,7 +43,7 @@ def _imagine_cmd( model_weights_path, model_config_path, prompt_library_path, - version=False, # noqa + version=False, make_gif=False, make_compare_gif=False, arg_schedules=None, @@ -78,10 +78,7 @@ def _imagine_cmd( configure_logging(log_level) - if isinstance(init_image, str): - init_images = [init_image] - else: - init_images = init_image + init_images = [init_image] if isinstance(init_image, str) else init_image from imaginairy.utils import glob_expand_paths @@ -89,9 +86,8 @@ def _imagine_cmd( init_images = glob_expand_paths(init_images) if len(init_images) < num_prexpaned_init_images: - raise ValueError( - f"Could not find any images matching the glob pattern(s) {init_image}. Are you sure the file(s) exists?" - ) + msg = f"Could not find any images matching the glob pattern(s) {init_image}. Are you sure the file(s) exists?" + raise ValueError(msg) total_image_count = len(prompt_texts) * max(len(init_images), 1) * repeats logger.info( @@ -227,7 +223,8 @@ def replace_option(options, option_name, new_option): if option.name == option_name: options[i] = new_option return - raise ValueError(f"Option {option_name} not found") + msg = f"Option {option_name} not found" + raise ValueError(msg) def remove_option(options, option_name): @@ -242,7 +239,8 @@ def remove_option(options, option_name): if option.name == option_name: del options[i] return - raise ValueError(f"Option {option_name} not found") + msg = f"Option {option_name} not found" + raise ValueError(msg) common_options = [ diff --git a/imaginairy/colorize.py b/imaginairy/colorize.py index 7e86f0b..c3ec536 100644 --- a/imaginairy/colorize.py +++ b/imaginairy/colorize.py @@ -36,7 +36,7 @@ def colorize_img(img, max_width=1024, max_height=1024, caption=None): steps=30, prompt_strength=12, ) - result = list(imagine(prompt))[0] + result = next(iter(imagine(prompt))) colorized_img = replace_color(img, result.images["generated"]) # allows the algorithm some leeway for the overall brightness of the image diff --git a/imaginairy/enhancers/bool_masker.py b/imaginairy/enhancers/bool_masker.py index dd0cc17..4689349 100644 --- a/imaginairy/enhancers/bool_masker.py +++ b/imaginairy/enhancers/bool_masker.py @@ -18,6 +18,7 @@ Examples: """ import operator from abc import ABC +from typing import ClassVar import pyparsing as pp import torch @@ -57,7 +58,7 @@ class SimpleMask(Mask): class ModifiedMask(Mask): - ops = { + ops: ClassVar = { "+": operator.add, "-": operator.sub, "*": operator.mul, @@ -80,7 +81,7 @@ class ModifiedMask(Mask): return cls(mask=ret_tokens[0][0], modifier=ret_tokens[0][1]) def __repr__(self): - return f"{repr(self.mask)}{self.modifier}" + return f"{self.mask!r}{self.modifier}" def gather_text_descriptions(self): return self.mask.gather_text_descriptions() @@ -141,7 +142,8 @@ class NestedMask(Mask): elif self.op == "NOT": mask = 1 - mask else: - raise ValueError(f"Invalid operand {self.op}") + msg = f"Invalid operand {self.op}" + raise ValueError(msg) return torch.clamp(mask, 0, 1) diff --git a/imaginairy/enhancers/clip_masking.py b/imaginairy/enhancers/clip_masking.py index 60a5354..74c78a6 100644 --- a/imaginairy/enhancers/clip_masking.py +++ b/imaginairy/enhancers/clip_masking.py @@ -14,9 +14,9 @@ from imaginairy.vendored.clipseg import CLIPDensePredT weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth" -@lru_cache() +@lru_cache def clip_mask_model(): - from imaginairy.paths import PKG_ROOT # noqa + from imaginairy.paths import PKG_ROOT model = CLIPDensePredT(version="ViT-B/16", reduce_dim=64, complex_trans_conv=True) model.eval() @@ -36,7 +36,7 @@ def get_img_mask( mask_description_statement: str, threshold: Optional[float] = None, ): - from imaginairy.enhancers.bool_masker import MASK_PROMPT # noqa + from imaginairy.enhancers.bool_masker import MASK_PROMPT parsed = MASK_PROMPT.parseString(mask_description_statement) parsed_mask = parsed[0][0] diff --git a/imaginairy/enhancers/describe_image_blip.py b/imaginairy/enhancers/describe_image_blip.py index 224f6cb..12fa6e9 100644 --- a/imaginairy/enhancers/describe_image_blip.py +++ b/imaginairy/enhancers/describe_image_blip.py @@ -17,9 +17,9 @@ if "mps" in device: BLIP_EVAL_SIZE = 384 -@lru_cache() +@lru_cache def blip_model(): - from imaginairy.paths import PKG_ROOT # noqa + from imaginairy.paths import PKG_ROOT config_path = os.path.join( PKG_ROOT, "vendored", "blip", "configs", "med_config.json" @@ -28,7 +28,7 @@ def blip_model(): model = BLIP_Decoder(image_size=BLIP_EVAL_SIZE, vit="base", med_config=config_path) cached_url_path = get_cached_url_path(url) - model, msg = load_checkpoint(model, cached_url_path) # noqa + model, msg = load_checkpoint(model, cached_url_path) model.eval() model = model.to(device) return model diff --git a/imaginairy/enhancers/describe_image_clip.py b/imaginairy/enhancers/describe_image_clip.py index f2a4cd3..42cec54 100644 --- a/imaginairy/enhancers/describe_image_clip.py +++ b/imaginairy/enhancers/describe_image_clip.py @@ -10,7 +10,7 @@ from imaginairy.vendored import clip device = "cuda" if torch.cuda.is_available() else "cpu" -@lru_cache() +@lru_cache def get_model(): model_name = "ViT-L/14" model, preprocess = clip.load(model_name, device=device) diff --git a/imaginairy/enhancers/face_restoration_codeformer.py b/imaginairy/enhancers/face_restoration_codeformer.py index 70378bc..377139c 100644 --- a/imaginairy/enhancers/face_restoration_codeformer.py +++ b/imaginairy/enhancers/face_restoration_codeformer.py @@ -17,7 +17,7 @@ face_restore_device = torch.device("cuda" if torch.cuda.is_available() else "cpu half_mode = face_restore_device == "cuda" -@lru_cache() +@lru_cache def codeformer_model(): model = CodeFormer( dim_embd=512, @@ -36,7 +36,7 @@ def codeformer_model(): return model -@lru_cache() +@lru_cache def face_restore_helper(): """ Provide a singleton of FaceRestoreHelper. @@ -85,11 +85,11 @@ def enhance_faces(img, fidelity=0): try: with torch.no_grad(): - output = net(cropped_face_t, w=fidelity, adain=True)[0] # noqa + output = net(cropped_face_t, w=fidelity, adain=True)[0] restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) del output torch.cuda.empty_cache() - except Exception as error: # noqa + except Exception as error: logger.exception(f"\tFailed inference for CodeFormer: {error}") restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) diff --git a/imaginairy/enhancers/prompt_expansion.py b/imaginairy/enhancers/prompt_expansion.py index d465f65..044c636 100644 --- a/imaginairy/enhancers/prompt_expansion.py +++ b/imaginairy/enhancers/prompt_expansion.py @@ -15,10 +15,10 @@ formatter = Formatter() PROMPT_EXPANSION_PATTERN = re.compile(r"[|a-z0-9_ -]+") -@lru_cache() +@lru_cache def prompt_library_filepaths(prompt_library_paths=None): """Return all available category/filepath pairs.""" - prompt_library_paths = [] if not prompt_library_paths else prompt_library_paths + prompt_library_paths = prompt_library_paths if prompt_library_paths else [] combined_prompt_library_filepaths = {} for prompt_path in DEFAULT_PROMPT_LIBRARY_PATHS + list(prompt_library_paths): library_prompts = prompt_library_filepath(prompt_path) @@ -27,7 +27,7 @@ def prompt_library_filepaths(prompt_library_paths=None): return combined_prompt_library_filepaths -@lru_cache() +@lru_cache def category_list(prompt_library_paths=None): """Return the names of available phrase-lists.""" categories = list(prompt_library_filepaths(prompt_library_paths).keys()) @@ -35,7 +35,7 @@ def category_list(prompt_library_paths=None): return categories -@lru_cache() +@lru_cache def prompt_library_filepath(library_path): lookup = {} @@ -55,9 +55,8 @@ def get_phrases(category_name, prompt_library_paths=None): try: filepath = lookup[category_name] except KeyError as e: - raise LookupError( - f"'{category_name}' is not a valid prompt expansion category. Could not find the txt file." - ) from e + msg = f"'{category_name}' is not a valid prompt expansion category. Could not find the txt file." + raise LookupError(msg) from e _open = open if filepath.endswith(".gz"): _open = gzip.open @@ -83,13 +82,12 @@ def expand_prompts(prompt_text, n=1, prompt_library_paths=None): """ prompt_parts = list(formatter.parse(prompt_text)) field_names = [] - for literal_text, field_name, format_spec, conversion in prompt_parts: # noqa + for literal_text, field_name, format_spec, conversion in prompt_parts: if field_name: field_name = field_name.lower() if not PROMPT_EXPANSION_PATTERN.match(field_name): - raise ValueError( - "Invalid prompt expansion. Only a-z0-9_|- characters permitted. " - ) + msg = "Invalid prompt expansion. Only a-z0-9_|- characters permitted. " + raise ValueError(msg) field_names.append(field_name) phrases = [] @@ -120,9 +118,7 @@ def expand_prompts(prompt_text, n=1, prompt_library_paths=None): yield output_prompt -def get_random_non_repeating_combination( # noqa - n=1, *sequences, allow_oversampling=True -): +def get_random_non_repeating_combination(n=1, *sequences, allow_oversampling=True): """ Efficiently return a non-repeating random sample of the product sequences. diff --git a/imaginairy/enhancers/upscale_riverwing.py b/imaginairy/enhancers/upscale_riverwing.py index 9f7c959..81c43a8 100644 --- a/imaginairy/enhancers/upscale_riverwing.py +++ b/imaginairy/enhancers/upscale_riverwing.py @@ -187,7 +187,7 @@ class CLIPEmbedder(nn.Module): ) -@lru_cache() +@lru_cache def clip_up_models(): with platform_appropriate_autocast(): tok_up = CLIPTokenizerTransform() @@ -290,7 +290,8 @@ def upscale_latent( eta=eta, **sampler_opts, ) - raise ValueError(f"Unknown sampler {sampler}") + msg = f"Unknown sampler {sampler}" + raise ValueError(msg) for _ in range((num_samples - 1) // batch_size + 1): if noise_aug_type == "gaussian": @@ -300,7 +301,7 @@ def upscale_latent( elif noise_aug_type == "fake": latent_noised = low_res_latent * (noise_aug_level**2 + 1) ** 0.5 extra_args = { - "low_res": latent_noised, # noqa + "low_res": latent_noised, "low_res_sigma": low_res_sigma, "c": c, } diff --git a/imaginairy/http/stablestudio/models.py b/imaginairy/http/stablestudio/models.py index 1d30928..9d62606 100644 --- a/imaginairy/http/stablestudio/models.py +++ b/imaginairy/http/stablestudio/models.py @@ -63,7 +63,7 @@ class StableStudioInput(BaseModel, extra=Extra.forbid): initial_image: Optional[StableStudioInputImage] = Field(None, alias="initialImage") @validator("seed") - def validate_seed(cls, v): # noqa + def validate_seed(cls, v): if v == 0: return None return v @@ -74,10 +74,7 @@ class StableStudioInput(BaseModel, extra=Extra.forbid): from PIL import Image - if self.prompts: - positive_prompt = self.prompts[0].text - else: - positive_prompt = None + positive_prompt = self.prompts[0].text if self.prompts else None if self.prompts and len(self.prompts) > 1: negative_prompt = self.prompts[1].text if len(self.prompts) > 1 else None else: diff --git a/imaginairy/img_processors/control_modes.py b/imaginairy/img_processors/control_modes.py index 633f5af..8ea3f5f 100644 --- a/imaginairy/img_processors/control_modes.py +++ b/imaginairy/img_processors/control_modes.py @@ -79,7 +79,7 @@ def _create_depth_map_raw(img): align_corners=False, ) - depth_pt = model(img)[0] # noqa + depth_pt = model(img)[0] return depth_pt @@ -209,7 +209,10 @@ def inpaint_prep(mask_image_t, target_image_t): def to_grayscale(img): # The dimensions of input should be (batch_size, channels, height, width) - assert img.dim() == 4 and img.size(1) == 3 + if img.dim() != 4: + raise ValueError("Input should be a 4d tensor") + if img.size(1) != 3: + raise ValueError("Input should have 3 channels") # Apply the formula to convert to grayscale. gray = ( diff --git a/imaginairy/img_processors/openpose.py b/imaginairy/img_processors/openpose.py index 8b56db9..9ff4ed7 100644 --- a/imaginairy/img_processors/openpose.py +++ b/imaginairy/img_processors/openpose.py @@ -3,7 +3,7 @@ from collections import OrderedDict from functools import lru_cache import cv2 -import matplotlib +import matplotlib as mpl import numpy as np import torch from scipy.ndimage.filters import gaussian_filter @@ -40,7 +40,7 @@ def pad_right_down_corner(img, stride, padValue): def transfer(model, model_weights): # transfer caffe model to pytorch which will match the layer name transfered_model_weights = {} - for weights_name in model.state_dict().keys(): + for weights_name in model.state_dict(): transfered_model_weights[weights_name] = model_weights[ ".".join(weights_name.split(".")[1:]) ] @@ -93,14 +93,14 @@ def draw_bodypose(canvas, candidate, subset): [255, 0, 85], ] for i in range(18): - for n in range(len(subset)): # noqa + for n in range(len(subset)): index = int(subset[n][i]) if index == -1: continue x, y = candidate[index][0:2] cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) for i in range(17): - for n in range(len(subset)): # noqa + for n in range(len(subset)): index = subset[n][np.array(limbSeq[i]) - 1] if -1 in index: continue @@ -155,8 +155,7 @@ def draw_handpose(canvas, all_hand_peaks, show_number=False): canvas, (x1, y1), (x2, y2), - matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) - * 255, + mpl.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2, ) @@ -368,7 +367,7 @@ class bodypose_model(nn.Module): ] ) - for k in blocks.keys(): + for k in blocks: blocks[k] = make_layers(blocks[k], no_relu_layers) self.model1_1 = blocks["block1_1"] @@ -473,7 +472,7 @@ class handpose_model(nn.Module): ] ) - for k in blocks.keys(): + for k in blocks: blocks[k] = make_layers(blocks[k], no_relu_layers) self.model1_0 = blocks["block1_0"] @@ -625,7 +624,7 @@ def create_body_pose(original_img_t): peaks = list( zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0]) ) # note reverse - peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks] + peaks_with_score = [(*x, map_ori[x[1], x[0]]) for x in peaks] peak_id = range(peak_counter, peak_counter + len(peaks)) peaks_with_score_and_id = [ peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id)) @@ -751,7 +750,7 @@ def create_body_pose(original_img_t): connection_candidate, key=lambda x: x[2], reverse=True ) connection = np.zeros((0, 5)) - for c in range(len(connection_candidate)): # noqa + for c in range(len(connection_candidate)): i, j, s = connection_candidate[c][0:3] if i not in connection[:, 3] and j not in connection[:, 4]: connection = np.vstack( diff --git a/imaginairy/img_utils.py b/imaginairy/img_utils.py index 0b68f34..b9adeef 100644 --- a/imaginairy/img_utils.py +++ b/imaginairy/img_utils.py @@ -101,9 +101,10 @@ def torch_img_to_pillow_img(img_t: torch.Tensor): elif img_t.shape[1] == 3: colorspace = "RGB" else: - raise ValueError( + msg = ( f"Unsupported colorspace. {img_t.shape[1]} channels in {img_t.shape} shape" ) + raise ValueError(msg) img_t = rearrange(img_t, "b c h w -> b h w c") img_t = torch.clamp((img_t + 1.0) / 2.0, min=0.0, max=1.0) img_np = (255.0 * img_t).cpu().numpy().astype(np.uint8)[0] @@ -113,7 +114,7 @@ def torch_img_to_pillow_img(img_t: torch.Tensor): def model_latent_to_pillow_img(latent: torch.Tensor) -> PIL.Image.Image: - from imaginairy.model_manager import get_current_diffusion_model # noqa + from imaginairy.model_manager import get_current_diffusion_model if len(latent.shape) == 3: latent = latent.unsqueeze(0) diff --git a/imaginairy/log_utils.py b/imaginairy/log_utils.py index ad01a4b..059882c 100644 --- a/imaginairy/log_utils.py +++ b/imaginairy/log_utils.py @@ -94,13 +94,13 @@ class ImageLoggingContext: self._prev_log_context = None def __enter__(self): - global _CURRENT_LOGGING_CONTEXT # noqa + global _CURRENT_LOGGING_CONTEXT self._prev_log_context = _CURRENT_LOGGING_CONTEXT _CURRENT_LOGGING_CONTEXT = self return self def __exit__(self, exc_type, exc_val, exc_tb): - global _CURRENT_LOGGING_CONTEXT # noqa + global _CURRENT_LOGGING_CONTEXT _CURRENT_LOGGING_CONTEXT = self._prev_log_context def timing(self, description): @@ -120,21 +120,20 @@ class ImageLoggingContext: ) def log_latents(self, latents, description): - from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa + from imaginairy.img_utils import model_latents_to_pillow_imgs if "predicted_latent" in description: if self.progress_latent_callback is not None: self.progress_latent_callback(latents) if ( self.step_count - self.last_progress_img_step - ) > self.progress_img_interval_steps: - if ( - time.perf_counter() - self.last_progress_img_ts - > self.progress_img_interval_min_s - ): - self.log_progress_latent(latents) - self.last_progress_img_step = self.step_count - self.last_progress_img_ts = time.perf_counter() + ) > self.progress_img_interval_steps and ( + time.perf_counter() - self.last_progress_img_ts + > self.progress_img_interval_min_s + ): + self.log_progress_latent(latents) + self.last_progress_img_step = self.step_count + self.last_progress_img_ts = time.perf_counter() if not self.debug_img_callback: return @@ -168,7 +167,7 @@ class ImageLoggingContext: ) def log_progress_latent(self, latent): - from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa + from imaginairy.img_utils import model_latents_to_pillow_imgs if not self.progress_img_callback: return @@ -280,7 +279,7 @@ def disable_pytorch_lighting_custom_logging(): from pytorch_lightning import _logger as pytorch_logger try: - from pytorch_lightning.utilities.seed import log # noqa + from pytorch_lightning.utilities.seed import log log.setLevel(logging.NOTSET) log.handlers = [] diff --git a/imaginairy/lr_scheduler.py b/imaginairy/lr_scheduler.py index f6d4ae4..f36ff9e 100644 --- a/imaginairy/lr_scheduler.py +++ b/imaginairy/lr_scheduler.py @@ -24,9 +24,8 @@ class LambdaWarmUpCosineScheduler: self.verbosity_interval = verbosity_interval def schedule(self, n, **kwargs): - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n < self.lr_warm_up_steps: lr = ( self.lr_max - self.lr_start @@ -66,7 +65,7 @@ class LambdaWarmUpCosineScheduler2: self.f_min = f_min self.f_max = f_max self.cycle_lengths = cycle_lengths - self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.cum_cycles = np.cumsum([0, *list(self.cycle_lengths)]) self.last_f = 0.0 self.verbosity_interval = verbosity_interval @@ -81,12 +80,11 @@ class LambdaWarmUpCosineScheduler2: def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print( - f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}" - ) + if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ cycle @@ -112,12 +110,11 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print( - f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}" - ) + if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ diff --git a/imaginairy/model_manager.py b/imaginairy/model_manager.py index d9f1898..ea4b31b 100644 --- a/imaginairy/model_manager.py +++ b/imaginairy/model_manager.py @@ -7,9 +7,11 @@ from functools import wraps import requests import torch -from huggingface_hub import HfFolder -from huggingface_hub import hf_hub_download as _hf_hub_download -from huggingface_hub import try_to_load_from_cache +from huggingface_hub import ( + HfFolder, + hf_hub_download as _hf_hub_download, + try_to_load_from_cache, +) from omegaconf import OmegaConf from safetensors.torch import load_file @@ -65,16 +67,19 @@ def load_state_dict(weights_location, half_mode=False, device=None): f'Error: "{ckpt_path}" not a valid path to model weights.\nPreconfigured models you can use: {MODEL_SHORT_NAMES}.' ) sys.exit(1) - raise e + raise except RuntimeError as e: - if "PytorchStreamReader failed reading zip archive" in str(e): - if weights_location.startswith("http"): - logger.warning("Corrupt checkpoint. deleting and re-downloading...") - os.remove(ckpt_path) - ckpt_path = get_cached_url_path(weights_location, category="weights") - state_dict = load_tensors(ckpt_path, map_location="cpu") + err_str = str(e) + if ( + "PytorchStreamReader failed reading zip archive" in err_str + and weights_location.startswith("http") + ): + logger.warning("Corrupt checkpoint. deleting and re-downloading...") + os.remove(ckpt_path) + ckpt_path = get_cached_url_path(weights_location, category="weights") + state_dict = load_tensors(ckpt_path, map_location="cpu") if state_dict is None: - raise e + raise state_dict = state_dict.get("state_dict", state_dict) @@ -166,7 +171,7 @@ def get_diffusion_model( except HuggingFaceAuthorizationError as e: if for_inpainting: logger.warning( - f"Failed to load inpainting model. Attempting to fall-back to standard model. {str(e)}" + f"Failed to load inpainting model. Attempting to fall-back to standard model. {e!s}" ) return _get_diffusion_model( iconfig.DEFAULT_MODEL, @@ -176,7 +181,7 @@ def get_diffusion_model( for_training=for_training, control_weights_locations=control_weights_locations, ) - raise e + raise def _get_diffusion_model( @@ -192,7 +197,7 @@ def _get_diffusion_model( Weights location may also be shortcut name, e.g. "SD-1.5" """ - global MOST_RECENTLY_LOADED_MODEL # noqa + global MOST_RECENTLY_LOADED_MODEL ( model_config, @@ -293,9 +298,8 @@ def resolve_model_paths( if for_training: weights_path = model_metadata_w.weights_url_full if weights_path is None: - raise ValueError( - "No full training weights configured for this model. Edit the code or subimt a github issue." - ) + msg = "No full training weights configured for this model. Edit the code or subimt a github issue." + raise ValueError(msg) else: weights_path = model_metadata_w.weights_url @@ -306,9 +310,8 @@ def resolve_model_paths( config_path = iconfig.MODEL_CONFIG_SHORTCUTS[iconfig.DEFAULT_MODEL].config_path if control_net_metadatas: if "stable-diffusion-v1" not in config_path: - raise ValueError( - "Control net is only supported for stable diffusion v1. Please use a different model." - ) + msg = "Control net is only supported for stable diffusion v1. Please use a different model." + raise ValueError(msg) control_weights_paths = [cnm.weights_url for cnm in control_net_metadatas] config_path = control_net_metadatas[0].config_path model_metadata = model_metadata_w or model_metadata_c @@ -374,7 +377,7 @@ def get_cached_url_path(url, category=None): os.rename(old_dest_path, dest_path) return dest_path - r = requests.get(url) # noqa + r = requests.get(url) with open(dest_path, "wb") as f: f.write(r.content) @@ -390,12 +393,8 @@ def check_huggingface_url_authorized(url): headers["authorization"] = f"Bearer {token}" response = requests.head(url, allow_redirects=True, headers=headers, timeout=5) if response.status_code == 401: - raise HuggingFaceAuthorizationError( - "Unauthorized access to HuggingFace model. This model requires a huggingface token. " - "Please login to HuggingFace " - "or set HUGGING_FACE_HUB_TOKEN to your User Access Token. " - "See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information" - ) + msg = "Unauthorized access to HuggingFace model. This model requires a huggingface token. Please login to HuggingFace or set HUGGING_FACE_HUB_TOKEN to your User Access Token. See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information" + raise HuggingFaceAuthorizationError(msg) return None @@ -413,7 +412,7 @@ def hf_hub_download(*args, **kwargs): if "unexpected keyword argument 'token'" in str(e): kwargs["use_auth_token"] = kwargs.pop("token") return _hf_hub_download(*args, **kwargs) - raise e + raise def huggingface_cached_path(url): diff --git a/imaginairy/modules/attention.py b/imaginairy/modules/attention.py index b9c4c58..241d685 100644 --- a/imaginairy/modules/attention.py +++ b/imaginairy/modules/attention.py @@ -14,8 +14,8 @@ XFORMERS_IS_AVAILABLE = False try: if get_device() == "cuda": - import xformers # noqa - import xformers.ops # noqa + import xformers + import xformers.ops XFORMERS_IS_AVAILABLE = True except ImportError: @@ -79,7 +79,7 @@ class LinearAttention(nn.Module): self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): - b, c, h, w = x.shape # noqa + b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange( qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 @@ -120,7 +120,7 @@ class SpatialSelfAttention(nn.Module): v = self.v(h_) # compute attention - b, c, h, w = q.shape # noqa + b, c, h, w = q.shape q = rearrange(q, "b c h w -> b (h w) c") k = rearrange(k, "b c h w -> b c (h w)") w_ = torch.einsum("bij,bjk->bik", q, k) @@ -183,7 +183,7 @@ class CrossAttention(nn.Module): # if mask is None and _global_mask_hack is not None: # mask = _global_mask_hack.to(torch.bool) - if get_device() == "cuda" or "mps" in get_device(): + if get_device() == "cuda" or "mps" in get_device(): # noqa if not XFORMERS_IS_AVAILABLE and ALLOW_SPLITMEM: return self.forward_splitmem(x, context=context, mask=mask) @@ -222,7 +222,7 @@ class CrossAttention(nn.Module): out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return self.to_out(out) - def forward_splitmem(self, x, context=None, mask=None): # noqa + def forward_splitmem(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) @@ -262,10 +262,8 @@ class CrossAttention(nn.Module): max_res = ( math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 ) - raise RuntimeError( - f"Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). " - f"Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free" - ) + msg = f"Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free" + raise RuntimeError(msg) slice_size = ( q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] ) @@ -474,7 +472,7 @@ class SpatialTransformer(nn.Module): # note: if no context is given, cross-attention defaults to self-attention if not isinstance(context, list): context = [context] - b, c, h, w = x.shape # noqa + b, c, h, w = x.shape x_in = x x = self.norm(x) if self.use_linear: diff --git a/imaginairy/modules/autoencoder.py b/imaginairy/modules/autoencoder.py index 00c0556..b932a02 100644 --- a/imaginairy/modules/autoencoder.py +++ b/imaginairy/modules/autoencoder.py @@ -290,10 +290,7 @@ class AutoencoderKL(pl.LightningModule): def forward(self, input, sample_posterior=True): # noqa posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() + z = posterior.sample() if sample_posterior else posterior.mode() dec = self.decode(z) return dec, posterior @@ -484,7 +481,7 @@ class AutoencoderKL(pl.LightningModule): :param x: img of size (bs, c, h, w) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) """ - bs, nc, h, w = x.shape # noqa + bs, nc, h, w = x.shape # number of crops in image Ly = (h - kernel_size[0]) // stride[0] + 1 diff --git a/imaginairy/modules/cldm.py b/imaginairy/modules/cldm.py index 3acac54..8a47081 100644 --- a/imaginairy/modules/cldm.py +++ b/imaginairy/modules/cldm.py @@ -19,12 +19,12 @@ from imaginairy.modules.diffusion.util import ( class ControlledUnetModel(UNetModel): - def forward( # noqa + def forward( self, x, timesteps=None, context=None, - control=None, # noqa + control=None, only_mid_control=False, **kwargs, ): @@ -129,10 +129,8 @@ class ControlNet(nn.Module): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: if len(num_res_blocks) != len(channel_mult): - raise ValueError( - "provide num_res_blocks either as an int (globally constant) or " - "as a list/tuple (per-level) with the same length as channel_mult" - ) + msg = "provide num_res_blocks either as an int (globally constant) or as a list/tuple (per-level) with the same length as channel_mult" + raise ValueError(msg) self.num_res_blocks = num_res_blocks if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not @@ -140,10 +138,8 @@ class ControlNet(nn.Module): if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) assert all( - ( - self.num_res_blocks[i] >= num_attention_blocks[i] - for i in range(len(num_attention_blocks)) - ) + self.num_res_blocks[i] >= num_attention_blocks[i] + for i in range(len(num_attention_blocks)) ) print( f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " @@ -425,10 +421,10 @@ class ControlLDM(LatentDiffusion): if is_diffusing: self.model = self.model.cuda() self.control_models = [cm.cuda() for cm in self.control_models] - self.first_stage_model = self.first_stage_model.cpu() # noqa + self.first_stage_model = self.first_stage_model.cpu() self.cond_stage_model = self.cond_stage_model.cpu() else: self.model = self.model.cpu() self.control_models = [cm.cpu() for cm in self.control_models] - self.first_stage_model = self.first_stage_model.cuda() # noqa + self.first_stage_model = self.first_stage_model.cuda() self.cond_stage_model = self.cond_stage_model.cuda() diff --git a/imaginairy/modules/clip_embedders.py b/imaginairy/modules/clip_embedders.py index b713ee9..4f7da44 100644 --- a/imaginairy/modules/clip_embedders.py +++ b/imaginairy/modules/clip_embedders.py @@ -102,9 +102,7 @@ class FrozenClipImageEmbedder(nn.Module): antialias=False, ): super().__init__() - self.model, preprocess = clip.load( # noqa - name=model_name, device=device, jit=jit - ) + self.model, preprocess = clip.load(name=model_name, device=device, jit=jit) self.antialias = antialias diff --git a/imaginairy/modules/diffusion/ddpm.py b/imaginairy/modules/diffusion/ddpm.py index b71fe11..8ed146f 100644 --- a/imaginairy/modules/diffusion/ddpm.py +++ b/imaginairy/modules/diffusion/ddpm.py @@ -9,6 +9,7 @@ import itertools import logging from contextlib import contextmanager, nullcontext from functools import partial +from typing import Optional import numpy as np import pytorch_lightning as pl @@ -371,7 +372,7 @@ class DDPM(pl.LightningModule): # we only modify first two axes assert new_shape[2:] == old_shape[2:] # assumes first axis corresponds to output dim - if not new_shape == old_shape: + if new_shape != old_shape: new_param = param.clone() old_param = sd[name] if len(new_shape) == 1: @@ -495,7 +496,7 @@ class DDPM(pl.LightningModule): img = torch.randn(shape, device=device) intermediates = [img] for i in tqdm( - reversed(range(0, self.num_timesteps)), + reversed(range(self.num_timesteps)), desc="Sampling t", total=self.num_timesteps, ): @@ -563,9 +564,8 @@ class DDPM(pl.LightningModule): elif self.parameterization == "v": target = self.get_v(x_start, noise, t) else: - raise NotImplementedError( - f"Parameterization {self.parameterization} not yet supported" - ) + msg = f"Parameterization {self.parameterization} not yet supported" + raise NotImplementedError(msg) loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) @@ -706,7 +706,7 @@ class DDPM(pl.LightningModule): lr = self.learning_rate params = list(self.model.parameters()) if self.learn_logvar: - params = params + [self.logvar] + params = [*params, self.logvar] opt = torch.optim.AdamW(params, lr=lr) return opt @@ -716,7 +716,7 @@ def _TileModeConv2DConvForward( ): if self.padding_modeX == self.padding_modeY: self.padding_mode = self.padding_modeX - return self._orig_conv_forward(input, weight, bias) # noqa + return self._orig_conv_forward(input, weight, bias) w1 = F.pad(input, self.paddingX, mode=self.padding_modeX) del input @@ -790,9 +790,7 @@ class LatentDiffusion(DDPM): if isinstance(m, nn.Conv2d): m._initial_padding_mode = m.padding_mode m._orig_conv_forward = m._conv_forward - m._conv_forward = _TileModeConv2DConvForward.__get__( # noqa - m, nn.Conv2d - ) + m._conv_forward = _TileModeConv2DConvForward.__get__(m, nn.Conv2d) self.tile_mode(tile_mode=False) def tile_mode(self, tile_mode): @@ -807,16 +805,16 @@ class LatentDiffusion(DDPM): if m.padding_modeY == m.padding_modeX: m.padding_mode = m.padding_modeX m.paddingX = ( - m._reversed_padding_repeated_twice[0], # noqa - m._reversed_padding_repeated_twice[1], # noqa + m._reversed_padding_repeated_twice[0], + m._reversed_padding_repeated_twice[1], 0, 0, ) m.paddingY = ( 0, 0, - m._reversed_padding_repeated_twice[2], # noqa - m._reversed_padding_repeated_twice[3], # noqa + m._reversed_padding_repeated_twice[2], + m._reversed_padding_repeated_twice[3], ) def make_cond_schedule( @@ -896,9 +894,8 @@ class LatentDiffusion(DDPM): elif isinstance(encoder_posterior, torch.Tensor): z = encoder_posterior else: - raise NotImplementedError( - f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" - ) + msg = f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + raise NotImplementedError(msg) return self.scale_factor * z def get_learned_conditioning(self, c): @@ -967,7 +964,7 @@ class LatentDiffusion(DDPM): :param x: img of size (bs, c, h, w) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) """ - bs, nc, h, w = x.shape # noqa + bs, nc, h, w = x.shape # number of crops in image Ly = (h - kernel_size[0]) // stride[0] + 1 @@ -1167,7 +1164,7 @@ class LatentDiffusion(DDPM): ks = self.split_input_params["ks"] # eg. (128, 128) stride = self.split_input_params["stride"] # eg. (64, 64) - h, w = x_noisy.shape[-2:] # noqa + h, w = x_noisy.shape[-2:] fold, unfold, normalization, weighting = self.get_fold_unfold( x_noisy, ks, stride @@ -1239,9 +1236,7 @@ class LatentDiffusion(DDPM): # tokenize crop coordinates for the bounding boxes of the respective patches patch_limits_tknzd = [ - torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[ # noqa - None - ].to( # noqa + torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to( self.device ) for bbox in patch_limits @@ -1292,7 +1287,7 @@ class LatentDiffusion(DDPM): return x_recon - def p_losses(self, x_start, cond, t, noise=None): # noqa + def p_losses(self, x_start, cond, t, noise=None): noise = noise if noise is not None else torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_output = self.apply_model(x_noisy, t, cond) @@ -1374,7 +1369,7 @@ class LatentDiffusion(DDPM): return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample( # noqa + def p_sample( self, x, c, @@ -1609,7 +1604,7 @@ class LatentDiffusion(DDPM): if inpaint: # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] + b, h, w = z.shape[0], z.shape[2], z.shape[3] # noqa mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0 @@ -1674,9 +1669,8 @@ class LatentDiffusion(DDPM): logger.info("Training the full unet") params = list(self.model.parameters()) else: - raise ValueError( - f"Unrecognised setting for unet_trainable: {self.unet_trainable}" - ) + msg = f"Unrecognised setting for unet_trainable: {self.unet_trainable}" + raise ValueError(msg) if self.cond_stage_trainable: logger.info( @@ -1706,7 +1700,7 @@ class LatentDiffusion(DDPM): def to_rgb(self, x): x = x.float() if not hasattr(self, "colorize"): - self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) # noqa + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) x = nn.functional.conv2d(x, weight=self.colorize) x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x @@ -1719,17 +1713,19 @@ class DiffusionWrapper(pl.LightningModule): self.conditioning_key = conditioning_key assert self.conditioning_key in [None, "concat", "crossattn", "hybrid", "adm"] - def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + def forward( + self, x, t, c_concat: Optional[list] = None, c_crossattn: Optional[list] = None + ): if self.conditioning_key is None: out = self.diffusion_model(x, t) elif self.conditioning_key == "concat": - xc = torch.cat([x] + c_concat, dim=1) + xc = torch.cat([x, *c_concat], dim=1) out = self.diffusion_model(xc, t) elif self.conditioning_key == "crossattn": cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(x, t, context=cc) elif self.conditioning_key == "hybrid": - xc = torch.cat([x] + c_concat, dim=1) + xc = torch.cat([x, *c_concat], dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc) elif self.conditioning_key == "adm": @@ -1818,7 +1814,7 @@ class LatentFinetuneDiffusion(LatentDiffusion): # print(f"Unexpected Keys: {unexpected}") @torch.no_grad() - def log_images( # noqa + def log_images( self, batch, N=8, @@ -1866,7 +1862,7 @@ class LatentFinetuneDiffusion(LatentDiffusion): if not (self.c_concat_log_start is None and self.c_concat_log_end is None): log["c_concat_decoded"] = self.decode_first_stage( - c_cat[:, self.c_concat_log_start : self.c_concat_log_end] # noqa + c_cat[:, self.c_concat_log_start : self.c_concat_log_end] ) if plot_diffusion_rows: @@ -1929,11 +1925,11 @@ class LatentFinetuneDiffusion(LatentDiffusion): class LatentInpaintDiffusion(LatentDiffusion): - def __init__( # noqa + def __init__( self, concat_keys=("mask", "masked_image"), masked_image_key="masked_image", - finetune_keys=None, # noqa + finetune_keys=None, *args, **kwargs, ): diff --git a/imaginairy/modules/diffusion/model.py b/imaginairy/modules/diffusion/model.py index 8bf02b8..043e22e 100644 --- a/imaginairy/modules/diffusion/model.py +++ b/imaginairy/modules/diffusion/model.py @@ -16,8 +16,8 @@ XFORMERS_IS_AVAILABLE = False try: if get_device() == "cuda": - import xformers # noqa - import xformers.ops # noqa + import xformers + import xformers.ops XFORMERS_IS_AVAILABLE = True except ImportError: @@ -415,7 +415,7 @@ class Model(nn.Module): ) curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) + in_ch_mult = (1, *tuple(ch_mult)) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() @@ -581,7 +581,7 @@ class Encoder(nn.Module): ) curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) + in_ch_mult = (1, *tuple(ch_mult)) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): @@ -853,10 +853,7 @@ class SimpleDecoder(nn.Module): def forward(self, x): for i, layer in enumerate(self.model): - if i in [1, 2, 3]: - x = layer(x, None) - else: - x = layer(x) + x = layer(x, None) if i in [1, 2, 3] else layer(x) h = self.norm_out(x) h = silu(h) diff --git a/imaginairy/modules/diffusion/openaimodel.py b/imaginairy/modules/diffusion/openaimodel.py index 0eafb26..12e3b09 100644 --- a/imaginairy/modules/diffusion/openaimodel.py +++ b/imaginairy/modules/diffusion/openaimodel.py @@ -1,5 +1,6 @@ import math from abc import abstractmethod +from typing import Optional import numpy as np import torch as th @@ -38,7 +39,7 @@ class AttentionPool2d(nn.Module): spacial_dim: int, embed_dim: int, num_heads_channels: int, - output_dim: int = None, + output_dim: Optional[int] = None, ): super().__init__() self.positional_embedding = nn.Parameter( @@ -519,10 +520,8 @@ class UNetModel(nn.Module): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: if len(num_res_blocks) != len(channel_mult): - raise ValueError( - "provide num_res_blocks either as an int (globally constant) or " - "as a list/tuple (per-level) with the same length as channel_mult" - ) + msg = "provide num_res_blocks either as an int (globally constant) or as a list/tuple (per-level) with the same length as channel_mult" + raise ValueError(msg) self.num_res_blocks = num_res_blocks if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not diff --git a/imaginairy/modules/diffusion/util.py b/imaginairy/modules/diffusion/util.py index 378d7e8..6893e46 100644 --- a/imaginairy/modules/diffusion/util.py +++ b/imaginairy/modules/diffusion/util.py @@ -52,7 +52,8 @@ def make_beta_schedule( ** 0.5 ) else: - raise ValueError(f"schedule '{schedule}' unknown.") + msg = f"schedule '{schedule}' unknown." + raise ValueError(msg) return betas.numpy() @@ -80,9 +81,8 @@ def make_ddim_timesteps( (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 ).astype(int) else: - raise NotImplementedError( - f'There is no ddim discretization method called "{ddim_discr_method}"' - ) + msg = f'There is no ddim discretization method called "{ddim_discr_method}"' + raise NotImplementedError(msg) # assert ddim_timesteps.shape[0] == num_ddim_timesteps # add one to get the final alpha values right (the ones from first scale to data during sampling) @@ -93,7 +93,7 @@ def make_ddim_timesteps( def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta): # select alphas for computing the variance schedule alphas = alphacums[ddim_timesteps] - alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + alphas_prev = np.asarray([alphacums[0], *alphacums[ddim_timesteps[:-1]].tolist()]) # according to the formula provided in https://arxiv.org/abs/2010.02502 sigmas = eta * np.sqrt( @@ -151,7 +151,7 @@ def checkpoint(func, inputs, params, flag): return func(*inputs) -class CheckpointFunction(torch.autograd.Function): # noqa +class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function @@ -180,7 +180,7 @@ class CheckpointFunction(torch.autograd.Function): # noqa del ctx.input_tensors del ctx.input_params del output_tensors - return (None, None) + input_grads + return (None, None, *input_grads) def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): @@ -246,7 +246,7 @@ def normalization(channels): class GroupNorm32(nn.GroupNorm): - def forward(self, x): # noqa + def forward(self, x): return super().forward(x.float()).type(x.dtype) @@ -260,7 +260,8 @@ def conv_nd(dims, *args, **kwargs): return nn.Conv2d(*args, **kwargs) if dims == 3: return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") + msg = f"unsupported dimensions: {dims}" + raise ValueError(msg) def linear(*args, **kwargs): @@ -278,7 +279,8 @@ def avg_pool_nd(dims, *args, **kwargs): return nn.AvgPool2d(*args, **kwargs) if dims == 3: return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") + msg = f"unsupported dimensions: {dims}" + raise ValueError(msg) class HybridConditioner(nn.Module): diff --git a/imaginairy/modules/encoders.py b/imaginairy/modules/encoders.py index 238dfac..a16cb0b 100644 --- a/imaginairy/modules/encoders.py +++ b/imaginairy/modules/encoders.py @@ -43,7 +43,7 @@ class ClassEmbedder(nn.Module): return uc -def disabled_train(self, mode=True): # noqa +def disabled_train(self, mode=True): """ For disabling train/eval mode. diff --git a/imaginairy/modules/midas/api.py b/imaginairy/modules/midas/api.py index 6b54be6..216bfbb 100644 --- a/imaginairy/modules/midas/api.py +++ b/imaginairy/modules/midas/api.py @@ -61,9 +61,8 @@ def load_midas_transform(model_type="dpt_hybrid"): ) else: - assert ( - False - ), f"model_type '{model_type}' not implemented, use: --model_type large" + msg = f"model_type '{model_type}' not implemented, use: --model_type large" + raise NotImplementedError(msg) transform = Compose( [ @@ -133,8 +132,8 @@ def load_model(model_type): ) else: - print(f"model_type '{model_type}' not implemented, use: --model_type large") - assert False + msg = f"model_type '{model_type}' not implemented, use: --model_type large" + raise NotImplementedError(msg) transform = Compose( [ @@ -155,13 +154,13 @@ def load_model(model_type): return model.eval(), transform -@lru_cache() +@lru_cache def midas_device(): # mps returns incorrect results ~50% of the time return torch.device("cuda" if torch.cuda.is_available() else "cpu") -@lru_cache() +@lru_cache def load_midas(model_type="dpt_hybrid"): model = MiDaSInference(model_type) model.to(midas_device()) diff --git a/imaginairy/modules/midas/midas/base_model.py b/imaginairy/modules/midas/midas/base_model.py index 23c20ac..dad86d1 100644 --- a/imaginairy/modules/midas/midas/base_model.py +++ b/imaginairy/modules/midas/midas/base_model.py @@ -4,7 +4,7 @@ from imaginairy import config from imaginairy.model_manager import get_cached_url_path -class BaseModel(torch.nn.Module): # noqa +class BaseModel(torch.nn.Module): def load(self, path): """ Load model from file. diff --git a/imaginairy/modules/midas/midas/blocks.py b/imaginairy/modules/midas/midas/blocks.py index ddb60e8..1348457 100644 --- a/imaginairy/modules/midas/midas/blocks.py +++ b/imaginairy/modules/midas/midas/blocks.py @@ -56,8 +56,8 @@ def _make_encoder( [32, 48, 136, 384], features, groups=groups, expand=expand ) # efficientnet_lite3 else: - print(f"Backbone '{backbone}' not implemented") - assert False + msg = f"Backbone '{backbone}' not implemented" + raise NotImplementedError(msg) return pretrained, scratch diff --git a/imaginairy/modules/midas/midas/transforms.py b/imaginairy/modules/midas/midas/transforms.py index a471796..1b7b13f 100644 --- a/imaginairy/modules/midas/midas/transforms.py +++ b/imaginairy/modules/midas/midas/transforms.py @@ -135,9 +135,8 @@ class Resize: # fit height scale_width = scale_height else: - raise ValueError( - f"resize_method {self.__resize_method} not implemented" - ) + msg = f"resize_method {self.__resize_method} not implemented" + raise ValueError(msg) if self.__resize_method == "lower_bound": new_height = self.constrain_to_multiple_of( @@ -157,7 +156,8 @@ class Resize: new_height = self.constrain_to_multiple_of(scale_height * height) new_width = self.constrain_to_multiple_of(scale_width * width) else: - raise ValueError(f"resize_method {self.__resize_method} not implemented") + msg = f"resize_method {self.__resize_method} not implemented" + raise ValueError(msg) return (new_width, new_height) diff --git a/imaginairy/modules/midas/midas/vit.py b/imaginairy/modules/midas/midas/vit.py index ae519f7..b799c28 100644 --- a/imaginairy/modules/midas/midas/vit.py +++ b/imaginairy/modules/midas/midas/vit.py @@ -22,10 +22,7 @@ class AddReadout(nn.Module): self.start_index = start_index def forward(self, x): - if self.start_index == 2: - readout = (x[:, 0] + x[:, 1]) / 2 - else: - readout = x[:, 0] + readout = (x[:, 0] + x[:, 1]) / 2 if self.start_index == 2 else x[:, 0] return x[:, self.start_index :] + readout.unsqueeze(1) @@ -118,7 +115,7 @@ def _resize_pos_embed(self, posemb, gs_h, gs_w): def forward_flex(self, x): b, c, h, w = x.shape - pos_embed = self._resize_pos_embed( # noqa + pos_embed = self._resize_pos_embed( self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] ) @@ -174,9 +171,8 @@ def get_readout_oper(vit_features, features, use_readout, start_index=1): ProjectReadout(vit_features, start_index) for out_feat in features ] else: - assert ( - False - ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + msg = "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + raise ValueError(msg) return readout_oper @@ -288,7 +284,7 @@ def _make_vit_b16_backbone( # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) - pretrained.model._resize_pos_embed = types.MethodType( # noqa + pretrained.model._resize_pos_embed = types.MethodType( _resize_pos_embed, pretrained.model ) @@ -469,7 +465,7 @@ def _make_vit_b_rn50_backbone( # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. - pretrained.model._resize_pos_embed = types.MethodType( # noqa + pretrained.model._resize_pos_embed = types.MethodType( _resize_pos_embed, pretrained.model ) diff --git a/imaginairy/modules/midas/utils.py b/imaginairy/modules/midas/utils.py index e189f3d..fded83b 100644 --- a/imaginairy/modules/midas/utils.py +++ b/imaginairy/modules/midas/utils.py @@ -94,9 +94,8 @@ def write_pfm(path, image, scale=1): ): # greyscale color = False else: - raise ValueError( - "Image must have H x W x 3, H x W x 1 or H x W dimensions." - ) + msg = "Image must have H x W x 3, H x W x 1 or H x W dimensions." + raise ValueError(msg) file.write("PF\n" if color else b"Pf\n") file.write(b"%d %d\n" % (image.shape[1], image.shape[0])) @@ -144,10 +143,7 @@ def resize_image(img): height_orig = img.shape[0] width_orig = img.shape[1] - if width_orig > height_orig: - scale = width_orig / 384 - else: - scale = height_orig / 384 + scale = width_orig / 384 if width_orig > height_orig else height_orig / 384 height = (np.ceil(height_orig / scale / 32) * 32).astype(int) width = (np.ceil(width_orig / scale / 32) * 32).astype(int) diff --git a/imaginairy/outpaint.py b/imaginairy/outpaint.py index 4735001..8f6fab0 100644 --- a/imaginairy/outpaint.py +++ b/imaginairy/outpaint.py @@ -217,15 +217,18 @@ def outpaint_arg_str_parse(arg_str): for arg in args: match = arg_pattern.match(arg) if not match: - raise ValueError(f"Invalid outpaint argument '{arg}'") + msg = f"Invalid outpaint argument '{arg}'" + raise ValueError(msg) direction, amount = match.groups() direction = direction.lower() if len(direction) == 1: if direction not in valid_direction_chars: - raise ValueError(f"Invalid outpaint direction '{direction}'") + msg = f"Invalid outpaint direction '{direction}'" + raise ValueError(msg) direction = valid_direction_chars[direction] elif direction not in valid_directions: - raise ValueError(f"Invalid outpaint direction '{direction}'") + msg = f"Invalid outpaint direction '{direction}'" + raise ValueError(msg) kwargs[direction] = int(amount) if "all" in kwargs: diff --git a/imaginairy/prompt_schedules.py b/imaginairy/prompt_schedules.py index d68ffaf..3072180 100644 --- a/imaginairy/prompt_schedules.py +++ b/imaginairy/prompt_schedules.py @@ -11,13 +11,13 @@ def parse_schedule_str(schedule_str): pattern = re.compile(r"([a-zA-Z0-9_-]+)\[([a-zA-Z0-9_:,. -]+)\]") match = pattern.match(schedule_str) if not match: - raise ValueError(f"Invalid kwarg schedule: {schedule_str}") + msg = f"Invalid kwarg schedule: {schedule_str}" + raise ValueError(msg) arg_name = match.group(1).replace("-", "_") if not hasattr(ImaginePrompt(), arg_name): - raise ValueError( - f"Invalid kwarg schedule. Not a valid argument name: {arg_name}" - ) + msg = f"Invalid kwarg schedule. Not a valid argument name: {arg_name}" + raise ValueError(msg) arg_values = match.group(2) if ":" in arg_values: @@ -53,7 +53,7 @@ def prompt_mutator(prompt, schedules): } """ - schedule_length = len(list(schedules.values())[0]) + schedule_length = len(next(iter(schedules.values()))) for i in range(schedule_length): new_prompt = copy(prompt) for attr_name, schedule in schedules.items(): diff --git a/imaginairy/roi_utils.py b/imaginairy/roi_utils.py index 8459434..12f1c20 100644 --- a/imaginairy/roi_utils.py +++ b/imaginairy/roi_utils.py @@ -24,7 +24,8 @@ def square_roi_coordinate(roi, max_width, max_height, best_effort=False): width = x2 - x1 height = y2 - y1 if not best_effort and width != height: - raise RuntimeError(f"ROI is not square: {width}x{height}") + msg = f"ROI is not square: {width}x{height}" + raise RuntimeError(msg) return x1, y1, x2, y2 @@ -96,8 +97,7 @@ def move_roi_into_bounds(roi, max_width, max_height, force=False): if x1 < 0 or y1 < 0 or x2 > max_width or y2 > max_height: roi_width = x2 - x1 roi_height = y2 - y1 - raise RoiNotInBoundsError( - f"Not possible to fit ROI into boundaries: {roi_width}x{roi_height} won't fit inside {max_width}x{max_height}" - ) + msg = f"Not possible to fit ROI into boundaries: {roi_width}x{roi_height} won't fit inside {max_width}x{max_height}" + raise RoiNotInBoundsError(msg) return x1, y1, x2, y2 diff --git a/imaginairy/safety.py b/imaginairy/safety.py index 3af6870..4db7e32 100644 --- a/imaginairy/safety.py +++ b/imaginairy/safety.py @@ -113,7 +113,7 @@ class EnhancedStableDiffusionSafetyChecker( return safety_results -@lru_cache() +@lru_cache def safety_models(): safety_model_id = "CompVis/stable-diffusion-safety-checker" monkeypatch_safety_cosine_distance() @@ -124,7 +124,7 @@ def safety_models(): return safety_feature_extractor, safety_checker -@lru_cache() +@lru_cache def monkeypatch_safety_cosine_distance(): orig_cosine_distance = safety_checker_mod.cosine_distance diff --git a/imaginairy/schema.py b/imaginairy/schema.py index 4653d90..bdb68cf 100644 --- a/imaginairy/schema.py +++ b/imaginairy/schema.py @@ -51,17 +51,19 @@ class InvalidUrlError(ValueError): class LazyLoadingImage: """Image file encoded as base64 string.""" - def __init__(self, *, filepath=None, url=None, img: Image = None, b64: str = None): + def __init__( + self, *, filepath=None, url=None, img: Image = None, b64: Optional[str] = None + ): if not filepath and not url and not img and not b64: - raise ValueError( - "You must specify a url or filepath or img or base64 string" - ) + msg = "You must specify a url or filepath or img or base64 string" + raise ValueError(msg) if sum([bool(filepath), bool(url), bool(img), bool(b64)]) > 1: raise ValueError("You cannot multiple input methods") # validate file exists if filepath and not os.path.exists(filepath): - raise FileNotFoundError(f"File does not exist: {filepath}") + msg = f"File does not exist: {filepath}" + raise FileNotFoundError(msg) # validate url is valid url if url: @@ -73,7 +75,8 @@ class LazyLoadingImage: except LocationParseError: raise InvalidUrlError(f"Invalid url: {url}") # noqa if parsed_url.scheme not in {"http", "https"} or not parsed_url.host: - raise InvalidUrlError(f"Invalid url: {url}") + msg = f"Invalid url: {url}" + raise InvalidUrlError(msg) if b64: img = self.load_image_from_base64(b64) @@ -145,16 +148,14 @@ class LazyLoadingImage: raise ValueError(msg) # noqa if isinstance(value, dict): return cls(**value) - raise ValueError( - "Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string" - ) + msg = "Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string" + raise ValueError(msg) def handle_b64(value: Any) -> "LazyLoadingImage": if isinstance(value, str): return cls(b64=value) - raise ValueError( - "Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string" - ) + msg = "Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string" + raise ValueError(msg) return core_schema.json_or_python_schema( json_schema=core_schema.chain_schema( @@ -349,15 +350,13 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): return "" if not isinstance(v, str): - raise ValueError( - f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}" - ) + msg = f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}" + raise ValueError(msg) # noqa v = v.lower() if v not in valid_tile_modes: - raise ValueError( - f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}" - ) + msg = f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}" + raise ValueError(msg) return v @field_validator("outpaint", mode="after") @@ -375,7 +374,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): return v if not isinstance(v, Tensor): - raise ValueError("conditioning must be a torch.Tensor") + raise ValueError("conditioning must be a torch.Tensor") # noqa return v # @field_validator("init_image", "mask_image", mode="after") @@ -412,13 +411,15 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): @field_validator("mask_image") def validate_mask_image(cls, v, info: core_schema.FieldValidationInfo): if v is not None and info.data.get("mask_prompt") is not None: - raise ValueError("You can only set one of `mask_image` and `mask_prompt`") + msg = "You can only set one of `mask_image` and `mask_prompt`" + raise ValueError(msg) return v @field_validator("mask_prompt", "mask_image", mode="before") def validate_mask_prompt(cls, v, info: core_schema.FieldValidationInfo): if info.data.get("init_image") is None and v: - raise ValueError("You must set `init_image` if you want to use a mask") + msg = "You must set `init_image` if you want to use a mask" + raise ValueError(msg) return v @field_validator("model", mode="before") @@ -455,9 +456,8 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): SamplerName.PLMS, SamplerName.DDIM, ): - raise ValueError( - "PLMS and DDIM samplers are not supported for pix2pix edit model." - ) + msg = "PLMS and DDIM samplers are not supported for pix2pix edit model." + raise ValueError(msg) return v @field_validator("steps") @@ -620,7 +620,7 @@ class ImagineResult: self.is_nsfw = is_nsfw self.safety_score = safety_score - self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc) + self.created_at = datetime.now(tz=timezone.utc) self.torch_backend = get_device() self.hardware_name = get_hardware_description(get_device()) @@ -655,9 +655,8 @@ class ImagineResult: def save(self, save_path, image_type="generated"): img = self.images.get(image_type, None) if img is None: - raise ValueError( - f"Image of type {image_type} not stored. Options are: {self.images.keys()}" - ) + msg = f"Image of type {image_type} not stored. Options are: {self.images.keys()}" + raise ValueError(msg) img.convert("RGB").save(save_path, exif=self._exif()) diff --git a/imaginairy/surprise_me.py b/imaginairy/surprise_me.py index 62819c4..54d3ab5 100644 --- a/imaginairy/surprise_me.py +++ b/imaginairy/surprise_me.py @@ -192,7 +192,7 @@ def surprise_me_prompts( width=width, height=height, seed=seed, - **kwargs, # noqa + **kwargs, ) ) else: @@ -206,7 +206,7 @@ def surprise_me_prompts( width=width, height=height, seed=seed, - **kwargs, # noqa + **kwargs, ) ) diff --git a/imaginairy/train.py b/imaginairy/train.py index 0bc6fa2..2710dda 100644 --- a/imaginairy/train.py +++ b/imaginairy/train.py @@ -20,6 +20,8 @@ except ImportError: # let's not break all of imaginairy just because a training import doesn't exist in an older version of PL # Use >= 1.6.0 to make this work DDPStrategy = None +import contextlib + from pytorch_lightning.trainer import Trainer from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.distributed import rank_zero_only @@ -220,10 +222,8 @@ class SetupCallback(Callback): dst, name = os.path.split(self.logdir) dst = os.path.join(dst, "child_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) - try: + with contextlib.suppress(FileNotFoundError): os.rename(self.logdir, dst) - except FileNotFoundError: - pass class ImageLogger(Callback): @@ -342,11 +342,12 @@ class ImageLogger(Callback): ): if not self.disabled and pl_module.global_step > 0: self.log_img(pl_module, batch, batch_idx, split="val") - if hasattr(pl_module, "calibrate_grad_norm"): - if ( - pl_module.calibrate_grad_norm and batch_idx % 25 == 0 - ) and batch_idx > 0: - self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + if ( + hasattr(pl_module, "calibrate_grad_norm") + and (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) + and batch_idx > 0 + ): + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) class CUDACallback(Callback): @@ -356,9 +357,9 @@ class CUDACallback(Callback): if "cuda" in get_device(): torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index) torch.cuda.synchronize(trainer.strategy.root_device.index) - self.start_time = time.time() # noqa + self.start_time = time.time() - def on_train_epoch_end(self, trainer, pl_module): # noqa + def on_train_epoch_end(self, trainer, pl_module): if "cuda" in get_device(): torch.cuda.synchronize(trainer.strategy.root_device.index) max_memory = ( @@ -394,19 +395,20 @@ def train_diffusion_model( accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf """ if DDPStrategy is None: - raise ImportError("Please install pytorch-lightning>=1.6.0 to train a model") + msg = "Please install pytorch-lightning>=1.6.0 to train a model" + raise ImportError(msg) batch_size = 1 seed = 23 num_workers = 1 num_val_workers = 0 - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # noqa: DTZ005 logdir = os.path.join(logdir, now) ckpt_output_dir = os.path.join(logdir, "checkpoints") cfg_output_dir = os.path.join(logdir, "configs") seed_everything(seed) - model = get_diffusion_model( # noqa + model = get_diffusion_model( weights_location=weights_location, half_mode=False, for_training=True )._model model.learning_rate = learning_rate * accumulate_grad_batches * batch_size @@ -501,9 +503,7 @@ def train_diffusion_model( num_sanity_val_steps=0, accumulate_grad_batches=accumulate_grad_batches, strategy=DDPStrategy(), - callbacks=[ - instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg # noqa - ], + callbacks=[instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg], gpus=1, default_root_dir=".", ) diff --git a/imaginairy/training_tools/image_prep.py b/imaginairy/training_tools/image_prep.py index bbaaeb3..88d07f1 100644 --- a/imaginairy/training_tools/image_prep.py +++ b/imaginairy/training_tools/image_prep.py @@ -145,7 +145,7 @@ def create_class_images(class_description, output_folder, num_images=200): while existing_image_count < num_images: prompt = ImaginePrompt(class_description, steps=20) - result = list(imagine([prompt]))[0] + result = next(iter(imagine([prompt]))) if result.is_nsfw: continue dest = os.path.join( diff --git a/imaginairy/training_tools/prune_model.py b/imaginairy/training_tools/prune_model.py index 5a1880a..691bd81 100644 --- a/imaginairy/training_tools/prune_model.py +++ b/imaginairy/training_tools/prune_model.py @@ -29,7 +29,7 @@ def prune_model_data(data, only_keep_ema=True): data.pop("optimizer_states", None) if only_keep_ema: state_dict = data["state_dict"] - model_keys = [k for k in state_dict.keys() if k.startswith("model.")] + model_keys = [k for k in state_dict if k.startswith("model.")] for model_key in model_keys: ema_key = "model_ema." + model_key[6:].replace(".", "") diff --git a/imaginairy/training_tools/single_concept.py b/imaginairy/training_tools/single_concept.py index 434ffe7..d4849b6 100644 --- a/imaginairy/training_tools/single_concept.py +++ b/imaginairy/training_tools/single_concept.py @@ -92,7 +92,8 @@ class SingleConceptDataset(Dataset): try: image = Image.open(img_path).convert("RGB") except RuntimeError as e: - raise RuntimeError(f"Could not read image {img_path}") from e + msg = f"Could not read image {img_path}" + raise RuntimeError(msg) from e image = self.image_transforms(image) data = {"image": image, "txt": txt} return data diff --git a/imaginairy/utils/__init__.py b/imaginairy/utils/__init__.py index c339559..6d98af2 100644 --- a/imaginairy/utils/__init__.py +++ b/imaginairy/utils/__init__.py @@ -14,7 +14,7 @@ from torch.overrides import handle_torch_function, has_torch_function_variadic logger = logging.getLogger(__name__) -@lru_cache() +@lru_cache def get_device() -> str: """Return the best torch backend available.""" if torch.cuda.is_available(): @@ -26,7 +26,7 @@ def get_device() -> str: return "cpu" -@lru_cache() +@lru_cache def get_hardware_description(device_type: str) -> str: """Description of the hardware being used.""" desc = platform.platform() @@ -185,10 +185,9 @@ def check_torch_working(): torch.randn(1, device=get_device()) except RuntimeError as e: if "CUDA" in str(e): - raise RuntimeError( - "CUDA is not working. Make sure you have a GPU and CUDA installed." - ) from e - raise e + msg = "CUDA is not working. Make sure you have a GPU and CUDA installed." + raise RuntimeError(msg) from e + raise def frange(start, stop, step): @@ -209,7 +208,7 @@ def shrink_list(items, max_size): new_items = {} for i, item in enumerate(items): new_items[int(i / removal_ratio)] = item - return [items[0]] + list(new_items.values()) + return [items[0], *list(new_items.values())] def glob_expand_paths(paths): diff --git a/imaginairy/utils/data_distorter.py b/imaginairy/utils/data_distorter.py index 99da50f..3d133a3 100644 --- a/imaginairy/utils/data_distorter.py +++ b/imaginairy/utils/data_distorter.py @@ -1,3 +1,4 @@ +import contextlib import math import sys from copy import deepcopy @@ -77,7 +78,7 @@ class DataDistorter: def __init__(self, data, add_data_values=True): self.data = deepcopy(data) self.data_map, self.data_unique_values = create_node_map(self.data) - self.distortion_values = DISTORTED_VALUES + [] + self.distortion_values = [*DISTORTED_VALUES] if add_data_values: self.distortion_values += list(self.data_unique_values) @@ -141,15 +142,13 @@ def create_node_map(data: Union[dict, list, tuple]) -> Tuple[Dict[int, list], se if isinstance(curr_data, dict): for key, value in curr_data.items(): - _traverse(value, curr_path + [key]) + _traverse(value, [*curr_path, key]) elif isinstance(curr_data, (list, tuple)): for idx, item in enumerate(curr_data): - _traverse(item, curr_path + [idx]) + _traverse(item, [*curr_path, idx]) else: - try: + with contextlib.suppress(TypeError): node_values.add(curr_data) - except TypeError: - pass _traverse(data, []) return node_map, node_values diff --git a/imaginairy/utils/model_cache.py b/imaginairy/utils/model_cache.py index ae2eaaa..fcb803f 100644 --- a/imaginairy/utils/model_cache.py +++ b/imaginairy/utils/model_cache.py @@ -203,9 +203,8 @@ class GPUModelCache: total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2) pct_to_use = float(self._max_cpu_memory_gb[:-1]) / 100.0 return total_ram_gb * pct_to_use * (1024**3) - raise ValueError( - f"Invalid value for max_cpu_memory_gb: {self._max_cpu_memory_gb}" - ) + msg = f"Invalid value for max_cpu_memory_gb: {self._max_cpu_memory_gb}" + raise ValueError(msg) return self._max_cpu_memory_gb * (1024**3) @cached_property @@ -224,9 +223,8 @@ class GPUModelCache: total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2) pct_to_use = float(self._max_gpu_memory_gb[:-1]) / 100.0 return total_ram_gb * pct_to_use * (1024**3) - raise ValueError( - f"Invalid value for max_gpu_memory_gb: {self._max_gpu_memory_gb}" - ) + msg = f"Invalid value for max_gpu_memory_gb: {self._max_gpu_memory_gb}" + raise ValueError(msg) return self._max_gpu_memory_gb * (1024**3) def _move_to_gpu(self, key, model): @@ -280,12 +278,12 @@ class GPUModelCache: import torch if key not in self: - raise KeyError(f"The key {key} does not exist in the cache") + msg = f"The key {key} does not exist in the cache" + raise KeyError(msg) - if key in self.cpu_cache: - if self.device != torch.device("cpu"): - self.cpu_cache.move_to_end(key) - self._move_to_gpu(key, self.cpu_cache[key]) + if key in self.cpu_cache and self.device != torch.device("cpu"): + self.cpu_cache.move_to_end(key) + self._move_to_gpu(key, self.cpu_cache[key]) if key in self.gpu_cache: self.gpu_cache.move_to_end(key) @@ -337,7 +335,7 @@ class MemoryManagedModelWrapper: self._mmmw_kwargs = kwargs self._mmmw_namespace = namespace self._mmmw_estimated_ram_size_mb = estimated_ram_size_mb - self._mmmw_cache_key = (namespace,) + args + tuple(kwargs.items()) + self._mmmw_cache_key = (namespace, *args, *tuple(kwargs.items())) def _mmmw_load_model(self): if self._mmmw_cache_key not in self.__class__._mmmw_cache: diff --git a/imaginairy/vendored/k_diffusion/sampling.py b/imaginairy/vendored/k_diffusion/sampling.py index 0b81f2d..ab7b811 100644 --- a/imaginairy/vendored/k_diffusion/sampling.py +++ b/imaginairy/vendored/k_diffusion/sampling.py @@ -85,7 +85,7 @@ class BatchedBrownianTree: seed = [seed] self.batched = False self.trees = [ - torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed + torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed # noqa ] @staticmethod diff --git a/requirements-dev.in b/requirements-dev.in index b69f06a..6234432 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -1,10 +1,6 @@ black coverage -isort ruff -pycln -pylama -pylint pytest pytest-randomly pytest-sugar diff --git a/requirements-dev.txt b/requirements-dev.txt index 89056ed..9dc9b49 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -16,8 +16,6 @@ anyio==3.7.1 # via # fastapi # starlette -astroid==2.15.8 - # via pylint async-timeout==4.0.3 # via aiohttp attrs==23.1.0 @@ -36,7 +34,6 @@ click==8.1.7 # click-help-colors # click-shell # imaginAIry (setup.py) - # typer # uvicorn click-help-colors==0.9.2 # via imaginAIry (setup.py) @@ -48,10 +45,8 @@ coverage==7.3.1 # via -r requirements-dev.in cycler==0.12.0 # via matplotlib -diffusers==0.21.3 +diffusers==0.21.4 # via imaginAIry (setup.py) -dill==0.3.7 - # via pylint einops==0.6.1 # via imaginAIry (setup.py) exceptiongroup==1.1.3 @@ -71,7 +66,7 @@ filelock==3.12.4 # transformers filterpy==1.4.5 # via facexlib -fonttools==4.42.1 +fonttools==4.43.0 # via matplotlib frozenlist==1.4.0 # via @@ -104,18 +99,10 @@ importlib-metadata==6.8.0 # via diffusers iniconfig==2.0.0 # via pytest -isort==5.12.0 - # via - # -r requirements-dev.in - # pylint kiwisolver==1.4.5 # via matplotlib kornia==0.7.0 # via imaginAIry (setup.py) -lazy-object-proxy==1.9.0 - # via astroid -libcst==1.0.1 - # via pycln lightning-utilities==0.9.0 # via # pytorch-lightning @@ -126,18 +113,12 @@ matplotlib==3.7.3 # via # -c tests/constraints.txt # filterpy -mccabe==0.7.0 - # via - # pylama - # pylint multidict==6.0.4 # via # aiohttp # yarl mypy-extensions==1.0.0 - # via - # black - # typing-inspect + # via black numba==0.58.0 # via facexlib numpy==1.24.4 @@ -178,9 +159,7 @@ packaging==23.1 # pytorch-lightning # transformers pathspec==0.11.2 - # via - # black - # pycln + # via black pillow==10.0.1 # via # diffusers @@ -190,9 +169,7 @@ pillow==10.0.1 # matplotlib # torchvision platformdirs==3.10.0 - # via - # black - # pylint + # via black pluggy==1.3.0 # via pytest protobuf==3.20.3 @@ -201,24 +178,12 @@ protobuf==3.20.3 # open-clip-torch psutil==5.9.5 # via imaginAIry (setup.py) -pycln==2.2.2 - # via -r requirements-dev.in -pycodestyle==2.11.0 - # via pylama pydantic==2.4.2 # via # fastapi # imaginAIry (setup.py) pydantic-core==2.10.1 # via pydantic -pydocstyle==6.3.0 - # via pylama -pyflakes==3.1.0 - # via pylama -pylama==8.4.1 - # via -r requirements-dev.in -pylint==2.17.6 - # via -r requirements-dev.in pyparsing==3.1.1 # via matplotlib pytest==7.4.2 @@ -237,9 +202,7 @@ pytorch-lightning==1.9.5 pyyaml==6.0.1 # via # huggingface-hub - # libcst # omegaconf - # pycln # pytorch-lightning # responses # timm @@ -280,8 +243,6 @@ six==1.16.0 # via python-dateutil sniffio==1.3.0 # via anyio -snowballstemmer==2.2.0 - # via pydocstyle starlette==0.27.0 # via fastapi termcolor==2.3.0 @@ -295,12 +256,7 @@ tokenizers==0.13.3 tomli==2.0.1 # via # black - # pylint # pytest -tomlkit==0.12.1 - # via - # pycln - # pylint torch==1.13.1 # via # facexlib @@ -335,28 +291,20 @@ tqdm==4.66.1 # transformers transformers==4.33.3 # via imaginAIry (setup.py) -typer==0.9.0 - # via pycln types-pyyaml==6.0.12.12 # via responses typing-extensions==4.8.0 # via - # astroid # black # fastapi # huggingface-hub - # libcst # lightning-utilities # pydantic # pydantic-core # pytorch-lightning # torch # torchvision - # typer - # typing-inspect # uvicorn -typing-inspect==0.9.0 - # via libcst urllib3==2.0.5 # via # requests @@ -367,8 +315,6 @@ wcwidth==0.2.7 # via ftfy wheel==0.41.2 # via -r requirements-dev.in -wrapt==1.15.0 - # via astroid yarl==1.9.2 # via aiohttp zipp==3.17.0 diff --git a/scripts/controlnet_convert.py b/scripts/controlnet_convert.py index 46d587f..4723003 100644 --- a/scripts/controlnet_convert.py +++ b/scripts/controlnet_convert.py @@ -52,8 +52,8 @@ def main(): time.sleep(1) controlnet_statedict = torch.load(controlnet_path, map_location="cpu") print("\n\nComparing reconstructed controlnet with original") - for k in controlnet_statedict.keys(): - if k not in reconstituted_controlnet_statedict.keys(): + for k in controlnet_statedict: + if k not in reconstituted_controlnet_statedict: print(f"Key {k} not in reconstituted") elif ( controlnet_statedict[k].shape diff --git a/scripts/prep_vocab_lists.py b/scripts/prep_vocab_lists.py index f6809f8..11fa2ad 100644 --- a/scripts/prep_vocab_lists.py +++ b/scripts/prep_vocab_lists.py @@ -54,7 +54,7 @@ def make_txts(): with open(src_json, encoding="utf-8") as f: prompts = json.load(f) categories = [] - for c in prompts.keys(): + for c in prompts: if any(c.startswith(p) for p in excluded_prefixes): continue categories.append(c) diff --git a/setup.py b/setup.py index a611975..3f2591a 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ else: entry_points = None -@lru_cache() +@lru_cache def get_git_revision_hash() -> str: try: return ( diff --git a/tests/conftest.py b/tests/conftest.py index 6342ef2..a460be8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import contextlib import logging import os import sys @@ -33,16 +34,14 @@ elif get_device() == "cpu": @pytest.fixture(scope="session", autouse=True) -def pre_setup(): +def _pre_setup(): api.IMAGINAIRY_SAFETY_MODE = "disabled" suppress_annoying_logs_and_warnings() test_output_folder = f"{TESTS_FOLDER}/test_output" # delete the testoutput folder and recreate it - try: + with contextlib.suppress(FileNotFoundError): rmtree(test_output_folder) - except FileNotFoundError: - pass os.makedirs(test_output_folder, exist_ok=True) orig_urlopen = HTTPConnectionPool.urlopen @@ -73,7 +72,7 @@ def pre_setup(): @pytest.fixture(autouse=True) -def reset_get_device(): +def _reset_get_device(): get_device.cache_clear() @@ -94,7 +93,7 @@ def sampler_type(request): return request.param -@pytest.fixture +@pytest.fixture() def mocked_responses(): with responses.RequestsMock() as rsps: yield rsps diff --git a/tests/enhancers/test_blur_detect.py b/tests/enhancers/test_blur_detect.py index 49960dd..e99a7d0 100644 --- a/tests/enhancers/test_blur_detect.py +++ b/tests/enhancers/test_blur_detect.py @@ -12,7 +12,7 @@ blur_params = [ ] -@pytest.mark.parametrize("img_path,expected", blur_params) +@pytest.mark.parametrize(("img_path", "expected"), blur_params) def test_calculate_blurriness_level(img_path, expected): img = Image.open(img_path) diff --git a/tests/img_processors/test_control_modes.py b/tests/img_processors/test_control_modes.py index 69285e2..b6d768e 100644 --- a/tests/img_processors/test_control_modes.py +++ b/tests/img_processors/test_control_modes.py @@ -15,7 +15,7 @@ def control_img_to_pillow_img(img_t): control_mode_params = list(CONTROL_MODES.items()) -@pytest.mark.parametrize("control_name,control_func", control_mode_params) +@pytest.mark.parametrize(("control_name", "control_func"), control_mode_params) def test_control_images(filename_base_for_outputs, control_func, control_name): seed_everything(42) img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png") diff --git a/tests/modules/test_autoencoders.py b/tests/modules/test_autoencoders.py index 6f817be..f68fe47 100644 --- a/tests/modules/test_autoencoders.py +++ b/tests/modules/test_autoencoders.py @@ -26,7 +26,7 @@ strat_combos = [ @pytest.mark.skipif(True, reason="Run manually as needed. Uses too much memory.") -@pytest.mark.parametrize("encode_strat,decode_strat", strat_combos) +@pytest.mark.parametrize(("encode_strat", "decode_strat"), strat_combos) def test_encode_decode(filename_base_for_outputs, encode_strat, decode_strat): """ Test that encoding and decoding works. diff --git a/tests/ruff.toml b/tests/ruff.toml new file mode 100644 index 0000000..6562f53 --- /dev/null +++ b/tests/ruff.toml @@ -0,0 +1,14 @@ +extend-ignore = ["E501", "G004", "PT005", "RET504", "SIM114", "TRY003", "TRY400", "TRY401", "RUF012", "RUF100"] +extend-exclude = ["imaginairy/vendored", "downloads", "other"] + +extend-select = [ + "I", "E", "W", "UP", "ASYNC", "BLE", "A001", "A002", + "C4", "DTZ", "T10", "EM", "ISC", "ICN", "G", "PIE", "PT", + "Q", "SIM", "TID", "TCH", "PLC", "PLE", "TRY", "RUF" +] + +[isort] +combine-as-imports = true + +[flake8-errmsg] +max-string-length = 50 \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py index 4b96d96..84ea4ff 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -174,10 +174,7 @@ def test_img_to_img_fruit_2_gold( filepath=os.path.join(TESTS_FOLDER, "data", "bowl_of_fruit.jpg") ) target_steps = 25 - if init_strength >= 1: - needed_steps = 25 - else: - needed_steps = int(target_steps / (1 - init_strength)) + needed_steps = 25 if init_strength >= 1 else int(target_steps / (1 - init_strength)) prompt = ImaginePrompt( "a white bowl filled with gold coins", prompt_strength=12, diff --git a/tests/test_enhancers.py b/tests/test_enhancers.py index 45ea0fd..db2ebde 100644 --- a/tests/test_enhancers.py +++ b/tests/test_enhancers.py @@ -126,7 +126,7 @@ boolean_mask_test_cases = [ ] -@pytest.mark.parametrize("mask_text,expected", boolean_mask_test_cases) +@pytest.mark.parametrize(("mask_text", "expected"), boolean_mask_test_cases) def test_clip_mask_parser(mask_text, expected): parsed = MASK_PROMPT.parseString(mask_text)[0][0] assert str(parsed) == expected diff --git a/tests/test_feather_tile.py b/tests/test_feather_tile.py index 7dbb24e..748f52e 100644 --- a/tests/test_feather_tile.py +++ b/tests/test_feather_tile.py @@ -46,7 +46,7 @@ cases = [ ] -@pytest.mark.parametrize("img_ratio, tile_size, overlap_pct", cases) +@pytest.mark.parametrize(("img_ratio", "tile_size", "overlap_pct"), cases) def test_feather_tile_simple(img_ratio, tile_size, overlap_pct): img = pillow_img_to_torch_image( LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bowl_of_fruit.jpg") diff --git a/tests/test_outpaint.py b/tests/test_outpaint.py index 46f2150..5027fb7 100644 --- a/tests/test_outpaint.py +++ b/tests/test_outpaint.py @@ -23,7 +23,7 @@ def test_outpainting_outpaint(filename_base_for_outputs): steps=20, seed=542906833, ) - result = list(imagine([prompt]))[0] + result = next(iter(imagine([prompt]))) img_path = f"{filename_base_for_outputs}.png" assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=17000) @@ -37,6 +37,6 @@ outpaint_test_params = [ ] -@pytest.mark.parametrize("arg_str, expected_kwargs", outpaint_test_params) +@pytest.mark.parametrize(("arg_str", "expected_kwargs"), outpaint_test_params) def test_outpaint_parse_kwargs(arg_str, expected_kwargs): assert outpaint_arg_str_parse(arg_str) == expected_kwargs diff --git a/tests/test_prompt_schedules.py b/tests/test_prompt_schedules.py index 969c65d..5266ac9 100644 --- a/tests/test_prompt_schedules.py +++ b/tests/test_prompt_schedules.py @@ -5,7 +5,7 @@ from imaginairy.utils import frange @pytest.mark.parametrize( - "schedule_str,expected", + ("schedule_str", "expected"), [ ("prompt_strength[2:40:1]", ("prompt_strength", list(range(2, 40)))), ("prompt_strength[2:40:0.5]", ("prompt_strength", list(frange(2, 40, 0.5)))), diff --git a/tests/test_schema/test_lazy_load_image.py b/tests/test_schema/test_lazy_load_image.py index 23f80db..57bf9d6 100644 --- a/tests/test_schema/test_lazy_load_image.py +++ b/tests/test_schema/test_lazy_load_image.py @@ -26,7 +26,7 @@ def _red_url(mocked_responses): status=200, content_type="image/png", ) - yield url + return url @pytest.fixture(name="red_path") diff --git a/tests/test_utils/test_model_cache.py b/tests/test_utils/test_model_cache.py index ef9b46f..f27df82 100644 --- a/tests/test_utils/test_model_cache.py +++ b/tests/test_utils/test_model_cache.py @@ -107,7 +107,7 @@ def test_cache_ordering(): ) cache.set("key-0", create_model_of_n_bytes(4_000_000)) - assert list(cache.cpu_cache.keys()) == [] # noqa + assert list(cache.cpu_cache.keys()) == [] assert list(cache.gpu_cache.keys()) == ["key-0"] assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == ( 0, @@ -115,7 +115,7 @@ def test_cache_ordering(): ) cache.set("key-1", create_model_of_n_bytes(4_000_000)) - assert list(cache.cpu_cache.keys()) == [] # noqa + assert list(cache.cpu_cache.keys()) == [] assert list(cache.gpu_cache.keys()) == ["key-0", "key-1"] assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == ( 0, diff --git a/tests/test_utils/test_utils.py b/tests/test_utils/test_utils.py index 293bede..d8913ec 100644 --- a/tests/test_utils/test_utils.py +++ b/tests/test_utils/test_utils.py @@ -68,7 +68,7 @@ def test_instantiate_from_config(): "params": {"year": 2002, "month": 10, "day": 1}, } o = instantiate_from_config(config) - assert o == datetime(2002, 10, 1) + assert o == datetime(2002, 10, 1) # noqa: DTZ001 config = "__is_first_stage__" assert instantiate_from_config(config) is None diff --git a/tox.ini b/tox.ini index 676afb3..f507962 100644 --- a/tox.ini +++ b/tox.ini @@ -4,25 +4,3 @@ norecursedirs = build dist downloads other prolly_delete imaginairy/vendored filterwarnings = ignore::DeprecationWarning ignore::UserWarning - - -[pylama] -format = pylint -skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads/*,imaginairy/vendored/*,testing_support/vastai_cli_official.py,.eggs/* -linters = pylint,pycodestyle,pyflakes,mypy -ignore = - Z999,C0103,C0201,C0301,C0302,C0114,C0115,C0116,C0415, - Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D406,D407,D413,D415,D417, - Z999,E203,E501,E1101,E1121,E1131,E1133,E1135,E1136, - Z999,R0901,R0902,R0903,R0904,R0193,R0912,R0913,R0914,R0915,R1702, - Z999,W0221,W0511,W0612,W0613,W0632,W1203 - -[pylama:tests/*] -ignore = C0104,C0114,C0116,D103,W0143,W0613 - -[pylama:*/__init__.py] -ignore = D104 - -[pylama:pylint] -generated_members=torch.* -extension-pkg-whitelist=pydantic