feature: replaces black formatter with ruff formatter

This commit is contained in:
jaydrennan 2023-12-27 14:53:05 -08:00
parent a2c38b3ec0
commit 7eef3bf628
15 changed files with 53 additions and 80 deletions

View File

@ -29,29 +29,14 @@ jobs:
- name: Install Ruff
if: steps.cache.outputs.cache-hit != 'true'
run: grep -E 'ruff==' requirements-dev.txt | xargs pip install
- name: Format
run: |
echo "::add-matcher::.github/pylama_matcher.json"
ruff format --config tests/ruff.toml . --check
- name: Lint
run: |
echo "::add-matcher::.github/pylama_matcher.json"
ruff --config tests/ruff.toml .
autoformat:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4.5.0
with:
python-version: "3.10"
- name: Cache dependencies
uses: actions/cache@v3.2.4
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements-dev.txt') }}-autoformat
- name: Install Black
if: steps.cache.outputs.cache-hit != 'true'
run: grep -E 'black==' requirements-dev.txt | xargs pip install
- name: Lint
run: |
black --diff --fast .
test-gpu:
runs-on: nvidia-4090
steps:

View File

@ -29,7 +29,7 @@ init: require_pyenv ## Setup a dev environment for local development.
af: autoformat ## Alias for `autoformat`
autoformat: ## Run the autoformatter.
@-ruff check --config tests/ruff.toml . --fix-only
@black .
@ruff format --config tests/ruff.toml .
test: ## Run the tests.
@pytest

View File

@ -102,9 +102,7 @@ class ControlNet(nn.Module):
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
if context_dim is not None:
assert (
use_spatial_transformer
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
assert use_spatial_transformer, "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
from omegaconf.listconfig import ListConfig
if isinstance(context_dim, ListConfig):

View File

@ -714,14 +714,17 @@ class DDPM(pl.LightningModule):
def _TileModeConv2DConvForward(
self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor # noqa
self,
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor, # noqa
):
if self.padding_modeX == self.padding_modeY:
self.padding_mode = self.padding_modeX
return self._orig_conv_forward(input, weight, bias)
return self._orig_conv_forward(input_tensor, weight, bias)
w1 = F.pad(input, self.paddingX, mode=self.padding_modeX)
del input
w1 = F.pad(input_tensor, self.paddingX, mode=self.padding_modeX)
del input_tensor
w2 = F.pad(w1, self.paddingY, mode=self.padding_modeY)
del w1

View File

@ -494,9 +494,7 @@ class UNetModel(nn.Module):
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
if context_dim is not None:
assert (
use_spatial_transformer
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
assert use_spatial_transformer, "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
if isinstance(context_dim, ListConfig):
context_dim = list(context_dim)

View File

@ -32,7 +32,7 @@ class DPT(BaseModel):
readout="project",
channels_last=False,
use_bn=False,
**kwargs
**kwargs,
):
super().__init__()

View File

@ -32,14 +32,17 @@ TileModeType = Literal["", "x", "y", "xy"]
def _tile_mode_conv2d_conv_forward(
self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor # noqa
self,
tensor_input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor, # noqa
):
if self.padding_mode_x == self.padding_mode_y:
self.padding_mode = self.padding_mode_x
return self._orig_conv_forward(input, weight, bias)
return self._orig_conv_forward(tensor_input, weight, bias)
w1 = F.pad(input, self.padding_x, mode=self.padding_modeX)
del input
w1 = F.pad(tensor_input, self.padding_x, mode=self.padding_modeX)
del tensor_input
w2 = F.pad(w1, self.padding_y, mode=self.padding_modeY)
del w1

View File

@ -625,7 +625,7 @@ class BasicTransformerBlock(nn.Module):
class BasicTransformerSingleLayerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
"softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
}

View File

@ -52,8 +52,9 @@ class LPIPS(nn.Module):
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
outs1[kk]
feats0[kk], feats1[kk] = (
normalize_tensor(outs0[kk]),
normalize_tensor(outs1[kk]),
)
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

View File

@ -27,8 +27,7 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
return sigma_to, 0.0
sigma_up = torch.minimum(
sigma_to,
eta
* (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
)
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
return sigma_down, sigma_up

View File

@ -468,8 +468,8 @@ class VideoUNet(nn.Module):
num_video_frames: Optional[int] = None,
image_only_indicator: Optional[th.Tensor] = None,
):
assert (y is not None) == (
self.num_classes is not None
assert (
(y is not None) == (self.num_classes is not None)
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)

View File

@ -650,7 +650,8 @@ def open_weights(filepath, device=None):
with safe_open(path=filepath, framework="pytorch", device=device) as tensors:
state_dict = {
key: tensors.get_tensor(key) for key in tensors.keys() # noqa
key: tensors.get_tensor(key)
for key in tensors.keys() # noqa
}
else:
import torch

View File

@ -1,4 +1,3 @@
black
coverage
httpx
mypy

View File

@ -5,26 +5,21 @@
# pip-compile --output-file=requirements-dev.txt requirements-dev.in setup.py
#
aiohttp==3.9.1
# via
# black
# fsspec
# via fsspec
aiosignal==1.3.1
# via aiohttp
annotated-types==0.6.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via omegaconf
anyio==3.7.1
anyio==4.2.0
# via
# fastapi
# httpx
# starlette
async-timeout==4.0.3
# via aiohttp
attrs==23.1.0
# via aiohttp
black==23.12.0
# via -r requirements-dev.in
certifi==2023.11.17
# via
# httpcore
@ -34,7 +29,6 @@ charset-normalizer==3.3.2
# via requests
click==8.1.7
# via
# black
# click-help-colors
# click-shell
# imaginAIry (setup.py)
@ -45,11 +39,11 @@ click-shell==2.1
# via imaginAIry (setup.py)
contourpy==1.2.0
# via matplotlib
coverage==7.3.3
coverage==7.4.0
# via -r requirements-dev.in
cycler==0.12.1
# via matplotlib
diffusers==0.24.0
diffusers==0.25.0
# via imaginAIry (setup.py)
einops==0.7.0
# via imaginAIry (setup.py)
@ -61,7 +55,7 @@ facexlib==0.3.0
# via imaginAIry (setup.py)
fairscale==0.4.13
# via imaginAIry (setup.py)
fastapi==0.105.0
fastapi==0.108.0
# via imaginAIry (setup.py)
filelock==3.13.1
# via
@ -71,7 +65,7 @@ filelock==3.13.1
# transformers
filterpy==1.4.5
# via facexlib
fonttools==4.46.0
fonttools==4.47.0
# via matplotlib
frozenlist==1.4.1
# via
@ -92,9 +86,9 @@ h11==0.14.0
# uvicorn
httpcore==1.0.2
# via httpx
httpx==0.25.2
httpx==0.26.0
# via -r requirements-dev.in
huggingface-hub==0.19.4
huggingface-hub==0.20.1
# via
# diffusers
# open-clip-torch
@ -109,7 +103,7 @@ idna==3.6
# yarl
imageio==2.33.1
# via imaginAIry (setup.py)
importlib-metadata==7.0.0
importlib-metadata==7.0.1
# via diffusers
iniconfig==2.0.0
# via pytest
@ -119,7 +113,7 @@ jinja2==3.1.2
# via torch
kiwisolver==1.4.5
# via matplotlib
kornia==0.7.0
kornia==0.7.1
# via imaginAIry (setup.py)
lightning-utilities==0.10.0
# via
@ -139,12 +133,10 @@ multidict==6.0.4
# via
# aiohttp
# yarl
mypy==1.7.1
mypy==1.8.0
# via -r requirements-dev.in
mypy-extensions==1.0.0
# via
# black
# mypy
# via mypy
networkx==3.2.1
# via torch
numba==0.58.1
@ -179,7 +171,6 @@ opencv-python==4.8.1.78
# imaginAIry (setup.py)
packaging==23.2
# via
# black
# huggingface-hub
# kornia
# lightning-utilities
@ -189,8 +180,6 @@ packaging==23.2
# pytorch-lightning
# torchmetrics
# transformers
pathspec==0.12.1
# via black
pillow==10.1.0
# via
# diffusers
@ -200,21 +189,19 @@ pillow==10.1.0
# matplotlib
# refiners
# torchvision
platformdirs==4.1.0
# via black
pluggy==1.3.0
# via pytest
protobuf==4.25.1
# via
# imaginAIry (setup.py)
# open-clip-torch
psutil==5.9.6
psutil==5.9.7
# via imaginAIry (setup.py)
pydantic==2.5.2
pydantic==2.5.3
# via
# fastapi
# imaginAIry (setup.py)
pydantic-core==2.14.5
pydantic-core==2.14.6
# via pydantic
pyparsing==3.1.1
# via matplotlib
@ -244,7 +231,7 @@ pyyaml==6.0.1
# transformers
refiners==0.2.0
# via imaginAIry (setup.py)
regex==2023.10.3
regex==2023.12.25
# via
# diffusers
# open-clip-torch
@ -260,7 +247,7 @@ requests==2.31.0
# transformers
responses==0.24.1
# via -r requirements-dev.in
ruff==0.1.8
ruff==0.1.9
# via -r requirements-dev.in
safetensors==0.3.3
# via
@ -283,7 +270,7 @@ sniffio==1.3.0
# via
# anyio
# httpx
starlette==0.27.0
starlette==0.32.0.post1
# via fastapi
sympy==1.12
# via torch
@ -297,7 +284,6 @@ tokenizers==0.15.0
# via transformers
tomli==2.0.1
# via
# black
# mypy
# pytest
torch==2.1.2
@ -333,7 +319,7 @@ tqdm==4.66.1
# open-clip-torch
# pytorch-lightning
# transformers
transformers==4.36.1
transformers==4.36.2
# via imaginAIry (setup.py)
typeguard==2.13.3
# via jaxtyping
@ -347,7 +333,7 @@ types-tqdm==4.66.0.5
# via -r requirements-dev.in
typing-extensions==4.9.0
# via
# black
# anyio
# fastapi
# huggingface-hub
# jaxtyping
@ -363,7 +349,7 @@ urllib3==2.1.0
# requests
# responses
# types-requests
uvicorn==0.24.0.post1
uvicorn==0.25.0
# via imaginAIry (setup.py)
wcwidth==0.2.12
# via ftfy

View File

@ -1,4 +1,4 @@
extend-ignore = ["E501", "G004", "PT004", "PT005", "RET504", "SIM114", "TRY003", "TRY400", "TRY401", "RUF012", "RUF100"]
extend-ignore = ["E501", "G004", "PT004", "PT005", "RET504", "SIM114", "TRY003", "TRY400", "TRY401", "RUF012", "RUF100", "ISC001"]
extend-exclude = ["imaginairy/vendored", "downloads", "other"]
extend-select = [