build: remove pytorch lightning dependency

This commit is contained in:
Bryce 2024-01-02 06:54:31 -08:00 committed by Bryce Drennan
parent 9f33fa0664
commit 5b3b04b877
14 changed files with 37 additions and 77 deletions

View File

@ -2,6 +2,7 @@ import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from imaginairy.api.generate import IMAGINAIRY_SAFETY_MODE from imaginairy.api.generate import IMAGINAIRY_SAFETY_MODE
from imaginairy.utils import seed_everything
from imaginairy.utils.img_utils import calc_scale_to_fit_within, combine_image from imaginairy.utils.img_utils import calc_scale_to_fit_within, combine_image
from imaginairy.utils.named_resolutions import normalize_image_size from imaginairy.utils.named_resolutions import normalize_image_size
@ -25,7 +26,6 @@ def _generate_single_image_compvis(
): ):
import torch.nn import torch.nn
from PIL import Image, ImageOps from PIL import Image, ImageOps
from pytorch_lightning import seed_everything
from imaginairy.enhancers.clip_masking import get_img_mask from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.describe_image_blip import generate_caption from imaginairy.enhancers.describe_image_blip import generate_caption

View File

@ -6,7 +6,7 @@ from typing import Any
from imaginairy.config import CONTROL_CONFIG_SHORTCUTS from imaginairy.config import CONTROL_CONFIG_SHORTCUTS
from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode
from imaginairy.utils import clear_gpu_cache from imaginairy.utils import clear_gpu_cache, seed_everything
from imaginairy.utils.log_utils import ImageLoggingContext from imaginairy.utils.log_utils import ImageLoggingContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,7 +27,6 @@ def generate_single_image(
): ):
import torch.nn import torch.nn
from PIL import Image, ImageOps from PIL import Image, ImageOps
from pytorch_lightning import seed_everything
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DPMSolver from refiners.foundationals.latent_diffusion.schedulers import DDIM, DPMSolver
from tqdm import tqdm from tqdm import tqdm

View File

@ -15,7 +15,7 @@ def create_canny_edges(img: "Tensor") -> "Tensor":
img = torch.clamp((img + 1.0) / 2.0, min=0.0, max=1.0) img = torch.clamp((img + 1.0) / 2.0, min=0.0, max=1.0)
img = einops.rearrange(img[0], "c h w -> h w c") img = einops.rearrange(img[0], "c h w -> h w c")
img = (255.0 * img).cpu().numpy().astype(np.uint8).squeeze() img = (255.0 * img).cpu().numpy().astype(np.uint8).squeeze()
blurred = cv2.GaussianBlur(img, (5, 5), 0).astype(np.uint8) blurred = cv2.GaussianBlur(img, (5, 5), 0).astype(np.uint8) # type: ignore
if len(blurred.shape) > 2: if len(blurred.shape) > 2:
blurred = cv2.cvtColor(blurred, cv2.COLOR_BGR2GRAY) blurred = cv2.cvtColor(blurred, cv2.COLOR_BGR2GRAY)
@ -143,7 +143,7 @@ def make_noise_disk(H: int, W: int, C: int, F: int) -> "np.ndarray":
import numpy as np import numpy as np
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) # type: ignore
noise = noise[F : F + H, F : F + W] noise = noise[F : F + H, F : F + W]
noise -= np.min(noise) noise -= np.min(noise)
noise /= np.max(noise) noise /= np.max(noise)
@ -165,7 +165,7 @@ def shuffle_map_np(img: "np.ndarray", h=None, w=None, f=256) -> "np.ndarray":
x = make_noise_disk(h, w, 1, f) * float(W - 1) x = make_noise_disk(h, w, 1, f) * float(W - 1)
y = make_noise_disk(h, w, 1, f) * float(H - 1) y = make_noise_disk(h, w, 1, f) * float(H - 1)
flow = np.concatenate([x, y], axis=2).astype(np.float32) flow = np.concatenate([x, y], axis=2).astype(np.float32)
return cv2.remap(img, flow, None, cv2.INTER_LINEAR) return cv2.remap(img, flow, None, cv2.INTER_LINEAR) # type: ignore
def shuffle_map_torch(tensor: "Tensor", h=None, w=None, f=256) -> "Tensor": def shuffle_map_torch(tensor: "Tensor", h=None, w=None, f=256) -> "Tensor":

View File

@ -5,8 +5,8 @@ import logging
import math import math
from contextlib import contextmanager from contextlib import contextmanager
import pytorch_lightning as pl
import torch import torch
from torch import nn
from torch.cuda import OutOfMemoryError from torch.cuda import OutOfMemoryError
from imaginairy.modules.diffusion.model import Decoder, Encoder from imaginairy.modules.diffusion.model import Decoder, Encoder
@ -18,7 +18,7 @@ from imaginairy.utils.feather_tile import rebuild_image, tile_image
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AutoencoderKL(pl.LightningModule): class AutoencoderKL(nn.Module):
def __init__( def __init__(
self, self,
ddconfig, ddconfig,

View File

@ -14,7 +14,6 @@ from functools import partial
from typing import Optional from typing import Optional
import numpy as np import numpy as np
import pytorch_lightning as pl
import torch import torch
from einops import rearrange, repeat from einops import rearrange, repeat
from omegaconf import ListConfig from omegaconf import ListConfig
@ -93,7 +92,7 @@ def uniform_on_device(r1, r2, shape, device):
return (r1 - r2) * torch.rand(*shape, device=device) + r2 return (r1 - r2) * torch.rand(*shape, device=device) + r2
class DDPM(pl.LightningModule): class DDPM(nn.Module):
# classic DDPM with Gaussian diffusion, in image space # classic DDPM with Gaussian diffusion, in image space
def __init__( def __init__(
self, self,
@ -1711,7 +1710,7 @@ class LatentDiffusion(DDPM):
return x return x
class DiffusionWrapper(pl.LightningModule): class DiffusionWrapper(nn.Module):
def __init__(self, diff_model_config, conditioning_key): def __init__(self, diff_model_config, conditioning_key):
super().__init__() super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config) self.diffusion_model = instantiate_from_config(diff_model_config)

View File

@ -7,7 +7,6 @@ from abc import abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import pytorch_lightning as pl
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import rearrange from einops import rearrange
@ -32,7 +31,7 @@ if TYPE_CHECKING:
logpy = logging.getLogger(__name__) logpy = logging.getLogger(__name__)
class AbstractAutoencoder(pl.LightningModule): class AbstractAutoencoder(nn.Module):
""" """
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features unCLIP models, etc. Hence, it is fairly general, and specific features

View File

@ -5,10 +5,10 @@ import math
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import pytorch_lightning as pl
import torch import torch
from omegaconf import ListConfig, OmegaConf from omegaconf import ListConfig, OmegaConf
from safetensors.torch import load_file as load_safetensors from safetensors.torch import load_file as load_safetensors
from torch import nn
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from imaginairy.modules.ema import LitEma from imaginairy.modules.ema import LitEma
@ -30,7 +30,7 @@ UNCONDITIONAL_CONFIG = {
OPENAIUNETWRAPPER = "imaginairy.modules.sgm.diffusionmodules.wrappers.OpenAIWrapper" OPENAIUNETWRAPPER = "imaginairy.modules.sgm.diffusionmodules.wrappers.OpenAIWrapper"
class DiffusionEngine(pl.LightningModule): class DiffusionEngine(nn.Module):
def __init__( def __init__(
self, self,
network_config, network_config,

View File

@ -1,6 +1,8 @@
import importlib import importlib
import logging import logging
import numpy as np
import platform import platform
import random
import re import re
import time import time
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
@ -334,3 +336,12 @@ def clear_gpu_cache():
gc.collect() gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
def seed_everything(seed: int | None = None) -> None:
if seed is None:
seed = random.randint(0, 2**32 - 1)
logger.info(f"Using random seed: {seed}")
random.seed(a=seed)
np.random.seed(seed=seed)
torch.manual_seed(seed=seed)
torch.cuda.manual_seed_all(seed=seed)

View File

@ -389,7 +389,10 @@ def disable_transformers_custom_logging():
def disable_pytorch_lighting_custom_logging(): def disable_pytorch_lighting_custom_logging():
from pytorch_lightning import _logger as pytorch_logger try:
from pytorch_lightning import _logger as pytorch_logger
except ImportError:
return
try: try:
from pytorch_lightning.utilities.seed import log from pytorch_lightning.utilities.seed import log
@ -419,7 +422,7 @@ def disable_common_warnings():
def suppress_annoying_logs_and_warnings(): def suppress_annoying_logs_and_warnings():
disable_transformers_custom_logging() disable_transformers_custom_logging()
disable_pytorch_lighting_custom_logging() # disable_pytorch_lighting_custom_logging()
disable_common_warnings() disable_common_warnings()

View File

@ -4,10 +4,6 @@
# #
# pip-compile --output-file=requirements-dev.txt requirements-dev.in setup.py # pip-compile --output-file=requirements-dev.txt requirements-dev.in setup.py
# #
aiohttp==3.9.1
# via fsspec
aiosignal==1.3.1
# via aiohttp
annotated-types==0.6.0 annotated-types==0.6.0
# via pydantic # via pydantic
antlr4-python3-runtime==4.9.3 antlr4-python3-runtime==4.9.3
@ -16,10 +12,6 @@ anyio==4.2.0
# via # via
# httpx # httpx
# starlette # starlette
async-timeout==4.0.3
# via aiohttp
attrs==23.1.0
# via aiohttp
certifi==2023.11.17 certifi==2023.11.17
# via # via
# httpcore # httpcore
@ -67,14 +59,9 @@ filterpy==1.4.5
# via facexlib # via facexlib
fonttools==4.47.0 fonttools==4.47.0
# via matplotlib # via matplotlib
frozenlist==1.4.1 fsspec==2023.12.2
# via
# aiohttp
# aiosignal
fsspec[http]==2023.12.2
# via # via
# huggingface-hub # huggingface-hub
# pytorch-lightning
# torch # torch
ftfy==6.1.3 ftfy==6.1.3
# via # via
@ -100,7 +87,6 @@ idna==3.6
# anyio # anyio
# httpx # httpx
# requests # requests
# yarl
imageio==2.33.1 imageio==2.33.1
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
importlib-metadata==7.0.1 importlib-metadata==7.0.1
@ -115,10 +101,6 @@ kiwisolver==1.4.5
# via matplotlib # via matplotlib
kornia==0.7.1 kornia==0.7.1
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
lightning-utilities==0.10.0
# via
# pytorch-lightning
# torchmetrics
llvmlite==0.41.1 llvmlite==0.41.1
# via numba # via numba
markupsafe==2.1.3 markupsafe==2.1.3
@ -129,10 +111,6 @@ matplotlib==3.7.4
# filterpy # filterpy
mpmath==1.3.0 mpmath==1.3.0
# via sympy # via sympy
multidict==6.0.4
# via
# aiohttp
# yarl
mypy==1.8.0 mypy==1.8.0
# via -r requirements-dev.in # via -r requirements-dev.in
mypy-extensions==1.0.0 mypy-extensions==1.0.0
@ -155,17 +133,15 @@ numpy==1.24.4
# matplotlib # matplotlib
# numba # numba
# opencv-python # opencv-python
# pytorch-lightning
# refiners # refiners
# scipy # scipy
# torchmetrics
# torchvision # torchvision
# transformers # transformers
omegaconf==2.3.0 omegaconf==2.3.0
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
open-clip-torch==2.23.0 open-clip-torch==2.23.0
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
opencv-python==4.8.1.78 opencv-python==4.9.0.80
# via # via
# facexlib # facexlib
# imaginAIry (setup.py) # imaginAIry (setup.py)
@ -173,14 +149,11 @@ packaging==23.2
# via # via
# huggingface-hub # huggingface-hub
# kornia # kornia
# lightning-utilities
# matplotlib # matplotlib
# pytest # pytest
# pytest-sugar # pytest-sugar
# pytorch-lightning
# torchmetrics
# transformers # transformers
pillow==10.1.0 pillow==10.2.0
# via # via
# diffusers # diffusers
# facexlib # facexlib
@ -205,13 +178,13 @@ pydantic-core==2.14.6
# via pydantic # via pydantic
pyparsing==3.1.1 pyparsing==3.1.1
# via matplotlib # via matplotlib
pytest==7.4.3 pytest==7.4.4
# via # via
# -r requirements-dev.in # -r requirements-dev.in
# pytest-asyncio # pytest-asyncio
# pytest-randomly # pytest-randomly
# pytest-sugar # pytest-sugar
pytest-asyncio==0.23.2 pytest-asyncio==0.23.3
# via -r requirements-dev.in # via -r requirements-dev.in
pytest-randomly==3.15.0 pytest-randomly==3.15.0
# via -r requirements-dev.in # via -r requirements-dev.in
@ -219,13 +192,10 @@ pytest-sugar==0.9.7
# via -r requirements-dev.in # via -r requirements-dev.in
python-dateutil==2.8.2 python-dateutil==2.8.2
# via matplotlib # via matplotlib
pytorch-lightning==1.9.5
# via imaginAIry (setup.py)
pyyaml==6.0.1 pyyaml==6.0.1
# via # via
# huggingface-hub # huggingface-hub
# omegaconf # omegaconf
# pytorch-lightning
# responses # responses
# timm # timm
# transformers # transformers
@ -239,7 +209,6 @@ regex==2023.12.25
requests==2.31.0 requests==2.31.0
# via # via
# diffusers # diffusers
# fsspec
# huggingface-hub # huggingface-hub
# imaginAIry (setup.py) # imaginAIry (setup.py)
# responses # responses
@ -293,18 +262,12 @@ torch==2.1.2
# imaginAIry (setup.py) # imaginAIry (setup.py)
# kornia # kornia
# open-clip-torch # open-clip-torch
# pytorch-lightning
# refiners # refiners
# timm # timm
# torchdiffeq # torchdiffeq
# torchmetrics
# torchvision # torchvision
torchdiffeq==0.2.3 torchdiffeq==0.2.3
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
torchmetrics==1.2.1
# via
# imaginAIry (setup.py)
# pytorch-lightning
torchvision==0.16.2 torchvision==0.16.2
# via # via
# facexlib # facexlib
@ -317,7 +280,6 @@ tqdm==4.66.1
# huggingface-hub # huggingface-hub
# imaginAIry (setup.py) # imaginAIry (setup.py)
# open-clip-torch # open-clip-torch
# pytorch-lightning
# transformers # transformers
transformers==4.36.2 transformers==4.36.2
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
@ -327,7 +289,7 @@ types-pillow==10.1.0.2
# via -r requirements-dev.in # via -r requirements-dev.in
types-psutil==5.9.5.17 types-psutil==5.9.5.17
# via -r requirements-dev.in # via -r requirements-dev.in
types-requests==2.31.0.10 types-requests==2.31.0.20231231
# via -r requirements-dev.in # via -r requirements-dev.in
types-tqdm==4.66.0.5 types-tqdm==4.66.0.5
# via -r requirements-dev.in # via -r requirements-dev.in
@ -337,11 +299,9 @@ typing-extensions==4.9.0
# fastapi # fastapi
# huggingface-hub # huggingface-hub
# jaxtyping # jaxtyping
# lightning-utilities
# mypy # mypy
# pydantic # pydantic
# pydantic-core # pydantic-core
# pytorch-lightning
# torch # torch
# uvicorn # uvicorn
urllib3==2.1.0 urllib3==2.1.0
@ -355,10 +315,5 @@ wcwidth==0.2.12
# via ftfy # via ftfy
wheel==0.42.0 wheel==0.42.0
# via -r requirements-dev.in # via -r requirements-dev.in
yarl==1.9.4
# via aiohttp
zipp==3.17.0 zipp==3.17.0
# via importlib-metadata # via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools

View File

@ -81,7 +81,6 @@ setup(
"fairscale>=0.4.4", # for vendored blip "fairscale>=0.4.4", # for vendored blip
"fastapi>=0.70.0", "fastapi>=0.70.0",
"ftfy>=6.0.1", # for vendored clip "ftfy>=6.0.1", # for vendored clip
# 2.0.0 produced garbage images on macOS
"torch>=2.1.0", "torch>=2.1.0",
# https://numpy.org/neps/nep-0029-deprecation_policy.html # https://numpy.org/neps/nep-0029-deprecation_policy.html
"numpy>=1.19.0,<1.26.0", "numpy>=1.19.0,<1.26.0",
@ -90,8 +89,6 @@ setup(
"imageio>=2.9.0", "imageio>=2.9.0",
"Pillow>=9.1.0", "Pillow>=9.1.0",
"psutil>5.7.3", "psutil>5.7.3",
# 2.0.0 need to fix `ImportError: cannot import name 'rank_zero_only' from 'pytorch_lightning.utilities.distributed' `
"pytorch-lightning>=1.4.2,<2.0.0",
"omegaconf>=2.1.1", "omegaconf>=2.1.1",
"open-clip-torch>=2.0.0", "open-clip-torch>=2.0.0",
"opencv-python>=4.4.0.46", "opencv-python>=4.4.0.46",
@ -105,7 +102,6 @@ setup(
"scipy<1.11", "scipy<1.11",
"timm>=0.4.12,!=0.9.0,!=0.9.1", # for vendored blip "timm>=0.4.12,!=0.9.0,!=0.9.1", # for vendored blip
"torchdiffeq>=0.2.0", "torchdiffeq>=0.2.0",
"torchmetrics>=0.6.0",
"torchvision>=0.13.1", "torchvision>=0.13.1",
"transformers>=4.19.2", "transformers>=4.19.2",
"triton>=2.0.0; sys_platform!='darwin' and platform_machine!='aarch64'", "triton>=2.0.0; sys_platform!='darwin' and platform_machine!='aarch64'",

View File

@ -1,11 +1,11 @@
import itertools import itertools
import pytest import pytest
from lightning_fabric import seed_everything
from imaginairy.img_processors.control_modes import CONTROL_MODES, create_depth_map from imaginairy.img_processors.control_modes import CONTROL_MODES, create_depth_map
from imaginairy.modules.midas.api import ISL_PATHS from imaginairy.modules.midas.api import ISL_PATHS
from imaginairy.schema import LazyLoadingImage from imaginairy.schema import LazyLoadingImage
from imaginairy.utils import seed_everything
from imaginairy.utils.img_utils import ( from imaginairy.utils.img_utils import (
pillow_img_to_torch_image, pillow_img_to_torch_image,
torch_img_to_pillow_img, torch_img_to_pillow_img,

View File

@ -1,9 +1,8 @@
import pytest import pytest
from lightning_fabric import seed_everything
from PIL import Image from PIL import Image
from imaginairy.enhancers.describe_image_blip import generate_caption from imaginairy.enhancers.describe_image_blip import generate_caption
from imaginairy.utils import get_device from imaginairy.utils import get_device, seed_everything
from tests import TESTS_FOLDER from tests import TESTS_FOLDER

View File

@ -1,9 +1,8 @@
import pytest import pytest
from lightning_fabric import seed_everything
from PIL import Image from PIL import Image
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.utils import get_device from imaginairy.utils import get_device, seed_everything
from tests import TESTS_FOLDER from tests import TESTS_FOLDER
from tests.utils import assert_image_similar_to_expectation from tests.utils import assert_image_similar_to_expectation