mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
feature: replaces black formatter with ruff formatter
This commit is contained in:
parent
a2c38b3ec0
commit
7eef3bf628
23
.github/workflows/ci.yaml
vendored
23
.github/workflows/ci.yaml
vendored
@ -29,29 +29,14 @@ jobs:
|
||||
- name: Install Ruff
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: grep -E 'ruff==' requirements-dev.txt | xargs pip install
|
||||
- name: Format
|
||||
run: |
|
||||
echo "::add-matcher::.github/pylama_matcher.json"
|
||||
ruff format --config tests/ruff.toml . --check
|
||||
- name: Lint
|
||||
run: |
|
||||
echo "::add-matcher::.github/pylama_matcher.json"
|
||||
ruff --config tests/ruff.toml .
|
||||
autoformat:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4.5.0
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v3.2.4
|
||||
id: cache
|
||||
with:
|
||||
path: ${{ env.pythonLocation }}
|
||||
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements-dev.txt') }}-autoformat
|
||||
- name: Install Black
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: grep -E 'black==' requirements-dev.txt | xargs pip install
|
||||
- name: Lint
|
||||
run: |
|
||||
black --diff --fast .
|
||||
test-gpu:
|
||||
runs-on: nvidia-4090
|
||||
steps:
|
||||
|
2
Makefile
2
Makefile
@ -29,7 +29,7 @@ init: require_pyenv ## Setup a dev environment for local development.
|
||||
af: autoformat ## Alias for `autoformat`
|
||||
autoformat: ## Run the autoformatter.
|
||||
@-ruff check --config tests/ruff.toml . --fix-only
|
||||
@black .
|
||||
@ruff format --config tests/ruff.toml .
|
||||
|
||||
test: ## Run the tests.
|
||||
@pytest
|
||||
|
@ -102,9 +102,7 @@ class ControlNet(nn.Module):
|
||||
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
||||
|
||||
if context_dim is not None:
|
||||
assert (
|
||||
use_spatial_transformer
|
||||
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
||||
assert use_spatial_transformer, "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
||||
from omegaconf.listconfig import ListConfig
|
||||
|
||||
if isinstance(context_dim, ListConfig):
|
||||
|
@ -714,14 +714,17 @@ class DDPM(pl.LightningModule):
|
||||
|
||||
|
||||
def _TileModeConv2DConvForward(
|
||||
self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor # noqa
|
||||
self,
|
||||
input_tensor: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor, # noqa
|
||||
):
|
||||
if self.padding_modeX == self.padding_modeY:
|
||||
self.padding_mode = self.padding_modeX
|
||||
return self._orig_conv_forward(input, weight, bias)
|
||||
return self._orig_conv_forward(input_tensor, weight, bias)
|
||||
|
||||
w1 = F.pad(input, self.paddingX, mode=self.padding_modeX)
|
||||
del input
|
||||
w1 = F.pad(input_tensor, self.paddingX, mode=self.padding_modeX)
|
||||
del input_tensor
|
||||
|
||||
w2 = F.pad(w1, self.paddingY, mode=self.padding_modeY)
|
||||
del w1
|
||||
|
@ -494,9 +494,7 @@ class UNetModel(nn.Module):
|
||||
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
||||
|
||||
if context_dim is not None:
|
||||
assert (
|
||||
use_spatial_transformer
|
||||
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
||||
assert use_spatial_transformer, "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
||||
|
||||
if isinstance(context_dim, ListConfig):
|
||||
context_dim = list(context_dim)
|
||||
|
@ -32,7 +32,7 @@ class DPT(BaseModel):
|
||||
readout="project",
|
||||
channels_last=False,
|
||||
use_bn=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -32,14 +32,17 @@ TileModeType = Literal["", "x", "y", "xy"]
|
||||
|
||||
|
||||
def _tile_mode_conv2d_conv_forward(
|
||||
self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor # noqa
|
||||
self,
|
||||
tensor_input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor, # noqa
|
||||
):
|
||||
if self.padding_mode_x == self.padding_mode_y:
|
||||
self.padding_mode = self.padding_mode_x
|
||||
return self._orig_conv_forward(input, weight, bias)
|
||||
return self._orig_conv_forward(tensor_input, weight, bias)
|
||||
|
||||
w1 = F.pad(input, self.padding_x, mode=self.padding_modeX)
|
||||
del input
|
||||
w1 = F.pad(tensor_input, self.padding_x, mode=self.padding_modeX)
|
||||
del tensor_input
|
||||
|
||||
w2 = F.pad(w1, self.padding_y, mode=self.padding_modeY)
|
||||
del w1
|
||||
|
@ -625,7 +625,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
class BasicTransformerSingleLayerBlock(nn.Module):
|
||||
ATTENTION_MODES = {
|
||||
"softmax": CrossAttention, # vanilla attention
|
||||
"softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
|
||||
"softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
|
||||
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
|
||||
}
|
||||
|
||||
|
@ -52,8 +52,9 @@ class LPIPS(nn.Module):
|
||||
feats0, feats1, diffs = {}, {}, {}
|
||||
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
||||
for kk in range(len(self.chns)):
|
||||
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
|
||||
outs1[kk]
|
||||
feats0[kk], feats1[kk] = (
|
||||
normalize_tensor(outs0[kk]),
|
||||
normalize_tensor(outs1[kk]),
|
||||
)
|
||||
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
||||
|
||||
|
@ -27,8 +27,7 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
|
||||
return sigma_to, 0.0
|
||||
sigma_up = torch.minimum(
|
||||
sigma_to,
|
||||
eta
|
||||
* (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
|
||||
eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
|
||||
)
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
return sigma_down, sigma_up
|
||||
|
@ -468,8 +468,8 @@ class VideoUNet(nn.Module):
|
||||
num_video_frames: Optional[int] = None,
|
||||
image_only_indicator: Optional[th.Tensor] = None,
|
||||
):
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
assert (
|
||||
(y is not None) == (self.num_classes is not None)
|
||||
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
|
@ -650,7 +650,8 @@ def open_weights(filepath, device=None):
|
||||
|
||||
with safe_open(path=filepath, framework="pytorch", device=device) as tensors:
|
||||
state_dict = {
|
||||
key: tensors.get_tensor(key) for key in tensors.keys() # noqa
|
||||
key: tensors.get_tensor(key)
|
||||
for key in tensors.keys() # noqa
|
||||
}
|
||||
else:
|
||||
import torch
|
||||
|
@ -1,4 +1,3 @@
|
||||
black
|
||||
coverage
|
||||
httpx
|
||||
mypy
|
||||
|
@ -5,26 +5,21 @@
|
||||
# pip-compile --output-file=requirements-dev.txt requirements-dev.in setup.py
|
||||
#
|
||||
aiohttp==3.9.1
|
||||
# via
|
||||
# black
|
||||
# fsspec
|
||||
# via fsspec
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
annotated-types==0.6.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via omegaconf
|
||||
anyio==3.7.1
|
||||
anyio==4.2.0
|
||||
# via
|
||||
# fastapi
|
||||
# httpx
|
||||
# starlette
|
||||
async-timeout==4.0.3
|
||||
# via aiohttp
|
||||
attrs==23.1.0
|
||||
# via aiohttp
|
||||
black==23.12.0
|
||||
# via -r requirements-dev.in
|
||||
certifi==2023.11.17
|
||||
# via
|
||||
# httpcore
|
||||
@ -34,7 +29,6 @@ charset-normalizer==3.3.2
|
||||
# via requests
|
||||
click==8.1.7
|
||||
# via
|
||||
# black
|
||||
# click-help-colors
|
||||
# click-shell
|
||||
# imaginAIry (setup.py)
|
||||
@ -45,11 +39,11 @@ click-shell==2.1
|
||||
# via imaginAIry (setup.py)
|
||||
contourpy==1.2.0
|
||||
# via matplotlib
|
||||
coverage==7.3.3
|
||||
coverage==7.4.0
|
||||
# via -r requirements-dev.in
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
diffusers==0.24.0
|
||||
diffusers==0.25.0
|
||||
# via imaginAIry (setup.py)
|
||||
einops==0.7.0
|
||||
# via imaginAIry (setup.py)
|
||||
@ -61,7 +55,7 @@ facexlib==0.3.0
|
||||
# via imaginAIry (setup.py)
|
||||
fairscale==0.4.13
|
||||
# via imaginAIry (setup.py)
|
||||
fastapi==0.105.0
|
||||
fastapi==0.108.0
|
||||
# via imaginAIry (setup.py)
|
||||
filelock==3.13.1
|
||||
# via
|
||||
@ -71,7 +65,7 @@ filelock==3.13.1
|
||||
# transformers
|
||||
filterpy==1.4.5
|
||||
# via facexlib
|
||||
fonttools==4.46.0
|
||||
fonttools==4.47.0
|
||||
# via matplotlib
|
||||
frozenlist==1.4.1
|
||||
# via
|
||||
@ -92,9 +86,9 @@ h11==0.14.0
|
||||
# uvicorn
|
||||
httpcore==1.0.2
|
||||
# via httpx
|
||||
httpx==0.25.2
|
||||
httpx==0.26.0
|
||||
# via -r requirements-dev.in
|
||||
huggingface-hub==0.19.4
|
||||
huggingface-hub==0.20.1
|
||||
# via
|
||||
# diffusers
|
||||
# open-clip-torch
|
||||
@ -109,7 +103,7 @@ idna==3.6
|
||||
# yarl
|
||||
imageio==2.33.1
|
||||
# via imaginAIry (setup.py)
|
||||
importlib-metadata==7.0.0
|
||||
importlib-metadata==7.0.1
|
||||
# via diffusers
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
@ -119,7 +113,7 @@ jinja2==3.1.2
|
||||
# via torch
|
||||
kiwisolver==1.4.5
|
||||
# via matplotlib
|
||||
kornia==0.7.0
|
||||
kornia==0.7.1
|
||||
# via imaginAIry (setup.py)
|
||||
lightning-utilities==0.10.0
|
||||
# via
|
||||
@ -139,12 +133,10 @@ multidict==6.0.4
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
mypy==1.7.1
|
||||
mypy==1.8.0
|
||||
# via -r requirements-dev.in
|
||||
mypy-extensions==1.0.0
|
||||
# via
|
||||
# black
|
||||
# mypy
|
||||
# via mypy
|
||||
networkx==3.2.1
|
||||
# via torch
|
||||
numba==0.58.1
|
||||
@ -179,7 +171,6 @@ opencv-python==4.8.1.78
|
||||
# imaginAIry (setup.py)
|
||||
packaging==23.2
|
||||
# via
|
||||
# black
|
||||
# huggingface-hub
|
||||
# kornia
|
||||
# lightning-utilities
|
||||
@ -189,8 +180,6 @@ packaging==23.2
|
||||
# pytorch-lightning
|
||||
# torchmetrics
|
||||
# transformers
|
||||
pathspec==0.12.1
|
||||
# via black
|
||||
pillow==10.1.0
|
||||
# via
|
||||
# diffusers
|
||||
@ -200,21 +189,19 @@ pillow==10.1.0
|
||||
# matplotlib
|
||||
# refiners
|
||||
# torchvision
|
||||
platformdirs==4.1.0
|
||||
# via black
|
||||
pluggy==1.3.0
|
||||
# via pytest
|
||||
protobuf==4.25.1
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# open-clip-torch
|
||||
psutil==5.9.6
|
||||
psutil==5.9.7
|
||||
# via imaginAIry (setup.py)
|
||||
pydantic==2.5.2
|
||||
pydantic==2.5.3
|
||||
# via
|
||||
# fastapi
|
||||
# imaginAIry (setup.py)
|
||||
pydantic-core==2.14.5
|
||||
pydantic-core==2.14.6
|
||||
# via pydantic
|
||||
pyparsing==3.1.1
|
||||
# via matplotlib
|
||||
@ -244,7 +231,7 @@ pyyaml==6.0.1
|
||||
# transformers
|
||||
refiners==0.2.0
|
||||
# via imaginAIry (setup.py)
|
||||
regex==2023.10.3
|
||||
regex==2023.12.25
|
||||
# via
|
||||
# diffusers
|
||||
# open-clip-torch
|
||||
@ -260,7 +247,7 @@ requests==2.31.0
|
||||
# transformers
|
||||
responses==0.24.1
|
||||
# via -r requirements-dev.in
|
||||
ruff==0.1.8
|
||||
ruff==0.1.9
|
||||
# via -r requirements-dev.in
|
||||
safetensors==0.3.3
|
||||
# via
|
||||
@ -283,7 +270,7 @@ sniffio==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
starlette==0.27.0
|
||||
starlette==0.32.0.post1
|
||||
# via fastapi
|
||||
sympy==1.12
|
||||
# via torch
|
||||
@ -297,7 +284,6 @@ tokenizers==0.15.0
|
||||
# via transformers
|
||||
tomli==2.0.1
|
||||
# via
|
||||
# black
|
||||
# mypy
|
||||
# pytest
|
||||
torch==2.1.2
|
||||
@ -333,7 +319,7 @@ tqdm==4.66.1
|
||||
# open-clip-torch
|
||||
# pytorch-lightning
|
||||
# transformers
|
||||
transformers==4.36.1
|
||||
transformers==4.36.2
|
||||
# via imaginAIry (setup.py)
|
||||
typeguard==2.13.3
|
||||
# via jaxtyping
|
||||
@ -347,7 +333,7 @@ types-tqdm==4.66.0.5
|
||||
# via -r requirements-dev.in
|
||||
typing-extensions==4.9.0
|
||||
# via
|
||||
# black
|
||||
# anyio
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# jaxtyping
|
||||
@ -363,7 +349,7 @@ urllib3==2.1.0
|
||||
# requests
|
||||
# responses
|
||||
# types-requests
|
||||
uvicorn==0.24.0.post1
|
||||
uvicorn==0.25.0
|
||||
# via imaginAIry (setup.py)
|
||||
wcwidth==0.2.12
|
||||
# via ftfy
|
||||
|
@ -1,4 +1,4 @@
|
||||
extend-ignore = ["E501", "G004", "PT004", "PT005", "RET504", "SIM114", "TRY003", "TRY400", "TRY401", "RUF012", "RUF100"]
|
||||
extend-ignore = ["E501", "G004", "PT004", "PT005", "RET504", "SIM114", "TRY003", "TRY400", "TRY401", "RUF012", "RUF100", "ISC001"]
|
||||
extend-exclude = ["imaginairy/vendored", "downloads", "other"]
|
||||
|
||||
extend-select = [
|
||||
|
Loading…
Reference in New Issue
Block a user