style: speed up linting and autoformatting. fix lints

pull/385/head
Bryce 8 months ago committed by Bryce Drennan
parent 460add16b8
commit 558d3388e5

@ -9,35 +9,42 @@ jobs:
lint: lint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4.5.0
with: with:
python-version: 3.9 python-version: 3.9
cache: pip - name: Cache dependencies
cache-dependency-path: requirements-dev.txt uses: actions/cache@v3.2.4
- name: Install dependencies id: cache
run: | with:
python -m pip install --disable-pip-version-check wheel pip-tools path: ${{ env.pythonLocation }}
pip-sync requirements-dev.txt key: ${{ env.pythonLocation }}-${{ hashFiles('requirements-dev.txt') }}-lint
python -m pip install --disable-pip-version-check --no-deps . - name: Install Ruff
- name: Lint if: steps.cache.outputs.cache-hit != 'true'
run: | run: grep -E 'ruff==' requirements-dev.txt | xargs pip install
echo "::add-matcher::.github/pylama_matcher.json" - name: Lint
pylama --options tox.ini run: |
echo "::add-matcher::.github/pylama_matcher.json"
ruff --config tests/ruff.toml .
autoformat: autoformat:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4.5.0
with: with:
python-version: 3.9 python-version: 3.9
- name: Install dependencies - name: Cache dependencies
run: | uses: actions/cache@v3.2.4
python -m pip install --disable-pip-version-check black==23.1.0 isort==5.12.0 id: cache
- name: Autoformatter with:
run: | path: ${{ env.pythonLocation }}
black --diff . key: ${{ env.pythonLocation }}-${{ hashFiles('requirements-dev.txt') }}-autoformat
isort --atomic --profile black --check-only . - name: Install Black
if: steps.cache.outputs.cache-hit != 'true'
run: grep -E 'black==' requirements-dev.txt | xargs pip install
- name: Lint
run: |
black --diff --fast .
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:

@ -28,18 +28,16 @@ init: require_pyenv ## Setup a dev environment for local development.
af: autoformat ## Alias for `autoformat` af: autoformat ## Alias for `autoformat`
autoformat: ## Run the autoformatter. autoformat: ## Run the autoformatter.
@pycln . --all --quiet --extend-exclude __init__\.py @-ruff check --config tests/ruff.toml . --fix-only
@# ERA,T201
@-ruff --extend-ignore ANN,ARG001,C90,DTZ,D100,D101,D102,D103,D202,D203,D212,D415,E501,RET504,S101,UP006,UP007 --extend-select C,D400,I,W --unfixable T,ERA --fix-only .
@black . @black .
@isort --atomic --profile black --skip downloads/** .
test: ## Run the tests. test: ## Run the tests.
@pytest @pytest
@echo -e "The tests pass! ✨ 🍰 ✨" @echo -e "The tests pass! ✨ 🍰 ✨"
lint: ## Run the code linter. lint: ## Run the code linter.
@pylama @ruff check --config tests/ruff.toml .
@echo -e "No linting errors - well done! ✨ 🍰 ✨" @echo -e "No linting errors - well done! ✨ 🍰 ✨"
deploy: ## Deploy the package to pypi.org deploy: ## Deploy the package to pypi.org

@ -63,7 +63,7 @@ def generate_image_morph_video():
if os.path.exists(filename): if os.path.exists(filename):
continue continue
result = list(imagine([prompt]))[0] result = next(iter(imagine([prompt])))
generated_image = result.images["generated"] generated_image = result.images["generated"]
draw = ImageDraw.Draw(generated_image) draw = ImageDraw.Draw(generated_image)

@ -33,7 +33,7 @@ def make_bounce_animation(
middle_imgs = shrink_list(middle_imgs, max_frames) middle_imgs = shrink_list(middle_imgs, max_frames)
frames = [first_img] + middle_imgs + [last_img] + list(reversed(middle_imgs)) frames = [first_img, *middle_imgs, last_img, *list(reversed(middle_imgs))]
# convert from latents # convert from latents
converted_frames = [] converted_frames = []

@ -92,13 +92,13 @@ def imagine_image_files(
os.makedirs(subpath, exist_ok=True) os.makedirs(subpath, exist_ok=True)
filepath = os.path.join(subpath, f"{basefilename}.gif") filepath = os.path.join(subpath, f"{basefilename}.gif")
frames = result.progress_latents + [result.images["generated"]] frames = [*result.progress_latents, result.images["generated"]]
if prompt.init_image: if prompt.init_image:
resized_init_image = pillow_fit_image_within( resized_init_image = pillow_fit_image_within(
prompt.init_image, prompt.width, prompt.height prompt.init_image, prompt.width, prompt.height
) )
frames = [resized_init_image] + frames frames = [resized_init_image, *frames]
frames.reverse() frames.reverse()
make_bounce_animation( make_bounce_animation(
imgs=frames, imgs=frames,
@ -170,7 +170,7 @@ def imagine(
logger.info( logger.info(
f"🖼 Generating {i + 1}/{num_prompts}: {prompt.prompt_description()}" f"🖼 Generating {i + 1}/{num_prompts}: {prompt.prompt_description()}"
) )
for attempt in range(0, unsafe_retry_count + 1): for attempt in range(unsafe_retry_count + 1):
if attempt > 0 and isinstance(prompt.seed, int): if attempt > 0 and isinstance(prompt.seed, int):
prompt.seed += 100_000_000 + attempt prompt.seed += 100_000_000 + attempt
result = _generate_single_image( result = _generate_single_image(
@ -238,7 +238,7 @@ def _generate_single_image(
latent_channels = 4 latent_channels = 4
downsampling_factor = 8 downsampling_factor = 8
batch_size = 1 batch_size = 1
global _most_recent_result # noqa global _most_recent_result
# handle prompt pulling in previous values # handle prompt pulling in previous values
# if isinstance(prompt.init_image, str) and prompt.init_image.startswith("*prev"): # if isinstance(prompt.init_image, str) and prompt.init_image.startswith("*prev"):
# _, img_type = prompt.init_image.strip("*").split(".") # _, img_type = prompt.init_image.strip("*").split(".")
@ -457,16 +457,17 @@ def _generate_single_image(
if control_image_t.shape[1] != 3: if control_image_t.shape[1] != 3:
raise RuntimeError("Control image must have 3 channels") raise RuntimeError("Control image must have 3 channels")
if control_input.mode != "inpaint": if (
if control_image_t.min() < 0 or control_image_t.max() > 1: control_input.mode != "inpaint"
raise RuntimeError( and control_image_t.min() < 0
f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}" or control_image_t.max() > 1
) ):
msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
raise RuntimeError(msg)
if control_image_t.max() == control_image_t.min(): if control_image_t.max() == control_image_t.min():
raise RuntimeError( msg = f"No control signal found in control image {control_input.mode}."
f"No control signal found in control image {control_input.mode}." raise RuntimeError(msg)
)
c_cat.append(control_image_t) c_cat.append(control_image_t)
control_strengths.append(control_input.strength) control_strengths.append(control_input.strength)
@ -517,7 +518,7 @@ def _generate_single_image(
if ( if (
prompt.allow_compose_phase prompt.allow_compose_phase
and not is_controlnet_model and not is_controlnet_model
and not model.cond_stage_key == "edit" and model.cond_stage_key != "edit"
): ):
if prompt.init_image: if prompt.init_image:
comp_image = _generate_composition_image( comp_image = _generate_composition_image(

@ -4,6 +4,7 @@ import logging
import shlex import shlex
import traceback import traceback
from functools import update_wrapper from functools import update_wrapper
from typing import ClassVar
import click import click
from click_help_colors import HelpColorsCommand, HelpColorsMixin from click_help_colors import HelpColorsCommand, HelpColorsMixin
@ -43,27 +44,23 @@ def mod_get_invoke(command):
# and that's not ideal when running in a shell. # and that's not ideal when running in a shell.
pass pass
except Exception as e: # noqa except Exception as e: # noqa
traceback.print_exception(e) # noqa traceback.print_exception(e)
# logger.warning(traceback.format_exc()) # logger.warning(traceback.format_exc())
# Always return False so the shell doesn't exit # Always return False so the shell doesn't exit
return False return False
invoke_ = update_wrapper(invoke_, command.callback) invoke_ = update_wrapper(invoke_, command.callback)
invoke_.__name__ = "do_%s" % command.name # noqa invoke_.__name__ = "do_%s" % command.name
return invoke_ return invoke_
class ModClickShell(ClickShell): class ModClickShell(ClickShell):
def add_command(self, cmd, name): def add_command(self, cmd, name):
# Use the MethodType to add these as bound methods to our current instance # Use the MethodType to add these as bound methods to our current instance
setattr( setattr(self, "do_%s" % name, get_method_type(mod_get_invoke(cmd), self))
self, "do_%s" % name, get_method_type(mod_get_invoke(cmd), self) # noqa setattr(self, "help_%s" % name, get_method_type(get_help(cmd), self))
) setattr(self, "complete_%s" % name, get_method_type(get_complete(cmd), self))
setattr(self, "help_%s" % name, get_method_type(get_help(cmd), self)) # noqa
setattr(
self, "complete_%s" % name, get_method_type(get_complete(cmd), self) # noqa
)
class ModShell(Shell): class ModShell(Shell):
@ -85,7 +82,7 @@ class ColorShell(HelpColorsMixin, ModShell):
class ImagineColorsCommand(HelpColorsCommand): class ImagineColorsCommand(HelpColorsCommand):
_option_order = [] _option_order: ClassVar = []
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

@ -43,7 +43,7 @@ remove_option(edit_options, "allow_compose_phase")
) )
@add_options(edit_options) @add_options(edit_options)
@click.pass_context @click.pass_context
def edit_cmd( # noqa def edit_cmd(
ctx, ctx,
image_paths, image_paths,
image_strength, image_strength,
@ -77,7 +77,7 @@ def edit_cmd( # noqa
model_weights_path, model_weights_path,
model_config_path, model_config_path,
prompt_library_path, prompt_library_path,
version, # noqa version,
make_gif, make_gif,
make_compare_gif, make_compare_gif,
arg_schedules, arg_schedules,
@ -130,7 +130,7 @@ def edit_cmd( # noqa
model_weights_path, model_weights_path,
model_config_path, model_config_path,
prompt_library_path, prompt_library_path,
version, # noqa version,
make_gif, make_gif,
make_compare_gif, make_compare_gif,
arg_schedules, arg_schedules,

@ -90,7 +90,7 @@ def imagine_cmd(
model_weights_path, model_weights_path,
model_config_path, model_config_path,
prompt_library_path, prompt_library_path,
version, # noqa version,
make_gif, make_gif,
make_compare_gif, make_compare_gif,
arg_schedules, arg_schedules,
@ -110,7 +110,7 @@ def imagine_cmd(
# hacky method of getting order of control images (mixing raw and normal images) # hacky method of getting order of control images (mixing raw and normal images)
control_images = [ control_images = [
(o, path) (o, path)
for o, path in ImagineColorsCommand._option_order # noqa for o, path in ImagineColorsCommand._option_order
if o.name in ("control_image", "control_image_raw") if o.name in ("control_image", "control_image_raw")
] ]
control_inputs = [] control_inputs = []
@ -176,7 +176,7 @@ def imagine_cmd(
model_weights_path, model_weights_path,
model_config_path, model_config_path,
prompt_library_path, prompt_library_path,
version, # noqa version,
make_gif, make_gif,
make_compare_gif, make_compare_gif,
arg_schedules, arg_schedules,
@ -187,4 +187,4 @@ def imagine_cmd(
if __name__ == "__main__": if __name__ == "__main__":
imagine_cmd() # noqa imagine_cmd()

@ -92,4 +92,4 @@ def model_list_cmd():
if __name__ == "__main__": if __name__ == "__main__":
aimg() # noqa aimg()

@ -43,7 +43,7 @@ def _imagine_cmd(
model_weights_path, model_weights_path,
model_config_path, model_config_path,
prompt_library_path, prompt_library_path,
version=False, # noqa version=False,
make_gif=False, make_gif=False,
make_compare_gif=False, make_compare_gif=False,
arg_schedules=None, arg_schedules=None,
@ -78,10 +78,7 @@ def _imagine_cmd(
configure_logging(log_level) configure_logging(log_level)
if isinstance(init_image, str): init_images = [init_image] if isinstance(init_image, str) else init_image
init_images = [init_image]
else:
init_images = init_image
from imaginairy.utils import glob_expand_paths from imaginairy.utils import glob_expand_paths
@ -89,9 +86,8 @@ def _imagine_cmd(
init_images = glob_expand_paths(init_images) init_images = glob_expand_paths(init_images)
if len(init_images) < num_prexpaned_init_images: if len(init_images) < num_prexpaned_init_images:
raise ValueError( msg = f"Could not find any images matching the glob pattern(s) {init_image}. Are you sure the file(s) exists?"
f"Could not find any images matching the glob pattern(s) {init_image}. Are you sure the file(s) exists?" raise ValueError(msg)
)
total_image_count = len(prompt_texts) * max(len(init_images), 1) * repeats total_image_count = len(prompt_texts) * max(len(init_images), 1) * repeats
logger.info( logger.info(
@ -227,7 +223,8 @@ def replace_option(options, option_name, new_option):
if option.name == option_name: if option.name == option_name:
options[i] = new_option options[i] = new_option
return return
raise ValueError(f"Option {option_name} not found") msg = f"Option {option_name} not found"
raise ValueError(msg)
def remove_option(options, option_name): def remove_option(options, option_name):
@ -242,7 +239,8 @@ def remove_option(options, option_name):
if option.name == option_name: if option.name == option_name:
del options[i] del options[i]
return return
raise ValueError(f"Option {option_name} not found") msg = f"Option {option_name} not found"
raise ValueError(msg)
common_options = [ common_options = [

@ -36,7 +36,7 @@ def colorize_img(img, max_width=1024, max_height=1024, caption=None):
steps=30, steps=30,
prompt_strength=12, prompt_strength=12,
) )
result = list(imagine(prompt))[0] result = next(iter(imagine(prompt)))
colorized_img = replace_color(img, result.images["generated"]) colorized_img = replace_color(img, result.images["generated"])
# allows the algorithm some leeway for the overall brightness of the image # allows the algorithm some leeway for the overall brightness of the image

@ -18,6 +18,7 @@ Examples:
""" """
import operator import operator
from abc import ABC from abc import ABC
from typing import ClassVar
import pyparsing as pp import pyparsing as pp
import torch import torch
@ -57,7 +58,7 @@ class SimpleMask(Mask):
class ModifiedMask(Mask): class ModifiedMask(Mask):
ops = { ops: ClassVar = {
"+": operator.add, "+": operator.add,
"-": operator.sub, "-": operator.sub,
"*": operator.mul, "*": operator.mul,
@ -80,7 +81,7 @@ class ModifiedMask(Mask):
return cls(mask=ret_tokens[0][0], modifier=ret_tokens[0][1]) return cls(mask=ret_tokens[0][0], modifier=ret_tokens[0][1])
def __repr__(self): def __repr__(self):
return f"{repr(self.mask)}{self.modifier}" return f"{self.mask!r}{self.modifier}"
def gather_text_descriptions(self): def gather_text_descriptions(self):
return self.mask.gather_text_descriptions() return self.mask.gather_text_descriptions()
@ -141,7 +142,8 @@ class NestedMask(Mask):
elif self.op == "NOT": elif self.op == "NOT":
mask = 1 - mask mask = 1 - mask
else: else:
raise ValueError(f"Invalid operand {self.op}") msg = f"Invalid operand {self.op}"
raise ValueError(msg)
return torch.clamp(mask, 0, 1) return torch.clamp(mask, 0, 1)

@ -14,9 +14,9 @@ from imaginairy.vendored.clipseg import CLIPDensePredT
weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth" weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth"
@lru_cache() @lru_cache
def clip_mask_model(): def clip_mask_model():
from imaginairy.paths import PKG_ROOT # noqa from imaginairy.paths import PKG_ROOT
model = CLIPDensePredT(version="ViT-B/16", reduce_dim=64, complex_trans_conv=True) model = CLIPDensePredT(version="ViT-B/16", reduce_dim=64, complex_trans_conv=True)
model.eval() model.eval()
@ -36,7 +36,7 @@ def get_img_mask(
mask_description_statement: str, mask_description_statement: str,
threshold: Optional[float] = None, threshold: Optional[float] = None,
): ):
from imaginairy.enhancers.bool_masker import MASK_PROMPT # noqa from imaginairy.enhancers.bool_masker import MASK_PROMPT
parsed = MASK_PROMPT.parseString(mask_description_statement) parsed = MASK_PROMPT.parseString(mask_description_statement)
parsed_mask = parsed[0][0] parsed_mask = parsed[0][0]

@ -17,9 +17,9 @@ if "mps" in device:
BLIP_EVAL_SIZE = 384 BLIP_EVAL_SIZE = 384
@lru_cache() @lru_cache
def blip_model(): def blip_model():
from imaginairy.paths import PKG_ROOT # noqa from imaginairy.paths import PKG_ROOT
config_path = os.path.join( config_path = os.path.join(
PKG_ROOT, "vendored", "blip", "configs", "med_config.json" PKG_ROOT, "vendored", "blip", "configs", "med_config.json"
@ -28,7 +28,7 @@ def blip_model():
model = BLIP_Decoder(image_size=BLIP_EVAL_SIZE, vit="base", med_config=config_path) model = BLIP_Decoder(image_size=BLIP_EVAL_SIZE, vit="base", med_config=config_path)
cached_url_path = get_cached_url_path(url) cached_url_path = get_cached_url_path(url)
model, msg = load_checkpoint(model, cached_url_path) # noqa model, msg = load_checkpoint(model, cached_url_path)
model.eval() model.eval()
model = model.to(device) model = model.to(device)
return model return model

@ -10,7 +10,7 @@ from imaginairy.vendored import clip
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
@lru_cache() @lru_cache
def get_model(): def get_model():
model_name = "ViT-L/14" model_name = "ViT-L/14"
model, preprocess = clip.load(model_name, device=device) model, preprocess = clip.load(model_name, device=device)

@ -17,7 +17,7 @@ face_restore_device = torch.device("cuda" if torch.cuda.is_available() else "cpu
half_mode = face_restore_device == "cuda" half_mode = face_restore_device == "cuda"
@lru_cache() @lru_cache
def codeformer_model(): def codeformer_model():
model = CodeFormer( model = CodeFormer(
dim_embd=512, dim_embd=512,
@ -36,7 +36,7 @@ def codeformer_model():
return model return model
@lru_cache() @lru_cache
def face_restore_helper(): def face_restore_helper():
""" """
Provide a singleton of FaceRestoreHelper. Provide a singleton of FaceRestoreHelper.
@ -85,11 +85,11 @@ def enhance_faces(img, fidelity=0):
try: try:
with torch.no_grad(): with torch.no_grad():
output = net(cropped_face_t, w=fidelity, adain=True)[0] # noqa output = net(cropped_face_t, w=fidelity, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output del output
torch.cuda.empty_cache() torch.cuda.empty_cache()
except Exception as error: # noqa except Exception as error:
logger.exception(f"\tFailed inference for CodeFormer: {error}") logger.exception(f"\tFailed inference for CodeFormer: {error}")
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))

@ -15,10 +15,10 @@ formatter = Formatter()
PROMPT_EXPANSION_PATTERN = re.compile(r"[|a-z0-9_ -]+") PROMPT_EXPANSION_PATTERN = re.compile(r"[|a-z0-9_ -]+")
@lru_cache() @lru_cache
def prompt_library_filepaths(prompt_library_paths=None): def prompt_library_filepaths(prompt_library_paths=None):
"""Return all available category/filepath pairs.""" """Return all available category/filepath pairs."""
prompt_library_paths = [] if not prompt_library_paths else prompt_library_paths prompt_library_paths = prompt_library_paths if prompt_library_paths else []
combined_prompt_library_filepaths = {} combined_prompt_library_filepaths = {}
for prompt_path in DEFAULT_PROMPT_LIBRARY_PATHS + list(prompt_library_paths): for prompt_path in DEFAULT_PROMPT_LIBRARY_PATHS + list(prompt_library_paths):
library_prompts = prompt_library_filepath(prompt_path) library_prompts = prompt_library_filepath(prompt_path)
@ -27,7 +27,7 @@ def prompt_library_filepaths(prompt_library_paths=None):
return combined_prompt_library_filepaths return combined_prompt_library_filepaths
@lru_cache() @lru_cache
def category_list(prompt_library_paths=None): def category_list(prompt_library_paths=None):
"""Return the names of available phrase-lists.""" """Return the names of available phrase-lists."""
categories = list(prompt_library_filepaths(prompt_library_paths).keys()) categories = list(prompt_library_filepaths(prompt_library_paths).keys())
@ -35,7 +35,7 @@ def category_list(prompt_library_paths=None):
return categories return categories
@lru_cache() @lru_cache
def prompt_library_filepath(library_path): def prompt_library_filepath(library_path):
lookup = {} lookup = {}
@ -55,9 +55,8 @@ def get_phrases(category_name, prompt_library_paths=None):
try: try:
filepath = lookup[category_name] filepath = lookup[category_name]
except KeyError as e: except KeyError as e:
raise LookupError( msg = f"'{category_name}' is not a valid prompt expansion category. Could not find the txt file."
f"'{category_name}' is not a valid prompt expansion category. Could not find the txt file." raise LookupError(msg) from e
) from e
_open = open _open = open
if filepath.endswith(".gz"): if filepath.endswith(".gz"):
_open = gzip.open _open = gzip.open
@ -83,13 +82,12 @@ def expand_prompts(prompt_text, n=1, prompt_library_paths=None):
""" """
prompt_parts = list(formatter.parse(prompt_text)) prompt_parts = list(formatter.parse(prompt_text))
field_names = [] field_names = []
for literal_text, field_name, format_spec, conversion in prompt_parts: # noqa for literal_text, field_name, format_spec, conversion in prompt_parts:
if field_name: if field_name:
field_name = field_name.lower() field_name = field_name.lower()
if not PROMPT_EXPANSION_PATTERN.match(field_name): if not PROMPT_EXPANSION_PATTERN.match(field_name):
raise ValueError( msg = "Invalid prompt expansion. Only a-z0-9_|- characters permitted. "
"Invalid prompt expansion. Only a-z0-9_|- characters permitted. " raise ValueError(msg)
)
field_names.append(field_name) field_names.append(field_name)
phrases = [] phrases = []
@ -120,9 +118,7 @@ def expand_prompts(prompt_text, n=1, prompt_library_paths=None):
yield output_prompt yield output_prompt
def get_random_non_repeating_combination( # noqa def get_random_non_repeating_combination(n=1, *sequences, allow_oversampling=True):
n=1, *sequences, allow_oversampling=True
):
""" """
Efficiently return a non-repeating random sample of the product sequences. Efficiently return a non-repeating random sample of the product sequences.

@ -187,7 +187,7 @@ class CLIPEmbedder(nn.Module):
) )
@lru_cache() @lru_cache
def clip_up_models(): def clip_up_models():
with platform_appropriate_autocast(): with platform_appropriate_autocast():
tok_up = CLIPTokenizerTransform() tok_up = CLIPTokenizerTransform()
@ -290,7 +290,8 @@ def upscale_latent(
eta=eta, eta=eta,
**sampler_opts, **sampler_opts,
) )
raise ValueError(f"Unknown sampler {sampler}") msg = f"Unknown sampler {sampler}"
raise ValueError(msg)
for _ in range((num_samples - 1) // batch_size + 1): for _ in range((num_samples - 1) // batch_size + 1):
if noise_aug_type == "gaussian": if noise_aug_type == "gaussian":
@ -300,7 +301,7 @@ def upscale_latent(
elif noise_aug_type == "fake": elif noise_aug_type == "fake":
latent_noised = low_res_latent * (noise_aug_level**2 + 1) ** 0.5 latent_noised = low_res_latent * (noise_aug_level**2 + 1) ** 0.5
extra_args = { extra_args = {
"low_res": latent_noised, # noqa "low_res": latent_noised,
"low_res_sigma": low_res_sigma, "low_res_sigma": low_res_sigma,
"c": c, "c": c,
} }

@ -63,7 +63,7 @@ class StableStudioInput(BaseModel, extra=Extra.forbid):
initial_image: Optional[StableStudioInputImage] = Field(None, alias="initialImage") initial_image: Optional[StableStudioInputImage] = Field(None, alias="initialImage")
@validator("seed") @validator("seed")
def validate_seed(cls, v): # noqa def validate_seed(cls, v):
if v == 0: if v == 0:
return None return None
return v return v
@ -74,10 +74,7 @@ class StableStudioInput(BaseModel, extra=Extra.forbid):
from PIL import Image from PIL import Image
if self.prompts: positive_prompt = self.prompts[0].text if self.prompts else None
positive_prompt = self.prompts[0].text
else:
positive_prompt = None
if self.prompts and len(self.prompts) > 1: if self.prompts and len(self.prompts) > 1:
negative_prompt = self.prompts[1].text if len(self.prompts) > 1 else None negative_prompt = self.prompts[1].text if len(self.prompts) > 1 else None
else: else:

@ -79,7 +79,7 @@ def _create_depth_map_raw(img):
align_corners=False, align_corners=False,
) )
depth_pt = model(img)[0] # noqa depth_pt = model(img)[0]
return depth_pt return depth_pt
@ -209,7 +209,10 @@ def inpaint_prep(mask_image_t, target_image_t):
def to_grayscale(img): def to_grayscale(img):
# The dimensions of input should be (batch_size, channels, height, width) # The dimensions of input should be (batch_size, channels, height, width)
assert img.dim() == 4 and img.size(1) == 3 if img.dim() != 4:
raise ValueError("Input should be a 4d tensor")
if img.size(1) != 3:
raise ValueError("Input should have 3 channels")
# Apply the formula to convert to grayscale. # Apply the formula to convert to grayscale.
gray = ( gray = (

@ -3,7 +3,7 @@ from collections import OrderedDict
from functools import lru_cache from functools import lru_cache
import cv2 import cv2
import matplotlib import matplotlib as mpl
import numpy as np import numpy as np
import torch import torch
from scipy.ndimage.filters import gaussian_filter from scipy.ndimage.filters import gaussian_filter
@ -40,7 +40,7 @@ def pad_right_down_corner(img, stride, padValue):
def transfer(model, model_weights): def transfer(model, model_weights):
# transfer caffe model to pytorch which will match the layer name # transfer caffe model to pytorch which will match the layer name
transfered_model_weights = {} transfered_model_weights = {}
for weights_name in model.state_dict().keys(): for weights_name in model.state_dict():
transfered_model_weights[weights_name] = model_weights[ transfered_model_weights[weights_name] = model_weights[
".".join(weights_name.split(".")[1:]) ".".join(weights_name.split(".")[1:])
] ]
@ -93,14 +93,14 @@ def draw_bodypose(canvas, candidate, subset):
[255, 0, 85], [255, 0, 85],
] ]
for i in range(18): for i in range(18):
for n in range(len(subset)): # noqa for n in range(len(subset)):
index = int(subset[n][i]) index = int(subset[n][i])
if index == -1: if index == -1:
continue continue
x, y = candidate[index][0:2] x, y = candidate[index][0:2]
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
for i in range(17): for i in range(17):
for n in range(len(subset)): # noqa for n in range(len(subset)):
index = subset[n][np.array(limbSeq[i]) - 1] index = subset[n][np.array(limbSeq[i]) - 1]
if -1 in index: if -1 in index:
continue continue
@ -155,8 +155,7 @@ def draw_handpose(canvas, all_hand_peaks, show_number=False):
canvas, canvas,
(x1, y1), (x1, y1),
(x2, y2), (x2, y2),
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) mpl.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
* 255,
thickness=2, thickness=2,
) )
@ -368,7 +367,7 @@ class bodypose_model(nn.Module):
] ]
) )
for k in blocks.keys(): for k in blocks:
blocks[k] = make_layers(blocks[k], no_relu_layers) blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_1 = blocks["block1_1"] self.model1_1 = blocks["block1_1"]
@ -473,7 +472,7 @@ class handpose_model(nn.Module):
] ]
) )
for k in blocks.keys(): for k in blocks:
blocks[k] = make_layers(blocks[k], no_relu_layers) blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_0 = blocks["block1_0"] self.model1_0 = blocks["block1_0"]
@ -625,7 +624,7 @@ def create_body_pose(original_img_t):
peaks = list( peaks = list(
zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0]) zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])
) # note reverse ) # note reverse
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks] peaks_with_score = [(*x, map_ori[x[1], x[0]]) for x in peaks]
peak_id = range(peak_counter, peak_counter + len(peaks)) peak_id = range(peak_counter, peak_counter + len(peaks))
peaks_with_score_and_id = [ peaks_with_score_and_id = [
peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id)) peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))
@ -751,7 +750,7 @@ def create_body_pose(original_img_t):
connection_candidate, key=lambda x: x[2], reverse=True connection_candidate, key=lambda x: x[2], reverse=True
) )
connection = np.zeros((0, 5)) connection = np.zeros((0, 5))
for c in range(len(connection_candidate)): # noqa for c in range(len(connection_candidate)):
i, j, s = connection_candidate[c][0:3] i, j, s = connection_candidate[c][0:3]
if i not in connection[:, 3] and j not in connection[:, 4]: if i not in connection[:, 3] and j not in connection[:, 4]:
connection = np.vstack( connection = np.vstack(

@ -101,9 +101,10 @@ def torch_img_to_pillow_img(img_t: torch.Tensor):
elif img_t.shape[1] == 3: elif img_t.shape[1] == 3:
colorspace = "RGB" colorspace = "RGB"
else: else:
raise ValueError( msg = (
f"Unsupported colorspace. {img_t.shape[1]} channels in {img_t.shape} shape" f"Unsupported colorspace. {img_t.shape[1]} channels in {img_t.shape} shape"
) )
raise ValueError(msg)
img_t = rearrange(img_t, "b c h w -> b h w c") img_t = rearrange(img_t, "b c h w -> b h w c")
img_t = torch.clamp((img_t + 1.0) / 2.0, min=0.0, max=1.0) img_t = torch.clamp((img_t + 1.0) / 2.0, min=0.0, max=1.0)
img_np = (255.0 * img_t).cpu().numpy().astype(np.uint8)[0] img_np = (255.0 * img_t).cpu().numpy().astype(np.uint8)[0]
@ -113,7 +114,7 @@ def torch_img_to_pillow_img(img_t: torch.Tensor):
def model_latent_to_pillow_img(latent: torch.Tensor) -> PIL.Image.Image: def model_latent_to_pillow_img(latent: torch.Tensor) -> PIL.Image.Image:
from imaginairy.model_manager import get_current_diffusion_model # noqa from imaginairy.model_manager import get_current_diffusion_model
if len(latent.shape) == 3: if len(latent.shape) == 3:
latent = latent.unsqueeze(0) latent = latent.unsqueeze(0)

@ -94,13 +94,13 @@ class ImageLoggingContext:
self._prev_log_context = None self._prev_log_context = None
def __enter__(self): def __enter__(self):
global _CURRENT_LOGGING_CONTEXT # noqa global _CURRENT_LOGGING_CONTEXT
self._prev_log_context = _CURRENT_LOGGING_CONTEXT self._prev_log_context = _CURRENT_LOGGING_CONTEXT
_CURRENT_LOGGING_CONTEXT = self _CURRENT_LOGGING_CONTEXT = self
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_LOGGING_CONTEXT # noqa global _CURRENT_LOGGING_CONTEXT
_CURRENT_LOGGING_CONTEXT = self._prev_log_context _CURRENT_LOGGING_CONTEXT = self._prev_log_context
def timing(self, description): def timing(self, description):
@ -120,21 +120,20 @@ class ImageLoggingContext:
) )
def log_latents(self, latents, description): def log_latents(self, latents, description):
from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa from imaginairy.img_utils import model_latents_to_pillow_imgs
if "predicted_latent" in description: if "predicted_latent" in description:
if self.progress_latent_callback is not None: if self.progress_latent_callback is not None:
self.progress_latent_callback(latents) self.progress_latent_callback(latents)
if ( if (
self.step_count - self.last_progress_img_step self.step_count - self.last_progress_img_step
) > self.progress_img_interval_steps: ) > self.progress_img_interval_steps and (
if ( time.perf_counter() - self.last_progress_img_ts
time.perf_counter() - self.last_progress_img_ts > self.progress_img_interval_min_s
> self.progress_img_interval_min_s ):
): self.log_progress_latent(latents)
self.log_progress_latent(latents) self.last_progress_img_step = self.step_count
self.last_progress_img_step = self.step_count self.last_progress_img_ts = time.perf_counter()
self.last_progress_img_ts = time.perf_counter()
if not self.debug_img_callback: if not self.debug_img_callback:
return return
@ -168,7 +167,7 @@ class ImageLoggingContext:
) )
def log_progress_latent(self, latent): def log_progress_latent(self, latent):
from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa from imaginairy.img_utils import model_latents_to_pillow_imgs
if not self.progress_img_callback: if not self.progress_img_callback:
return return
@ -280,7 +279,7 @@ def disable_pytorch_lighting_custom_logging():
from pytorch_lightning import _logger as pytorch_logger from pytorch_lightning import _logger as pytorch_logger
try: try:
from pytorch_lightning.utilities.seed import log # noqa from pytorch_lightning.utilities.seed import log
log.setLevel(logging.NOTSET) log.setLevel(logging.NOTSET)
log.handlers = [] log.handlers = []

@ -24,9 +24,8 @@ class LambdaWarmUpCosineScheduler:
self.verbosity_interval = verbosity_interval self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs): def schedule(self, n, **kwargs):
if self.verbosity_interval > 0: if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps: if n < self.lr_warm_up_steps:
lr = ( lr = (
self.lr_max - self.lr_start self.lr_max - self.lr_start
@ -66,7 +65,7 @@ class LambdaWarmUpCosineScheduler2:
self.f_min = f_min self.f_min = f_min
self.f_max = f_max self.f_max = f_max
self.cycle_lengths = cycle_lengths self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) self.cum_cycles = np.cumsum([0, *list(self.cycle_lengths)])
self.last_f = 0.0 self.last_f = 0.0
self.verbosity_interval = verbosity_interval self.verbosity_interval = verbosity_interval
@ -81,12 +80,11 @@ class LambdaWarmUpCosineScheduler2:
def schedule(self, n, **kwargs): def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n) cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle] n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0: if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
if n % self.verbosity_interval == 0: print(
print( f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}"
f"current cycle {cycle}" )
)
if n < self.lr_warm_up_steps[cycle]: if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle cycle
@ -112,12 +110,11 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs): def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n) cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle] n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0: if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
if n % self.verbosity_interval == 0: print(
print( f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}"
f"current cycle {cycle}" )
)
if n < self.lr_warm_up_steps[cycle]: if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[

@ -7,9 +7,11 @@ from functools import wraps
import requests import requests
import torch import torch
from huggingface_hub import HfFolder from huggingface_hub import (
from huggingface_hub import hf_hub_download as _hf_hub_download HfFolder,
from huggingface_hub import try_to_load_from_cache hf_hub_download as _hf_hub_download,
try_to_load_from_cache,
)
from omegaconf import OmegaConf from omegaconf import OmegaConf
from safetensors.torch import load_file from safetensors.torch import load_file
@ -65,16 +67,19 @@ def load_state_dict(weights_location, half_mode=False, device=None):
f'Error: "{ckpt_path}" not a valid path to model weights.\nPreconfigured models you can use: {MODEL_SHORT_NAMES}.' f'Error: "{ckpt_path}" not a valid path to model weights.\nPreconfigured models you can use: {MODEL_SHORT_NAMES}.'
) )
sys.exit(1) sys.exit(1)
raise e raise
except RuntimeError as e: except RuntimeError as e:
if "PytorchStreamReader failed reading zip archive" in str(e): err_str = str(e)
if weights_location.startswith("http"): if (
logger.warning("Corrupt checkpoint. deleting and re-downloading...") "PytorchStreamReader failed reading zip archive" in err_str
os.remove(ckpt_path) and weights_location.startswith("http")
ckpt_path = get_cached_url_path(weights_location, category="weights") ):
state_dict = load_tensors(ckpt_path, map_location="cpu") logger.warning("Corrupt checkpoint. deleting and re-downloading...")
os.remove(ckpt_path)
ckpt_path = get_cached_url_path(weights_location, category="weights")
state_dict = load_tensors(ckpt_path, map_location="cpu")
if state_dict is None: if state_dict is None:
raise e raise
state_dict = state_dict.get("state_dict", state_dict) state_dict = state_dict.get("state_dict", state_dict)
@ -166,7 +171,7 @@ def get_diffusion_model(
except HuggingFaceAuthorizationError as e: except HuggingFaceAuthorizationError as e:
if for_inpainting: if for_inpainting:
logger.warning( logger.warning(
f"Failed to load inpainting model. Attempting to fall-back to standard model. {str(e)}" f"Failed to load inpainting model. Attempting to fall-back to standard model. {e!s}"
) )
return _get_diffusion_model( return _get_diffusion_model(
iconfig.DEFAULT_MODEL, iconfig.DEFAULT_MODEL,
@ -176,7 +181,7 @@ def get_diffusion_model(
for_training=for_training, for_training=for_training,
control_weights_locations=control_weights_locations, control_weights_locations=control_weights_locations,
) )
raise e raise
def _get_diffusion_model( def _get_diffusion_model(
@ -192,7 +197,7 @@ def _get_diffusion_model(
Weights location may also be shortcut name, e.g. "SD-1.5" Weights location may also be shortcut name, e.g. "SD-1.5"
""" """
global MOST_RECENTLY_LOADED_MODEL # noqa global MOST_RECENTLY_LOADED_MODEL
( (
model_config, model_config,
@ -293,9 +298,8 @@ def resolve_model_paths(
if for_training: if for_training:
weights_path = model_metadata_w.weights_url_full weights_path = model_metadata_w.weights_url_full
if weights_path is None: if weights_path is None:
raise ValueError( msg = "No full training weights configured for this model. Edit the code or subimt a github issue."
"No full training weights configured for this model. Edit the code or subimt a github issue." raise ValueError(msg)
)
else: else:
weights_path = model_metadata_w.weights_url weights_path = model_metadata_w.weights_url
@ -306,9 +310,8 @@ def resolve_model_paths(
config_path = iconfig.MODEL_CONFIG_SHORTCUTS[iconfig.DEFAULT_MODEL].config_path config_path = iconfig.MODEL_CONFIG_SHORTCUTS[iconfig.DEFAULT_MODEL].config_path
if control_net_metadatas: if control_net_metadatas:
if "stable-diffusion-v1" not in config_path: if "stable-diffusion-v1" not in config_path:
raise ValueError( msg = "Control net is only supported for stable diffusion v1. Please use a different model."
"Control net is only supported for stable diffusion v1. Please use a different model." raise ValueError(msg)
)
control_weights_paths = [cnm.weights_url for cnm in control_net_metadatas] control_weights_paths = [cnm.weights_url for cnm in control_net_metadatas]
config_path = control_net_metadatas[0].config_path config_path = control_net_metadatas[0].config_path
model_metadata = model_metadata_w or model_metadata_c model_metadata = model_metadata_w or model_metadata_c
@ -374,7 +377,7 @@ def get_cached_url_path(url, category=None):
os.rename(old_dest_path, dest_path) os.rename(old_dest_path, dest_path)
return dest_path return dest_path
r = requests.get(url) # noqa r = requests.get(url)
with open(dest_path, "wb") as f: with open(dest_path, "wb") as f:
f.write(r.content) f.write(r.content)
@ -390,12 +393,8 @@ def check_huggingface_url_authorized(url):
headers["authorization"] = f"Bearer {token}" headers["authorization"] = f"Bearer {token}"
response = requests.head(url, allow_redirects=True, headers=headers, timeout=5) response = requests.head(url, allow_redirects=True, headers=headers, timeout=5)
if response.status_code == 401: if response.status_code == 401:
raise HuggingFaceAuthorizationError( msg = "Unauthorized access to HuggingFace model. This model requires a huggingface token. Please login to HuggingFace or set HUGGING_FACE_HUB_TOKEN to your User Access Token. See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information"
"Unauthorized access to HuggingFace model. This model requires a huggingface token. " raise HuggingFaceAuthorizationError(msg)
"Please login to HuggingFace "
"or set HUGGING_FACE_HUB_TOKEN to your User Access Token. "
"See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information"
)
return None return None
@ -413,7 +412,7 @@ def hf_hub_download(*args, **kwargs):
if "unexpected keyword argument 'token'" in str(e): if "unexpected keyword argument 'token'" in str(e):
kwargs["use_auth_token"] = kwargs.pop("token") kwargs["use_auth_token"] = kwargs.pop("token")
return _hf_hub_download(*args, **kwargs) return _hf_hub_download(*args, **kwargs)
raise e raise
def huggingface_cached_path(url): def huggingface_cached_path(url):

@ -14,8 +14,8 @@ XFORMERS_IS_AVAILABLE = False
try: try:
if get_device() == "cuda": if get_device() == "cuda":
import xformers # noqa import xformers
import xformers.ops # noqa import xformers.ops
XFORMERS_IS_AVAILABLE = True XFORMERS_IS_AVAILABLE = True
except ImportError: except ImportError:
@ -79,7 +79,7 @@ class LinearAttention(nn.Module):
self.to_out = nn.Conv2d(hidden_dim, dim, 1) self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape # noqa b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
q, k, v = rearrange( q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
@ -120,7 +120,7 @@ class SpatialSelfAttention(nn.Module):
v = self.v(h_) v = self.v(h_)
# compute attention # compute attention
b, c, h, w = q.shape # noqa b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c") q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)") k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k) w_ = torch.einsum("bij,bjk->bik", q, k)
@ -183,7 +183,7 @@ class CrossAttention(nn.Module):
# if mask is None and _global_mask_hack is not None: # if mask is None and _global_mask_hack is not None:
# mask = _global_mask_hack.to(torch.bool) # mask = _global_mask_hack.to(torch.bool)
if get_device() == "cuda" or "mps" in get_device(): if get_device() == "cuda" or "mps" in get_device(): # noqa
if not XFORMERS_IS_AVAILABLE and ALLOW_SPLITMEM: if not XFORMERS_IS_AVAILABLE and ALLOW_SPLITMEM:
return self.forward_splitmem(x, context=context, mask=mask) return self.forward_splitmem(x, context=context, mask=mask)
@ -222,7 +222,7 @@ class CrossAttention(nn.Module):
out = rearrange(out, "(b h) n d -> b n (h d)", h=h) out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out) return self.to_out(out)
def forward_splitmem(self, x, context=None, mask=None): # noqa def forward_splitmem(self, x, context=None, mask=None):
h = self.heads h = self.heads
q_in = self.to_q(x) q_in = self.to_q(x)
@ -262,10 +262,8 @@ class CrossAttention(nn.Module):
max_res = ( max_res = (
math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
) )
raise RuntimeError( msg = f"Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free"
f"Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). " raise RuntimeError(msg)
f"Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free"
)
slice_size = ( slice_size = (
q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
) )
@ -474,7 +472,7 @@ class SpatialTransformer(nn.Module):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list): if not isinstance(context, list):
context = [context] context = [context]
b, c, h, w = x.shape # noqa b, c, h, w = x.shape
x_in = x x_in = x
x = self.norm(x) x = self.norm(x)
if self.use_linear: if self.use_linear:

@ -290,10 +290,7 @@ class AutoencoderKL(pl.LightningModule):
def forward(self, input, sample_posterior=True): # noqa def forward(self, input, sample_posterior=True): # noqa
posterior = self.encode(input) posterior = self.encode(input)
if sample_posterior: z = posterior.sample() if sample_posterior else posterior.mode()
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z) dec = self.decode(z)
return dec, posterior return dec, posterior
@ -484,7 +481,7 @@ class AutoencoderKL(pl.LightningModule):
:param x: img of size (bs, c, h, w) :param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
""" """
bs, nc, h, w = x.shape # noqa bs, nc, h, w = x.shape
# number of crops in image # number of crops in image
Ly = (h - kernel_size[0]) // stride[0] + 1 Ly = (h - kernel_size[0]) // stride[0] + 1

@ -19,12 +19,12 @@ from imaginairy.modules.diffusion.util import (
class ControlledUnetModel(UNetModel): class ControlledUnetModel(UNetModel):
def forward( # noqa def forward(
self, self,
x, x,
timesteps=None, timesteps=None,
context=None, context=None,
control=None, # noqa control=None,
only_mid_control=False, only_mid_control=False,
**kwargs, **kwargs,
): ):
@ -129,10 +129,8 @@ class ControlNet(nn.Module):
self.num_res_blocks = len(channel_mult) * [num_res_blocks] self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else: else:
if len(num_res_blocks) != len(channel_mult): if len(num_res_blocks) != len(channel_mult):
raise ValueError( msg = "provide num_res_blocks either as an int (globally constant) or as a list/tuple (per-level) with the same length as channel_mult"
"provide num_res_blocks either as an int (globally constant) or " raise ValueError(msg)
"as a list/tuple (per-level) with the same length as channel_mult"
)
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None: if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
@ -140,10 +138,8 @@ class ControlNet(nn.Module):
if num_attention_blocks is not None: if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks) assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all( assert all(
( self.num_res_blocks[i] >= num_attention_blocks[i]
self.num_res_blocks[i] >= num_attention_blocks[i] for i in range(len(num_attention_blocks))
for i in range(len(num_attention_blocks))
)
) )
print( print(
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
@ -425,10 +421,10 @@ class ControlLDM(LatentDiffusion):
if is_diffusing: if is_diffusing:
self.model = self.model.cuda() self.model = self.model.cuda()
self.control_models = [cm.cuda() for cm in self.control_models] self.control_models = [cm.cuda() for cm in self.control_models]
self.first_stage_model = self.first_stage_model.cpu() # noqa self.first_stage_model = self.first_stage_model.cpu()
self.cond_stage_model = self.cond_stage_model.cpu() self.cond_stage_model = self.cond_stage_model.cpu()
else: else:
self.model = self.model.cpu() self.model = self.model.cpu()
self.control_models = [cm.cpu() for cm in self.control_models] self.control_models = [cm.cpu() for cm in self.control_models]
self.first_stage_model = self.first_stage_model.cuda() # noqa self.first_stage_model = self.first_stage_model.cuda()
self.cond_stage_model = self.cond_stage_model.cuda() self.cond_stage_model = self.cond_stage_model.cuda()

@ -102,9 +102,7 @@ class FrozenClipImageEmbedder(nn.Module):
antialias=False, antialias=False,
): ):
super().__init__() super().__init__()
self.model, preprocess = clip.load( # noqa self.model, preprocess = clip.load(name=model_name, device=device, jit=jit)
name=model_name, device=device, jit=jit
)
self.antialias = antialias self.antialias = antialias

@ -9,6 +9,7 @@ import itertools
import logging import logging
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from functools import partial from functools import partial
from typing import Optional
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
@ -371,7 +372,7 @@ class DDPM(pl.LightningModule):
# we only modify first two axes # we only modify first two axes
assert new_shape[2:] == old_shape[2:] assert new_shape[2:] == old_shape[2:]
# assumes first axis corresponds to output dim # assumes first axis corresponds to output dim
if not new_shape == old_shape: if new_shape != old_shape:
new_param = param.clone() new_param = param.clone()
old_param = sd[name] old_param = sd[name]
if len(new_shape) == 1: if len(new_shape) == 1:
@ -495,7 +496,7 @@ class DDPM(pl.LightningModule):
img = torch.randn(shape, device=device) img = torch.randn(shape, device=device)
intermediates = [img] intermediates = [img]
for i in tqdm( for i in tqdm(
reversed(range(0, self.num_timesteps)), reversed(range(self.num_timesteps)),
desc="Sampling t", desc="Sampling t",
total=self.num_timesteps, total=self.num_timesteps,
): ):
@ -563,9 +564,8 @@ class DDPM(pl.LightningModule):
elif self.parameterization == "v": elif self.parameterization == "v":
target = self.get_v(x_start, noise, t) target = self.get_v(x_start, noise, t)
else: else:
raise NotImplementedError( msg = f"Parameterization {self.parameterization} not yet supported"
f"Parameterization {self.parameterization} not yet supported" raise NotImplementedError(msg)
)
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
@ -706,7 +706,7 @@ class DDPM(pl.LightningModule):
lr = self.learning_rate lr = self.learning_rate
params = list(self.model.parameters()) params = list(self.model.parameters())
if self.learn_logvar: if self.learn_logvar:
params = params + [self.logvar] params = [*params, self.logvar]
opt = torch.optim.AdamW(params, lr=lr) opt = torch.optim.AdamW(params, lr=lr)
return opt return opt
@ -716,7 +716,7 @@ def _TileModeConv2DConvForward(
): ):
if self.padding_modeX == self.padding_modeY: if self.padding_modeX == self.padding_modeY:
self.padding_mode = self.padding_modeX self.padding_mode = self.padding_modeX
return self._orig_conv_forward(input, weight, bias) # noqa return self._orig_conv_forward(input, weight, bias)
w1 = F.pad(input, self.paddingX, mode=self.padding_modeX) w1 = F.pad(input, self.paddingX, mode=self.padding_modeX)
del input del input
@ -790,9 +790,7 @@ class LatentDiffusion(DDPM):
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
m._initial_padding_mode = m.padding_mode m._initial_padding_mode = m.padding_mode
m._orig_conv_forward = m._conv_forward m._orig_conv_forward = m._conv_forward
m._conv_forward = _TileModeConv2DConvForward.__get__( # noqa m._conv_forward = _TileModeConv2DConvForward.__get__(m, nn.Conv2d)
m, nn.Conv2d
)
self.tile_mode(tile_mode=False) self.tile_mode(tile_mode=False)
def tile_mode(self, tile_mode): def tile_mode(self, tile_mode):
@ -807,16 +805,16 @@ class LatentDiffusion(DDPM):
if m.padding_modeY == m.padding_modeX: if m.padding_modeY == m.padding_modeX:
m.padding_mode = m.padding_modeX m.padding_mode = m.padding_modeX
m.paddingX = ( m.paddingX = (
m._reversed_padding_repeated_twice[0], # noqa m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1], # noqa m._reversed_padding_repeated_twice[1],
0, 0,
0, 0,
) )
m.paddingY = ( m.paddingY = (
0, 0,
0, 0,
m._reversed_padding_repeated_twice[2], # noqa m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3], # noqa m._reversed_padding_repeated_twice[3],
) )
def make_cond_schedule( def make_cond_schedule(
@ -896,9 +894,8 @@ class LatentDiffusion(DDPM):
elif isinstance(encoder_posterior, torch.Tensor): elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior z = encoder_posterior
else: else:
raise NotImplementedError( msg = f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" raise NotImplementedError(msg)
)
return self.scale_factor * z return self.scale_factor * z
def get_learned_conditioning(self, c): def get_learned_conditioning(self, c):
@ -967,7 +964,7 @@ class LatentDiffusion(DDPM):
:param x: img of size (bs, c, h, w) :param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
""" """
bs, nc, h, w = x.shape # noqa bs, nc, h, w = x.shape
# number of crops in image # number of crops in image
Ly = (h - kernel_size[0]) // stride[0] + 1 Ly = (h - kernel_size[0]) // stride[0] + 1
@ -1167,7 +1164,7 @@ class LatentDiffusion(DDPM):
ks = self.split_input_params["ks"] # eg. (128, 128) ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64) stride = self.split_input_params["stride"] # eg. (64, 64)
h, w = x_noisy.shape[-2:] # noqa h, w = x_noisy.shape[-2:]
fold, unfold, normalization, weighting = self.get_fold_unfold( fold, unfold, normalization, weighting = self.get_fold_unfold(
x_noisy, ks, stride x_noisy, ks, stride
@ -1239,9 +1236,7 @@ class LatentDiffusion(DDPM):
# tokenize crop coordinates for the bounding boxes of the respective patches # tokenize crop coordinates for the bounding boxes of the respective patches
patch_limits_tknzd = [ patch_limits_tknzd = [
torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[ # noqa torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(
None
].to( # noqa
self.device self.device
) )
for bbox in patch_limits for bbox in patch_limits
@ -1292,7 +1287,7 @@ class LatentDiffusion(DDPM):
return x_recon return x_recon
def p_losses(self, x_start, cond, t, noise=None): # noqa def p_losses(self, x_start, cond, t, noise=None):
noise = noise if noise is not None else torch.randn_like(x_start) noise = noise if noise is not None else torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = self.apply_model(x_noisy, t, cond) model_output = self.apply_model(x_noisy, t, cond)
@ -1374,7 +1369,7 @@ class LatentDiffusion(DDPM):
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample( # noqa def p_sample(
self, self,
x, x,
c, c,
@ -1609,7 +1604,7 @@ class LatentDiffusion(DDPM):
if inpaint: if inpaint:
# make a simple center square # make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3] b, h, w = z.shape[0], z.shape[2], z.shape[3] # noqa
mask = torch.ones(N, h, w).to(self.device) mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in # zeros will be filled in
mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0 mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0
@ -1674,9 +1669,8 @@ class LatentDiffusion(DDPM):
logger.info("Training the full unet") logger.info("Training the full unet")
params = list(self.model.parameters()) params = list(self.model.parameters())
else: else:
raise ValueError( msg = f"Unrecognised setting for unet_trainable: {self.unet_trainable}"
f"Unrecognised setting for unet_trainable: {self.unet_trainable}" raise ValueError(msg)
)
if self.cond_stage_trainable: if self.cond_stage_trainable:
logger.info( logger.info(
@ -1706,7 +1700,7 @@ class LatentDiffusion(DDPM):
def to_rgb(self, x): def to_rgb(self, x):
x = x.float() x = x.float()
if not hasattr(self, "colorize"): if not hasattr(self, "colorize"):
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) # noqa self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
x = nn.functional.conv2d(x, weight=self.colorize) x = nn.functional.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x return x
@ -1719,17 +1713,19 @@ class DiffusionWrapper(pl.LightningModule):
self.conditioning_key = conditioning_key self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, "concat", "crossattn", "hybrid", "adm"] assert self.conditioning_key in [None, "concat", "crossattn", "hybrid", "adm"]
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): def forward(
self, x, t, c_concat: Optional[list] = None, c_crossattn: Optional[list] = None
):
if self.conditioning_key is None: if self.conditioning_key is None:
out = self.diffusion_model(x, t) out = self.diffusion_model(x, t)
elif self.conditioning_key == "concat": elif self.conditioning_key == "concat":
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x, *c_concat], dim=1)
out = self.diffusion_model(xc, t) out = self.diffusion_model(xc, t)
elif self.conditioning_key == "crossattn": elif self.conditioning_key == "crossattn":
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc) out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == "hybrid": elif self.conditioning_key == "hybrid":
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x, *c_concat], dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc) out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == "adm": elif self.conditioning_key == "adm":
@ -1818,7 +1814,7 @@ class LatentFinetuneDiffusion(LatentDiffusion):
# print(f"Unexpected Keys: {unexpected}") # print(f"Unexpected Keys: {unexpected}")
@torch.no_grad() @torch.no_grad()
def log_images( # noqa def log_images(
self, self,
batch, batch,
N=8, N=8,
@ -1866,7 +1862,7 @@ class LatentFinetuneDiffusion(LatentDiffusion):
if not (self.c_concat_log_start is None and self.c_concat_log_end is None): if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
log["c_concat_decoded"] = self.decode_first_stage( log["c_concat_decoded"] = self.decode_first_stage(
c_cat[:, self.c_concat_log_start : self.c_concat_log_end] # noqa c_cat[:, self.c_concat_log_start : self.c_concat_log_end]
) )
if plot_diffusion_rows: if plot_diffusion_rows:
@ -1929,11 +1925,11 @@ class LatentFinetuneDiffusion(LatentDiffusion):
class LatentInpaintDiffusion(LatentDiffusion): class LatentInpaintDiffusion(LatentDiffusion):
def __init__( # noqa def __init__(
self, self,
concat_keys=("mask", "masked_image"), concat_keys=("mask", "masked_image"),
masked_image_key="masked_image", masked_image_key="masked_image",
finetune_keys=None, # noqa finetune_keys=None,
*args, *args,
**kwargs, **kwargs,
): ):

@ -16,8 +16,8 @@ XFORMERS_IS_AVAILABLE = False
try: try:
if get_device() == "cuda": if get_device() == "cuda":
import xformers # noqa import xformers
import xformers.ops # noqa import xformers.ops
XFORMERS_IS_AVAILABLE = True XFORMERS_IS_AVAILABLE = True
except ImportError: except ImportError:
@ -415,7 +415,7 @@ class Model(nn.Module):
) )
curr_res = resolution curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult) in_ch_mult = (1, *tuple(ch_mult))
self.down = nn.ModuleList() self.down = nn.ModuleList()
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
block = nn.ModuleList() block = nn.ModuleList()
@ -581,7 +581,7 @@ class Encoder(nn.Module):
) )
curr_res = resolution curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult) in_ch_mult = (1, *tuple(ch_mult))
self.in_ch_mult = in_ch_mult self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList() self.down = nn.ModuleList()
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
@ -853,10 +853,7 @@ class SimpleDecoder(nn.Module):
def forward(self, x): def forward(self, x):
for i, layer in enumerate(self.model): for i, layer in enumerate(self.model):
if i in [1, 2, 3]: x = layer(x, None) if i in [1, 2, 3] else layer(x)
x = layer(x, None)
else:
x = layer(x)
h = self.norm_out(x) h = self.norm_out(x)
h = silu(h) h = silu(h)

@ -1,5 +1,6 @@
import math import math
from abc import abstractmethod from abc import abstractmethod
from typing import Optional
import numpy as np import numpy as np
import torch as th import torch as th
@ -38,7 +39,7 @@ class AttentionPool2d(nn.Module):
spacial_dim: int, spacial_dim: int,
embed_dim: int, embed_dim: int,
num_heads_channels: int, num_heads_channels: int,
output_dim: int = None, output_dim: Optional[int] = None,
): ):
super().__init__() super().__init__()
self.positional_embedding = nn.Parameter( self.positional_embedding = nn.Parameter(
@ -519,10 +520,8 @@ class UNetModel(nn.Module):
self.num_res_blocks = len(channel_mult) * [num_res_blocks] self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else: else:
if len(num_res_blocks) != len(channel_mult): if len(num_res_blocks) != len(channel_mult):
raise ValueError( msg = "provide num_res_blocks either as an int (globally constant) or as a list/tuple (per-level) with the same length as channel_mult"
"provide num_res_blocks either as an int (globally constant) or " raise ValueError(msg)
"as a list/tuple (per-level) with the same length as channel_mult"
)
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None: if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not

@ -52,7 +52,8 @@ def make_beta_schedule(
** 0.5 ** 0.5
) )
else: else:
raise ValueError(f"schedule '{schedule}' unknown.") msg = f"schedule '{schedule}' unknown."
raise ValueError(msg)
return betas.numpy() return betas.numpy()
@ -80,9 +81,8 @@ def make_ddim_timesteps(
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
).astype(int) ).astype(int)
else: else:
raise NotImplementedError( msg = f'There is no ddim discretization method called "{ddim_discr_method}"'
f'There is no ddim discretization method called "{ddim_discr_method}"' raise NotImplementedError(msg)
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps # assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling) # add one to get the final alpha values right (the ones from first scale to data during sampling)
@ -93,7 +93,7 @@ def make_ddim_timesteps(
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta): def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta):
# select alphas for computing the variance schedule # select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps] alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) alphas_prev = np.asarray([alphacums[0], *alphacums[ddim_timesteps[:-1]].tolist()])
# according to the formula provided in https://arxiv.org/abs/2010.02502 # according to the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt( sigmas = eta * np.sqrt(
@ -151,7 +151,7 @@ def checkpoint(func, inputs, params, flag):
return func(*inputs) return func(*inputs)
class CheckpointFunction(torch.autograd.Function): # noqa class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, run_function, length, *args): def forward(ctx, run_function, length, *args):
ctx.run_function = run_function ctx.run_function = run_function
@ -180,7 +180,7 @@ class CheckpointFunction(torch.autograd.Function): # noqa
del ctx.input_tensors del ctx.input_tensors
del ctx.input_params del ctx.input_params
del output_tensors del output_tensors
return (None, None) + input_grads return (None, None, *input_grads)
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
@ -246,7 +246,7 @@ def normalization(channels):
class GroupNorm32(nn.GroupNorm): class GroupNorm32(nn.GroupNorm):
def forward(self, x): # noqa def forward(self, x):
return super().forward(x.float()).type(x.dtype) return super().forward(x.float()).type(x.dtype)
@ -260,7 +260,8 @@ def conv_nd(dims, *args, **kwargs):
return nn.Conv2d(*args, **kwargs) return nn.Conv2d(*args, **kwargs)
if dims == 3: if dims == 3:
return nn.Conv3d(*args, **kwargs) return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}") msg = f"unsupported dimensions: {dims}"
raise ValueError(msg)
def linear(*args, **kwargs): def linear(*args, **kwargs):
@ -278,7 +279,8 @@ def avg_pool_nd(dims, *args, **kwargs):
return nn.AvgPool2d(*args, **kwargs) return nn.AvgPool2d(*args, **kwargs)
if dims == 3: if dims == 3:
return nn.AvgPool3d(*args, **kwargs) return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}") msg = f"unsupported dimensions: {dims}"
raise ValueError(msg)
class HybridConditioner(nn.Module): class HybridConditioner(nn.Module):

@ -43,7 +43,7 @@ class ClassEmbedder(nn.Module):
return uc return uc
def disabled_train(self, mode=True): # noqa def disabled_train(self, mode=True):
""" """
For disabling train/eval mode. For disabling train/eval mode.

@ -61,9 +61,8 @@ def load_midas_transform(model_type="dpt_hybrid"):
) )
else: else:
assert ( msg = f"model_type '{model_type}' not implemented, use: --model_type large"
False raise NotImplementedError(msg)
), f"model_type '{model_type}' not implemented, use: --model_type large"
transform = Compose( transform = Compose(
[ [
@ -133,8 +132,8 @@ def load_model(model_type):
) )
else: else:
print(f"model_type '{model_type}' not implemented, use: --model_type large") msg = f"model_type '{model_type}' not implemented, use: --model_type large"
assert False raise NotImplementedError(msg)
transform = Compose( transform = Compose(
[ [
@ -155,13 +154,13 @@ def load_model(model_type):
return model.eval(), transform return model.eval(), transform
@lru_cache() @lru_cache
def midas_device(): def midas_device():
# mps returns incorrect results ~50% of the time # mps returns incorrect results ~50% of the time
return torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.device("cuda" if torch.cuda.is_available() else "cpu")
@lru_cache() @lru_cache
def load_midas(model_type="dpt_hybrid"): def load_midas(model_type="dpt_hybrid"):
model = MiDaSInference(model_type) model = MiDaSInference(model_type)
model.to(midas_device()) model.to(midas_device())

@ -4,7 +4,7 @@ from imaginairy import config
from imaginairy.model_manager import get_cached_url_path from imaginairy.model_manager import get_cached_url_path
class BaseModel(torch.nn.Module): # noqa class BaseModel(torch.nn.Module):
def load(self, path): def load(self, path):
""" """
Load model from file. Load model from file.

@ -56,8 +56,8 @@ def _make_encoder(
[32, 48, 136, 384], features, groups=groups, expand=expand [32, 48, 136, 384], features, groups=groups, expand=expand
) # efficientnet_lite3 ) # efficientnet_lite3
else: else:
print(f"Backbone '{backbone}' not implemented") msg = f"Backbone '{backbone}' not implemented"
assert False raise NotImplementedError(msg)
return pretrained, scratch return pretrained, scratch

@ -135,9 +135,8 @@ class Resize:
# fit height # fit height
scale_width = scale_height scale_width = scale_height
else: else:
raise ValueError( msg = f"resize_method {self.__resize_method} not implemented"
f"resize_method {self.__resize_method} not implemented" raise ValueError(msg)
)
if self.__resize_method == "lower_bound": if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of( new_height = self.constrain_to_multiple_of(
@ -157,7 +156,8 @@ class Resize:
new_height = self.constrain_to_multiple_of(scale_height * height) new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width) new_width = self.constrain_to_multiple_of(scale_width * width)
else: else:
raise ValueError(f"resize_method {self.__resize_method} not implemented") msg = f"resize_method {self.__resize_method} not implemented"
raise ValueError(msg)
return (new_width, new_height) return (new_width, new_height)

@ -22,10 +22,7 @@ class AddReadout(nn.Module):
self.start_index = start_index self.start_index = start_index
def forward(self, x): def forward(self, x):
if self.start_index == 2: readout = (x[:, 0] + x[:, 1]) / 2 if self.start_index == 2 else x[:, 0]
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
return x[:, self.start_index :] + readout.unsqueeze(1) return x[:, self.start_index :] + readout.unsqueeze(1)
@ -118,7 +115,7 @@ def _resize_pos_embed(self, posemb, gs_h, gs_w):
def forward_flex(self, x): def forward_flex(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
pos_embed = self._resize_pos_embed( # noqa pos_embed = self._resize_pos_embed(
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
) )
@ -174,9 +171,8 @@ def get_readout_oper(vit_features, features, use_readout, start_index=1):
ProjectReadout(vit_features, start_index) for out_feat in features ProjectReadout(vit_features, start_index) for out_feat in features
] ]
else: else:
assert ( msg = "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
False raise ValueError(msg)
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
return readout_oper return readout_oper
@ -288,7 +284,7 @@ def _make_vit_b16_backbone(
# We inject this function into the VisionTransformer instances so that # We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source. # we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType( # noqa pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model _resize_pos_embed, pretrained.model
) )
@ -469,7 +465,7 @@ def _make_vit_b_rn50_backbone(
# We inject this function into the VisionTransformer instances so that # We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source. # we can use it with interpolated position embeddings without modifying the library source.
pretrained.model._resize_pos_embed = types.MethodType( # noqa pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model _resize_pos_embed, pretrained.model
) )

@ -94,9 +94,8 @@ def write_pfm(path, image, scale=1):
): # greyscale ): # greyscale
color = False color = False
else: else:
raise ValueError( msg = "Image must have H x W x 3, H x W x 1 or H x W dimensions."
"Image must have H x W x 3, H x W x 1 or H x W dimensions." raise ValueError(msg)
)
file.write("PF\n" if color else b"Pf\n") file.write("PF\n" if color else b"Pf\n")
file.write(b"%d %d\n" % (image.shape[1], image.shape[0])) file.write(b"%d %d\n" % (image.shape[1], image.shape[0]))
@ -144,10 +143,7 @@ def resize_image(img):
height_orig = img.shape[0] height_orig = img.shape[0]
width_orig = img.shape[1] width_orig = img.shape[1]
if width_orig > height_orig: scale = width_orig / 384 if width_orig > height_orig else height_orig / 384
scale = width_orig / 384
else:
scale = height_orig / 384
height = (np.ceil(height_orig / scale / 32) * 32).astype(int) height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
width = (np.ceil(width_orig / scale / 32) * 32).astype(int) width = (np.ceil(width_orig / scale / 32) * 32).astype(int)

@ -217,15 +217,18 @@ def outpaint_arg_str_parse(arg_str):
for arg in args: for arg in args:
match = arg_pattern.match(arg) match = arg_pattern.match(arg)
if not match: if not match:
raise ValueError(f"Invalid outpaint argument '{arg}'") msg = f"Invalid outpaint argument '{arg}'"
raise ValueError(msg)
direction, amount = match.groups() direction, amount = match.groups()
direction = direction.lower() direction = direction.lower()
if len(direction) == 1: if len(direction) == 1:
if direction not in valid_direction_chars: if direction not in valid_direction_chars:
raise ValueError(f"Invalid outpaint direction '{direction}'") msg = f"Invalid outpaint direction '{direction}'"
raise ValueError(msg)
direction = valid_direction_chars[direction] direction = valid_direction_chars[direction]
elif direction not in valid_directions: elif direction not in valid_directions:
raise ValueError(f"Invalid outpaint direction '{direction}'") msg = f"Invalid outpaint direction '{direction}'"
raise ValueError(msg)
kwargs[direction] = int(amount) kwargs[direction] = int(amount)
if "all" in kwargs: if "all" in kwargs:

@ -11,13 +11,13 @@ def parse_schedule_str(schedule_str):
pattern = re.compile(r"([a-zA-Z0-9_-]+)\[([a-zA-Z0-9_:,. -]+)\]") pattern = re.compile(r"([a-zA-Z0-9_-]+)\[([a-zA-Z0-9_:,. -]+)\]")
match = pattern.match(schedule_str) match = pattern.match(schedule_str)
if not match: if not match:
raise ValueError(f"Invalid kwarg schedule: {schedule_str}") msg = f"Invalid kwarg schedule: {schedule_str}"
raise ValueError(msg)
arg_name = match.group(1).replace("-", "_") arg_name = match.group(1).replace("-", "_")
if not hasattr(ImaginePrompt(), arg_name): if not hasattr(ImaginePrompt(), arg_name):
raise ValueError( msg = f"Invalid kwarg schedule. Not a valid argument name: {arg_name}"
f"Invalid kwarg schedule. Not a valid argument name: {arg_name}" raise ValueError(msg)
)
arg_values = match.group(2) arg_values = match.group(2)
if ":" in arg_values: if ":" in arg_values:
@ -53,7 +53,7 @@ def prompt_mutator(prompt, schedules):
} }
""" """
schedule_length = len(list(schedules.values())[0]) schedule_length = len(next(iter(schedules.values())))
for i in range(schedule_length): for i in range(schedule_length):
new_prompt = copy(prompt) new_prompt = copy(prompt)
for attr_name, schedule in schedules.items(): for attr_name, schedule in schedules.items():

@ -24,7 +24,8 @@ def square_roi_coordinate(roi, max_width, max_height, best_effort=False):
width = x2 - x1 width = x2 - x1
height = y2 - y1 height = y2 - y1
if not best_effort and width != height: if not best_effort and width != height:
raise RuntimeError(f"ROI is not square: {width}x{height}") msg = f"ROI is not square: {width}x{height}"
raise RuntimeError(msg)
return x1, y1, x2, y2 return x1, y1, x2, y2
@ -96,8 +97,7 @@ def move_roi_into_bounds(roi, max_width, max_height, force=False):
if x1 < 0 or y1 < 0 or x2 > max_width or y2 > max_height: if x1 < 0 or y1 < 0 or x2 > max_width or y2 > max_height:
roi_width = x2 - x1 roi_width = x2 - x1
roi_height = y2 - y1 roi_height = y2 - y1
raise RoiNotInBoundsError( msg = f"Not possible to fit ROI into boundaries: {roi_width}x{roi_height} won't fit inside {max_width}x{max_height}"
f"Not possible to fit ROI into boundaries: {roi_width}x{roi_height} won't fit inside {max_width}x{max_height}" raise RoiNotInBoundsError(msg)
)
return x1, y1, x2, y2 return x1, y1, x2, y2

@ -113,7 +113,7 @@ class EnhancedStableDiffusionSafetyChecker(
return safety_results return safety_results
@lru_cache() @lru_cache
def safety_models(): def safety_models():
safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_model_id = "CompVis/stable-diffusion-safety-checker"
monkeypatch_safety_cosine_distance() monkeypatch_safety_cosine_distance()
@ -124,7 +124,7 @@ def safety_models():
return safety_feature_extractor, safety_checker return safety_feature_extractor, safety_checker
@lru_cache() @lru_cache
def monkeypatch_safety_cosine_distance(): def monkeypatch_safety_cosine_distance():
orig_cosine_distance = safety_checker_mod.cosine_distance orig_cosine_distance = safety_checker_mod.cosine_distance

@ -51,17 +51,19 @@ class InvalidUrlError(ValueError):
class LazyLoadingImage: class LazyLoadingImage:
"""Image file encoded as base64 string.""" """Image file encoded as base64 string."""
def __init__(self, *, filepath=None, url=None, img: Image = None, b64: str = None): def __init__(
self, *, filepath=None, url=None, img: Image = None, b64: Optional[str] = None
):
if not filepath and not url and not img and not b64: if not filepath and not url and not img and not b64:
raise ValueError( msg = "You must specify a url or filepath or img or base64 string"
"You must specify a url or filepath or img or base64 string" raise ValueError(msg)
)
if sum([bool(filepath), bool(url), bool(img), bool(b64)]) > 1: if sum([bool(filepath), bool(url), bool(img), bool(b64)]) > 1:
raise ValueError("You cannot multiple input methods") raise ValueError("You cannot multiple input methods")
# validate file exists # validate file exists
if filepath and not os.path.exists(filepath): if filepath and not os.path.exists(filepath):
raise FileNotFoundError(f"File does not exist: {filepath}") msg = f"File does not exist: {filepath}"
raise FileNotFoundError(msg)
# validate url is valid url # validate url is valid url
if url: if url:
@ -73,7 +75,8 @@ class LazyLoadingImage:
except LocationParseError: except LocationParseError:
raise InvalidUrlError(f"Invalid url: {url}") # noqa raise InvalidUrlError(f"Invalid url: {url}") # noqa
if parsed_url.scheme not in {"http", "https"} or not parsed_url.host: if parsed_url.scheme not in {"http", "https"} or not parsed_url.host:
raise InvalidUrlError(f"Invalid url: {url}") msg = f"Invalid url: {url}"
raise InvalidUrlError(msg)
if b64: if b64:
img = self.load_image_from_base64(b64) img = self.load_image_from_base64(b64)
@ -145,16 +148,14 @@ class LazyLoadingImage:
raise ValueError(msg) # noqa raise ValueError(msg) # noqa
if isinstance(value, dict): if isinstance(value, dict):
return cls(**value) return cls(**value)
raise ValueError( msg = "Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string"
"Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string" raise ValueError(msg)
)
def handle_b64(value: Any) -> "LazyLoadingImage": def handle_b64(value: Any) -> "LazyLoadingImage":
if isinstance(value, str): if isinstance(value, str):
return cls(b64=value) return cls(b64=value)
raise ValueError( msg = "Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string"
"Image value must be either a LazyLoadingImage, PIL.Image.Image or a Base64 string" raise ValueError(msg)
)
return core_schema.json_or_python_schema( return core_schema.json_or_python_schema(
json_schema=core_schema.chain_schema( json_schema=core_schema.chain_schema(
@ -349,15 +350,13 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return "" return ""
if not isinstance(v, str): if not isinstance(v, str):
raise ValueError( msg = f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}"
f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}" raise ValueError(msg) # noqa
)
v = v.lower() v = v.lower()
if v not in valid_tile_modes: if v not in valid_tile_modes:
raise ValueError( msg = f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}"
f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}" raise ValueError(msg)
)
return v return v
@field_validator("outpaint", mode="after") @field_validator("outpaint", mode="after")
@ -375,7 +374,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return v return v
if not isinstance(v, Tensor): if not isinstance(v, Tensor):
raise ValueError("conditioning must be a torch.Tensor") raise ValueError("conditioning must be a torch.Tensor") # noqa
return v return v
# @field_validator("init_image", "mask_image", mode="after") # @field_validator("init_image", "mask_image", mode="after")
@ -412,13 +411,15 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
@field_validator("mask_image") @field_validator("mask_image")
def validate_mask_image(cls, v, info: core_schema.FieldValidationInfo): def validate_mask_image(cls, v, info: core_schema.FieldValidationInfo):
if v is not None and info.data.get("mask_prompt") is not None: if v is not None and info.data.get("mask_prompt") is not None:
raise ValueError("You can only set one of `mask_image` and `mask_prompt`") msg = "You can only set one of `mask_image` and `mask_prompt`"
raise ValueError(msg)
return v return v
@field_validator("mask_prompt", "mask_image", mode="before") @field_validator("mask_prompt", "mask_image", mode="before")
def validate_mask_prompt(cls, v, info: core_schema.FieldValidationInfo): def validate_mask_prompt(cls, v, info: core_schema.FieldValidationInfo):
if info.data.get("init_image") is None and v: if info.data.get("init_image") is None and v:
raise ValueError("You must set `init_image` if you want to use a mask") msg = "You must set `init_image` if you want to use a mask"
raise ValueError(msg)
return v return v
@field_validator("model", mode="before") @field_validator("model", mode="before")
@ -455,9 +456,8 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
SamplerName.PLMS, SamplerName.PLMS,
SamplerName.DDIM, SamplerName.DDIM,
): ):
raise ValueError( msg = "PLMS and DDIM samplers are not supported for pix2pix edit model."
"PLMS and DDIM samplers are not supported for pix2pix edit model." raise ValueError(msg)
)
return v return v
@field_validator("steps") @field_validator("steps")
@ -620,7 +620,7 @@ class ImagineResult:
self.is_nsfw = is_nsfw self.is_nsfw = is_nsfw
self.safety_score = safety_score self.safety_score = safety_score
self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc) self.created_at = datetime.now(tz=timezone.utc)
self.torch_backend = get_device() self.torch_backend = get_device()
self.hardware_name = get_hardware_description(get_device()) self.hardware_name = get_hardware_description(get_device())
@ -655,9 +655,8 @@ class ImagineResult:
def save(self, save_path, image_type="generated"): def save(self, save_path, image_type="generated"):
img = self.images.get(image_type, None) img = self.images.get(image_type, None)
if img is None: if img is None:
raise ValueError( msg = f"Image of type {image_type} not stored. Options are: {self.images.keys()}"
f"Image of type {image_type} not stored. Options are: {self.images.keys()}" raise ValueError(msg)
)
img.convert("RGB").save(save_path, exif=self._exif()) img.convert("RGB").save(save_path, exif=self._exif())

@ -192,7 +192,7 @@ def surprise_me_prompts(
width=width, width=width,
height=height, height=height,
seed=seed, seed=seed,
**kwargs, # noqa **kwargs,
) )
) )
else: else:
@ -206,7 +206,7 @@ def surprise_me_prompts(
width=width, width=width,
height=height, height=height,
seed=seed, seed=seed,
**kwargs, # noqa **kwargs,
) )
) )

@ -20,6 +20,8 @@ except ImportError:
# let's not break all of imaginairy just because a training import doesn't exist in an older version of PL # let's not break all of imaginairy just because a training import doesn't exist in an older version of PL
# Use >= 1.6.0 to make this work # Use >= 1.6.0 to make this work
DDPStrategy = None DDPStrategy = None
import contextlib
from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_only
@ -220,10 +222,8 @@ class SetupCallback(Callback):
dst, name = os.path.split(self.logdir) dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, "child_runs", name) dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True) os.makedirs(os.path.split(dst)[0], exist_ok=True)
try: with contextlib.suppress(FileNotFoundError):
os.rename(self.logdir, dst) os.rename(self.logdir, dst)
except FileNotFoundError:
pass
class ImageLogger(Callback): class ImageLogger(Callback):
@ -342,11 +342,12 @@ class ImageLogger(Callback):
): ):
if not self.disabled and pl_module.global_step > 0: if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split="val") self.log_img(pl_module, batch, batch_idx, split="val")
if hasattr(pl_module, "calibrate_grad_norm"): if (
if ( hasattr(pl_module, "calibrate_grad_norm")
pl_module.calibrate_grad_norm and batch_idx % 25 == 0 and (pl_module.calibrate_grad_norm and batch_idx % 25 == 0)
) and batch_idx > 0: and batch_idx > 0
self.log_gradients(trainer, pl_module, batch_idx=batch_idx) ):
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
class CUDACallback(Callback): class CUDACallback(Callback):
@ -356,9 +357,9 @@ class CUDACallback(Callback):
if "cuda" in get_device(): if "cuda" in get_device():
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index) torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
torch.cuda.synchronize(trainer.strategy.root_device.index) torch.cuda.synchronize(trainer.strategy.root_device.index)
self.start_time = time.time() # noqa self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module): # noqa def on_train_epoch_end(self, trainer, pl_module):
if "cuda" in get_device(): if "cuda" in get_device():
torch.cuda.synchronize(trainer.strategy.root_device.index) torch.cuda.synchronize(trainer.strategy.root_device.index)
max_memory = ( max_memory = (
@ -394,19 +395,20 @@ def train_diffusion_model(
accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf
""" """
if DDPStrategy is None: if DDPStrategy is None:
raise ImportError("Please install pytorch-lightning>=1.6.0 to train a model") msg = "Please install pytorch-lightning>=1.6.0 to train a model"
raise ImportError(msg)
batch_size = 1 batch_size = 1
seed = 23 seed = 23
num_workers = 1 num_workers = 1
num_val_workers = 0 num_val_workers = 0
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # noqa: DTZ005
logdir = os.path.join(logdir, now) logdir = os.path.join(logdir, now)
ckpt_output_dir = os.path.join(logdir, "checkpoints") ckpt_output_dir = os.path.join(logdir, "checkpoints")
cfg_output_dir = os.path.join(logdir, "configs") cfg_output_dir = os.path.join(logdir, "configs")
seed_everything(seed) seed_everything(seed)
model = get_diffusion_model( # noqa model = get_diffusion_model(
weights_location=weights_location, half_mode=False, for_training=True weights_location=weights_location, half_mode=False, for_training=True
)._model )._model
model.learning_rate = learning_rate * accumulate_grad_batches * batch_size model.learning_rate = learning_rate * accumulate_grad_batches * batch_size
@ -501,9 +503,7 @@ def train_diffusion_model(
num_sanity_val_steps=0, num_sanity_val_steps=0,
accumulate_grad_batches=accumulate_grad_batches, accumulate_grad_batches=accumulate_grad_batches,
strategy=DDPStrategy(), strategy=DDPStrategy(),
callbacks=[ callbacks=[instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg],
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg # noqa
],
gpus=1, gpus=1,
default_root_dir=".", default_root_dir=".",
) )

@ -145,7 +145,7 @@ def create_class_images(class_description, output_folder, num_images=200):
while existing_image_count < num_images: while existing_image_count < num_images:
prompt = ImaginePrompt(class_description, steps=20) prompt = ImaginePrompt(class_description, steps=20)
result = list(imagine([prompt]))[0] result = next(iter(imagine([prompt])))
if result.is_nsfw: if result.is_nsfw:
continue continue
dest = os.path.join( dest = os.path.join(

@ -29,7 +29,7 @@ def prune_model_data(data, only_keep_ema=True):
data.pop("optimizer_states", None) data.pop("optimizer_states", None)
if only_keep_ema: if only_keep_ema:
state_dict = data["state_dict"] state_dict = data["state_dict"]
model_keys = [k for k in state_dict.keys() if k.startswith("model.")] model_keys = [k for k in state_dict if k.startswith("model.")]
for model_key in model_keys: for model_key in model_keys:
ema_key = "model_ema." + model_key[6:].replace(".", "") ema_key = "model_ema." + model_key[6:].replace(".", "")

@ -92,7 +92,8 @@ class SingleConceptDataset(Dataset):
try: try:
image = Image.open(img_path).convert("RGB") image = Image.open(img_path).convert("RGB")
except RuntimeError as e: except RuntimeError as e:
raise RuntimeError(f"Could not read image {img_path}") from e msg = f"Could not read image {img_path}"
raise RuntimeError(msg) from e
image = self.image_transforms(image) image = self.image_transforms(image)
data = {"image": image, "txt": txt} data = {"image": image, "txt": txt}
return data return data

@ -14,7 +14,7 @@ from torch.overrides import handle_torch_function, has_torch_function_variadic
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@lru_cache() @lru_cache
def get_device() -> str: def get_device() -> str:
"""Return the best torch backend available.""" """Return the best torch backend available."""
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -26,7 +26,7 @@ def get_device() -> str:
return "cpu" return "cpu"
@lru_cache() @lru_cache
def get_hardware_description(device_type: str) -> str: def get_hardware_description(device_type: str) -> str:
"""Description of the hardware being used.""" """Description of the hardware being used."""
desc = platform.platform() desc = platform.platform()
@ -185,10 +185,9 @@ def check_torch_working():
torch.randn(1, device=get_device()) torch.randn(1, device=get_device())
except RuntimeError as e: except RuntimeError as e:
if "CUDA" in str(e): if "CUDA" in str(e):
raise RuntimeError( msg = "CUDA is not working. Make sure you have a GPU and CUDA installed."
"CUDA is not working. Make sure you have a GPU and CUDA installed." raise RuntimeError(msg) from e
) from e raise
raise e
def frange(start, stop, step): def frange(start, stop, step):
@ -209,7 +208,7 @@ def shrink_list(items, max_size):
new_items = {} new_items = {}
for i, item in enumerate(items): for i, item in enumerate(items):
new_items[int(i / removal_ratio)] = item new_items[int(i / removal_ratio)] = item
return [items[0]] + list(new_items.values()) return [items[0], *list(new_items.values())]
def glob_expand_paths(paths): def glob_expand_paths(paths):

@ -1,3 +1,4 @@
import contextlib
import math import math
import sys import sys
from copy import deepcopy from copy import deepcopy
@ -77,7 +78,7 @@ class DataDistorter:
def __init__(self, data, add_data_values=True): def __init__(self, data, add_data_values=True):
self.data = deepcopy(data) self.data = deepcopy(data)
self.data_map, self.data_unique_values = create_node_map(self.data) self.data_map, self.data_unique_values = create_node_map(self.data)
self.distortion_values = DISTORTED_VALUES + [] self.distortion_values = [*DISTORTED_VALUES]
if add_data_values: if add_data_values:
self.distortion_values += list(self.data_unique_values) self.distortion_values += list(self.data_unique_values)
@ -141,15 +142,13 @@ def create_node_map(data: Union[dict, list, tuple]) -> Tuple[Dict[int, list], se
if isinstance(curr_data, dict): if isinstance(curr_data, dict):
for key, value in curr_data.items(): for key, value in curr_data.items():
_traverse(value, curr_path + [key]) _traverse(value, [*curr_path, key])
elif isinstance(curr_data, (list, tuple)): elif isinstance(curr_data, (list, tuple)):
for idx, item in enumerate(curr_data): for idx, item in enumerate(curr_data):
_traverse(item, curr_path + [idx]) _traverse(item, [*curr_path, idx])
else: else:
try: with contextlib.suppress(TypeError):
node_values.add(curr_data) node_values.add(curr_data)
except TypeError:
pass
_traverse(data, []) _traverse(data, [])
return node_map, node_values return node_map, node_values

@ -203,9 +203,8 @@ class GPUModelCache:
total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2) total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2)
pct_to_use = float(self._max_cpu_memory_gb[:-1]) / 100.0 pct_to_use = float(self._max_cpu_memory_gb[:-1]) / 100.0
return total_ram_gb * pct_to_use * (1024**3) return total_ram_gb * pct_to_use * (1024**3)
raise ValueError( msg = f"Invalid value for max_cpu_memory_gb: {self._max_cpu_memory_gb}"
f"Invalid value for max_cpu_memory_gb: {self._max_cpu_memory_gb}" raise ValueError(msg)
)
return self._max_cpu_memory_gb * (1024**3) return self._max_cpu_memory_gb * (1024**3)
@cached_property @cached_property
@ -224,9 +223,8 @@ class GPUModelCache:
total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2) total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2)
pct_to_use = float(self._max_gpu_memory_gb[:-1]) / 100.0 pct_to_use = float(self._max_gpu_memory_gb[:-1]) / 100.0
return total_ram_gb * pct_to_use * (1024**3) return total_ram_gb * pct_to_use * (1024**3)
raise ValueError( msg = f"Invalid value for max_gpu_memory_gb: {self._max_gpu_memory_gb}"
f"Invalid value for max_gpu_memory_gb: {self._max_gpu_memory_gb}" raise ValueError(msg)
)
return self._max_gpu_memory_gb * (1024**3) return self._max_gpu_memory_gb * (1024**3)
def _move_to_gpu(self, key, model): def _move_to_gpu(self, key, model):
@ -280,12 +278,12 @@ class GPUModelCache:
import torch import torch
if key not in self: if key not in self:
raise KeyError(f"The key {key} does not exist in the cache") msg = f"The key {key} does not exist in the cache"
raise KeyError(msg)
if key in self.cpu_cache: if key in self.cpu_cache and self.device != torch.device("cpu"):
if self.device != torch.device("cpu"): self.cpu_cache.move_to_end(key)
self.cpu_cache.move_to_end(key) self._move_to_gpu(key, self.cpu_cache[key])
self._move_to_gpu(key, self.cpu_cache[key])
if key in self.gpu_cache: if key in self.gpu_cache:
self.gpu_cache.move_to_end(key) self.gpu_cache.move_to_end(key)
@ -337,7 +335,7 @@ class MemoryManagedModelWrapper:
self._mmmw_kwargs = kwargs self._mmmw_kwargs = kwargs
self._mmmw_namespace = namespace self._mmmw_namespace = namespace
self._mmmw_estimated_ram_size_mb = estimated_ram_size_mb self._mmmw_estimated_ram_size_mb = estimated_ram_size_mb
self._mmmw_cache_key = (namespace,) + args + tuple(kwargs.items()) self._mmmw_cache_key = (namespace, *args, *tuple(kwargs.items()))
def _mmmw_load_model(self): def _mmmw_load_model(self):
if self._mmmw_cache_key not in self.__class__._mmmw_cache: if self._mmmw_cache_key not in self.__class__._mmmw_cache:

@ -85,7 +85,7 @@ class BatchedBrownianTree:
seed = [seed] seed = [seed]
self.batched = False self.batched = False
self.trees = [ self.trees = [
torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed # noqa
] ]
@staticmethod @staticmethod

@ -1,10 +1,6 @@
black black
coverage coverage
isort
ruff ruff
pycln
pylama
pylint
pytest pytest
pytest-randomly pytest-randomly
pytest-sugar pytest-sugar

@ -16,8 +16,6 @@ anyio==3.7.1
# via # via
# fastapi # fastapi
# starlette # starlette
astroid==2.15.8
# via pylint
async-timeout==4.0.3 async-timeout==4.0.3
# via aiohttp # via aiohttp
attrs==23.1.0 attrs==23.1.0
@ -36,7 +34,6 @@ click==8.1.7
# click-help-colors # click-help-colors
# click-shell # click-shell
# imaginAIry (setup.py) # imaginAIry (setup.py)
# typer
# uvicorn # uvicorn
click-help-colors==0.9.2 click-help-colors==0.9.2
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
@ -48,10 +45,8 @@ coverage==7.3.1
# via -r requirements-dev.in # via -r requirements-dev.in
cycler==0.12.0 cycler==0.12.0
# via matplotlib # via matplotlib
diffusers==0.21.3 diffusers==0.21.4
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
dill==0.3.7
# via pylint
einops==0.6.1 einops==0.6.1
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
exceptiongroup==1.1.3 exceptiongroup==1.1.3
@ -71,7 +66,7 @@ filelock==3.12.4
# transformers # transformers
filterpy==1.4.5 filterpy==1.4.5
# via facexlib # via facexlib
fonttools==4.42.1 fonttools==4.43.0
# via matplotlib # via matplotlib
frozenlist==1.4.0 frozenlist==1.4.0
# via # via
@ -104,18 +99,10 @@ importlib-metadata==6.8.0
# via diffusers # via diffusers
iniconfig==2.0.0 iniconfig==2.0.0
# via pytest # via pytest
isort==5.12.0
# via
# -r requirements-dev.in
# pylint
kiwisolver==1.4.5 kiwisolver==1.4.5
# via matplotlib # via matplotlib
kornia==0.7.0 kornia==0.7.0
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
lazy-object-proxy==1.9.0
# via astroid
libcst==1.0.1
# via pycln
lightning-utilities==0.9.0 lightning-utilities==0.9.0
# via # via
# pytorch-lightning # pytorch-lightning
@ -126,18 +113,12 @@ matplotlib==3.7.3
# via # via
# -c tests/constraints.txt # -c tests/constraints.txt
# filterpy # filterpy
mccabe==0.7.0
# via
# pylama
# pylint
multidict==6.0.4 multidict==6.0.4
# via # via
# aiohttp # aiohttp
# yarl # yarl
mypy-extensions==1.0.0 mypy-extensions==1.0.0
# via # via black
# black
# typing-inspect
numba==0.58.0 numba==0.58.0
# via facexlib # via facexlib
numpy==1.24.4 numpy==1.24.4
@ -178,9 +159,7 @@ packaging==23.1
# pytorch-lightning # pytorch-lightning
# transformers # transformers
pathspec==0.11.2 pathspec==0.11.2
# via # via black
# black
# pycln
pillow==10.0.1 pillow==10.0.1
# via # via
# diffusers # diffusers
@ -190,9 +169,7 @@ pillow==10.0.1
# matplotlib # matplotlib
# torchvision # torchvision
platformdirs==3.10.0 platformdirs==3.10.0
# via # via black
# black
# pylint
pluggy==1.3.0 pluggy==1.3.0
# via pytest # via pytest
protobuf==3.20.3 protobuf==3.20.3
@ -201,24 +178,12 @@ protobuf==3.20.3
# open-clip-torch # open-clip-torch
psutil==5.9.5 psutil==5.9.5
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
pycln==2.2.2
# via -r requirements-dev.in
pycodestyle==2.11.0
# via pylama
pydantic==2.4.2 pydantic==2.4.2
# via # via
# fastapi # fastapi
# imaginAIry (setup.py) # imaginAIry (setup.py)
pydantic-core==2.10.1 pydantic-core==2.10.1
# via pydantic # via pydantic
pydocstyle==6.3.0
# via pylama
pyflakes==3.1.0
# via pylama
pylama==8.4.1
# via -r requirements-dev.in
pylint==2.17.6
# via -r requirements-dev.in
pyparsing==3.1.1 pyparsing==3.1.1
# via matplotlib # via matplotlib
pytest==7.4.2 pytest==7.4.2
@ -237,9 +202,7 @@ pytorch-lightning==1.9.5
pyyaml==6.0.1 pyyaml==6.0.1
# via # via
# huggingface-hub # huggingface-hub
# libcst
# omegaconf # omegaconf
# pycln
# pytorch-lightning # pytorch-lightning
# responses # responses
# timm # timm
@ -280,8 +243,6 @@ six==1.16.0
# via python-dateutil # via python-dateutil
sniffio==1.3.0 sniffio==1.3.0
# via anyio # via anyio
snowballstemmer==2.2.0
# via pydocstyle
starlette==0.27.0 starlette==0.27.0
# via fastapi # via fastapi
termcolor==2.3.0 termcolor==2.3.0
@ -295,12 +256,7 @@ tokenizers==0.13.3
tomli==2.0.1 tomli==2.0.1
# via # via
# black # black
# pylint
# pytest # pytest
tomlkit==0.12.1
# via
# pycln
# pylint
torch==1.13.1 torch==1.13.1
# via # via
# facexlib # facexlib
@ -335,28 +291,20 @@ tqdm==4.66.1
# transformers # transformers
transformers==4.33.3 transformers==4.33.3
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
typer==0.9.0
# via pycln
types-pyyaml==6.0.12.12 types-pyyaml==6.0.12.12
# via responses # via responses
typing-extensions==4.8.0 typing-extensions==4.8.0
# via # via
# astroid
# black # black
# fastapi # fastapi
# huggingface-hub # huggingface-hub
# libcst
# lightning-utilities # lightning-utilities
# pydantic # pydantic
# pydantic-core # pydantic-core
# pytorch-lightning # pytorch-lightning
# torch # torch
# torchvision # torchvision
# typer
# typing-inspect
# uvicorn # uvicorn
typing-inspect==0.9.0
# via libcst
urllib3==2.0.5 urllib3==2.0.5
# via # via
# requests # requests
@ -367,8 +315,6 @@ wcwidth==0.2.7
# via ftfy # via ftfy
wheel==0.41.2 wheel==0.41.2
# via -r requirements-dev.in # via -r requirements-dev.in
wrapt==1.15.0
# via astroid
yarl==1.9.2 yarl==1.9.2
# via aiohttp # via aiohttp
zipp==3.17.0 zipp==3.17.0

@ -52,8 +52,8 @@ def main():
time.sleep(1) time.sleep(1)
controlnet_statedict = torch.load(controlnet_path, map_location="cpu") controlnet_statedict = torch.load(controlnet_path, map_location="cpu")
print("\n\nComparing reconstructed controlnet with original") print("\n\nComparing reconstructed controlnet with original")
for k in controlnet_statedict.keys(): for k in controlnet_statedict:
if k not in reconstituted_controlnet_statedict.keys(): if k not in reconstituted_controlnet_statedict:
print(f"Key {k} not in reconstituted") print(f"Key {k} not in reconstituted")
elif ( elif (
controlnet_statedict[k].shape controlnet_statedict[k].shape

@ -54,7 +54,7 @@ def make_txts():
with open(src_json, encoding="utf-8") as f: with open(src_json, encoding="utf-8") as f:
prompts = json.load(f) prompts = json.load(f)
categories = [] categories = []
for c in prompts.keys(): for c in prompts:
if any(c.startswith(p) for p in excluded_prefixes): if any(c.startswith(p) for p in excluded_prefixes):
continue continue
categories.append(c) categories.append(c)

@ -19,7 +19,7 @@ else:
entry_points = None entry_points = None
@lru_cache() @lru_cache
def get_git_revision_hash() -> str: def get_git_revision_hash() -> str:
try: try:
return ( return (

@ -1,3 +1,4 @@
import contextlib
import logging import logging
import os import os
import sys import sys
@ -33,16 +34,14 @@ elif get_device() == "cpu":
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def pre_setup(): def _pre_setup():
api.IMAGINAIRY_SAFETY_MODE = "disabled" api.IMAGINAIRY_SAFETY_MODE = "disabled"
suppress_annoying_logs_and_warnings() suppress_annoying_logs_and_warnings()
test_output_folder = f"{TESTS_FOLDER}/test_output" test_output_folder = f"{TESTS_FOLDER}/test_output"
# delete the testoutput folder and recreate it # delete the testoutput folder and recreate it
try: with contextlib.suppress(FileNotFoundError):
rmtree(test_output_folder) rmtree(test_output_folder)
except FileNotFoundError:
pass
os.makedirs(test_output_folder, exist_ok=True) os.makedirs(test_output_folder, exist_ok=True)
orig_urlopen = HTTPConnectionPool.urlopen orig_urlopen = HTTPConnectionPool.urlopen
@ -73,7 +72,7 @@ def pre_setup():
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_get_device(): def _reset_get_device():
get_device.cache_clear() get_device.cache_clear()
@ -94,7 +93,7 @@ def sampler_type(request):
return request.param return request.param
@pytest.fixture @pytest.fixture()
def mocked_responses(): def mocked_responses():
with responses.RequestsMock() as rsps: with responses.RequestsMock() as rsps:
yield rsps yield rsps

@ -12,7 +12,7 @@ blur_params = [
] ]
@pytest.mark.parametrize("img_path,expected", blur_params) @pytest.mark.parametrize(("img_path", "expected"), blur_params)
def test_calculate_blurriness_level(img_path, expected): def test_calculate_blurriness_level(img_path, expected):
img = Image.open(img_path) img = Image.open(img_path)

@ -15,7 +15,7 @@ def control_img_to_pillow_img(img_t):
control_mode_params = list(CONTROL_MODES.items()) control_mode_params = list(CONTROL_MODES.items())
@pytest.mark.parametrize("control_name,control_func", control_mode_params) @pytest.mark.parametrize(("control_name", "control_func"), control_mode_params)
def test_control_images(filename_base_for_outputs, control_func, control_name): def test_control_images(filename_base_for_outputs, control_func, control_name):
seed_everything(42) seed_everything(42)
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png") img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png")

@ -26,7 +26,7 @@ strat_combos = [
@pytest.mark.skipif(True, reason="Run manually as needed. Uses too much memory.") @pytest.mark.skipif(True, reason="Run manually as needed. Uses too much memory.")
@pytest.mark.parametrize("encode_strat,decode_strat", strat_combos) @pytest.mark.parametrize(("encode_strat", "decode_strat"), strat_combos)
def test_encode_decode(filename_base_for_outputs, encode_strat, decode_strat): def test_encode_decode(filename_base_for_outputs, encode_strat, decode_strat):
""" """
Test that encoding and decoding works. Test that encoding and decoding works.

@ -0,0 +1,14 @@
extend-ignore = ["E501", "G004", "PT005", "RET504", "SIM114", "TRY003", "TRY400", "TRY401", "RUF012", "RUF100"]
extend-exclude = ["imaginairy/vendored", "downloads", "other"]
extend-select = [
"I", "E", "W", "UP", "ASYNC", "BLE", "A001", "A002",
"C4", "DTZ", "T10", "EM", "ISC", "ICN", "G", "PIE", "PT",
"Q", "SIM", "TID", "TCH", "PLC", "PLE", "TRY", "RUF"
]
[isort]
combine-as-imports = true
[flake8-errmsg]
max-string-length = 50

@ -174,10 +174,7 @@ def test_img_to_img_fruit_2_gold(
filepath=os.path.join(TESTS_FOLDER, "data", "bowl_of_fruit.jpg") filepath=os.path.join(TESTS_FOLDER, "data", "bowl_of_fruit.jpg")
) )
target_steps = 25 target_steps = 25
if init_strength >= 1: needed_steps = 25 if init_strength >= 1 else int(target_steps / (1 - init_strength))
needed_steps = 25
else:
needed_steps = int(target_steps / (1 - init_strength))
prompt = ImaginePrompt( prompt = ImaginePrompt(
"a white bowl filled with gold coins", "a white bowl filled with gold coins",
prompt_strength=12, prompt_strength=12,

@ -126,7 +126,7 @@ boolean_mask_test_cases = [
] ]
@pytest.mark.parametrize("mask_text,expected", boolean_mask_test_cases) @pytest.mark.parametrize(("mask_text", "expected"), boolean_mask_test_cases)
def test_clip_mask_parser(mask_text, expected): def test_clip_mask_parser(mask_text, expected):
parsed = MASK_PROMPT.parseString(mask_text)[0][0] parsed = MASK_PROMPT.parseString(mask_text)[0][0]
assert str(parsed) == expected assert str(parsed) == expected

@ -46,7 +46,7 @@ cases = [
] ]
@pytest.mark.parametrize("img_ratio, tile_size, overlap_pct", cases) @pytest.mark.parametrize(("img_ratio", "tile_size", "overlap_pct"), cases)
def test_feather_tile_simple(img_ratio, tile_size, overlap_pct): def test_feather_tile_simple(img_ratio, tile_size, overlap_pct):
img = pillow_img_to_torch_image( img = pillow_img_to_torch_image(
LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bowl_of_fruit.jpg") LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bowl_of_fruit.jpg")

@ -23,7 +23,7 @@ def test_outpainting_outpaint(filename_base_for_outputs):
steps=20, steps=20,
seed=542906833, seed=542906833,
) )
result = list(imagine([prompt]))[0] result = next(iter(imagine([prompt])))
img_path = f"{filename_base_for_outputs}.png" img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=17000) assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=17000)
@ -37,6 +37,6 @@ outpaint_test_params = [
] ]
@pytest.mark.parametrize("arg_str, expected_kwargs", outpaint_test_params) @pytest.mark.parametrize(("arg_str", "expected_kwargs"), outpaint_test_params)
def test_outpaint_parse_kwargs(arg_str, expected_kwargs): def test_outpaint_parse_kwargs(arg_str, expected_kwargs):
assert outpaint_arg_str_parse(arg_str) == expected_kwargs assert outpaint_arg_str_parse(arg_str) == expected_kwargs

@ -5,7 +5,7 @@ from imaginairy.utils import frange
@pytest.mark.parametrize( @pytest.mark.parametrize(
"schedule_str,expected", ("schedule_str", "expected"),
[ [
("prompt_strength[2:40:1]", ("prompt_strength", list(range(2, 40)))), ("prompt_strength[2:40:1]", ("prompt_strength", list(range(2, 40)))),
("prompt_strength[2:40:0.5]", ("prompt_strength", list(frange(2, 40, 0.5)))), ("prompt_strength[2:40:0.5]", ("prompt_strength", list(frange(2, 40, 0.5)))),

@ -26,7 +26,7 @@ def _red_url(mocked_responses):
status=200, status=200,
content_type="image/png", content_type="image/png",
) )
yield url return url
@pytest.fixture(name="red_path") @pytest.fixture(name="red_path")

@ -107,7 +107,7 @@ def test_cache_ordering():
) )
cache.set("key-0", create_model_of_n_bytes(4_000_000)) cache.set("key-0", create_model_of_n_bytes(4_000_000))
assert list(cache.cpu_cache.keys()) == [] # noqa assert list(cache.cpu_cache.keys()) == []
assert list(cache.gpu_cache.keys()) == ["key-0"] assert list(cache.gpu_cache.keys()) == ["key-0"]
assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == ( assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == (
0, 0,
@ -115,7 +115,7 @@ def test_cache_ordering():
) )
cache.set("key-1", create_model_of_n_bytes(4_000_000)) cache.set("key-1", create_model_of_n_bytes(4_000_000))
assert list(cache.cpu_cache.keys()) == [] # noqa assert list(cache.cpu_cache.keys()) == []
assert list(cache.gpu_cache.keys()) == ["key-0", "key-1"] assert list(cache.gpu_cache.keys()) == ["key-0", "key-1"]
assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == ( assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == (
0, 0,

@ -68,7 +68,7 @@ def test_instantiate_from_config():
"params": {"year": 2002, "month": 10, "day": 1}, "params": {"year": 2002, "month": 10, "day": 1},
} }
o = instantiate_from_config(config) o = instantiate_from_config(config)
assert o == datetime(2002, 10, 1) assert o == datetime(2002, 10, 1) # noqa: DTZ001
config = "__is_first_stage__" config = "__is_first_stage__"
assert instantiate_from_config(config) is None assert instantiate_from_config(config) is None

@ -4,25 +4,3 @@ norecursedirs = build dist downloads other prolly_delete imaginairy/vendored
filterwarnings = filterwarnings =
ignore::DeprecationWarning ignore::DeprecationWarning
ignore::UserWarning ignore::UserWarning
[pylama]
format = pylint
skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads/*,imaginairy/vendored/*,testing_support/vastai_cli_official.py,.eggs/*
linters = pylint,pycodestyle,pyflakes,mypy
ignore =
Z999,C0103,C0201,C0301,C0302,C0114,C0115,C0116,C0415,
Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D406,D407,D413,D415,D417,
Z999,E203,E501,E1101,E1121,E1131,E1133,E1135,E1136,
Z999,R0901,R0902,R0903,R0904,R0193,R0912,R0913,R0914,R0915,R1702,
Z999,W0221,W0511,W0612,W0613,W0632,W1203
[pylama:tests/*]
ignore = C0104,C0114,C0116,D103,W0143,W0613
[pylama:*/__init__.py]
ignore = D104
[pylama:pylint]
generated_members=torch.*
extension-pkg-whitelist=pydantic

Loading…
Cancel
Save