build: remove pytorch lightning dependency

This commit is contained in:
Bryce 2024-01-02 06:54:31 -08:00 committed by Bryce Drennan
parent 9f33fa0664
commit 5b3b04b877
14 changed files with 37 additions and 77 deletions

View File

@ -2,6 +2,7 @@ import logging
from typing import TYPE_CHECKING, Any
from imaginairy.api.generate import IMAGINAIRY_SAFETY_MODE
from imaginairy.utils import seed_everything
from imaginairy.utils.img_utils import calc_scale_to_fit_within, combine_image
from imaginairy.utils.named_resolutions import normalize_image_size
@ -25,7 +26,6 @@ def _generate_single_image_compvis(
):
import torch.nn
from PIL import Image, ImageOps
from pytorch_lightning import seed_everything
from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.describe_image_blip import generate_caption

View File

@ -6,7 +6,7 @@ from typing import Any
from imaginairy.config import CONTROL_CONFIG_SHORTCUTS
from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode
from imaginairy.utils import clear_gpu_cache
from imaginairy.utils import clear_gpu_cache, seed_everything
from imaginairy.utils.log_utils import ImageLoggingContext
logger = logging.getLogger(__name__)
@ -27,7 +27,6 @@ def generate_single_image(
):
import torch.nn
from PIL import Image, ImageOps
from pytorch_lightning import seed_everything
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DPMSolver
from tqdm import tqdm

View File

@ -15,7 +15,7 @@ def create_canny_edges(img: "Tensor") -> "Tensor":
img = torch.clamp((img + 1.0) / 2.0, min=0.0, max=1.0)
img = einops.rearrange(img[0], "c h w -> h w c")
img = (255.0 * img).cpu().numpy().astype(np.uint8).squeeze()
blurred = cv2.GaussianBlur(img, (5, 5), 0).astype(np.uint8)
blurred = cv2.GaussianBlur(img, (5, 5), 0).astype(np.uint8) # type: ignore
if len(blurred.shape) > 2:
blurred = cv2.cvtColor(blurred, cv2.COLOR_BGR2GRAY)
@ -143,7 +143,7 @@ def make_noise_disk(H: int, W: int, C: int, F: int) -> "np.ndarray":
import numpy as np
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) # type: ignore
noise = noise[F : F + H, F : F + W]
noise -= np.min(noise)
noise /= np.max(noise)
@ -165,7 +165,7 @@ def shuffle_map_np(img: "np.ndarray", h=None, w=None, f=256) -> "np.ndarray":
x = make_noise_disk(h, w, 1, f) * float(W - 1)
y = make_noise_disk(h, w, 1, f) * float(H - 1)
flow = np.concatenate([x, y], axis=2).astype(np.float32)
return cv2.remap(img, flow, None, cv2.INTER_LINEAR)
return cv2.remap(img, flow, None, cv2.INTER_LINEAR) # type: ignore
def shuffle_map_torch(tensor: "Tensor", h=None, w=None, f=256) -> "Tensor":

View File

@ -5,8 +5,8 @@ import logging
import math
from contextlib import contextmanager
import pytorch_lightning as pl
import torch
from torch import nn
from torch.cuda import OutOfMemoryError
from imaginairy.modules.diffusion.model import Decoder, Encoder
@ -18,7 +18,7 @@ from imaginairy.utils.feather_tile import rebuild_image, tile_image
logger = logging.getLogger(__name__)
class AutoencoderKL(pl.LightningModule):
class AutoencoderKL(nn.Module):
def __init__(
self,
ddconfig,

View File

@ -14,7 +14,6 @@ from functools import partial
from typing import Optional
import numpy as np
import pytorch_lightning as pl
import torch
from einops import rearrange, repeat
from omegaconf import ListConfig
@ -93,7 +92,7 @@ def uniform_on_device(r1, r2, shape, device):
return (r1 - r2) * torch.rand(*shape, device=device) + r2
class DDPM(pl.LightningModule):
class DDPM(nn.Module):
# classic DDPM with Gaussian diffusion, in image space
def __init__(
self,
@ -1711,7 +1710,7 @@ class LatentDiffusion(DDPM):
return x
class DiffusionWrapper(pl.LightningModule):
class DiffusionWrapper(nn.Module):
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)

View File

@ -7,7 +7,6 @@ from abc import abstractmethod
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
import torch.nn as nn
from einops import rearrange
@ -32,7 +31,7 @@ if TYPE_CHECKING:
logpy = logging.getLogger(__name__)
class AbstractAutoencoder(pl.LightningModule):
class AbstractAutoencoder(nn.Module):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features

View File

@ -5,10 +5,10 @@ import math
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
from omegaconf import ListConfig, OmegaConf
from safetensors.torch import load_file as load_safetensors
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from imaginairy.modules.ema import LitEma
@ -30,7 +30,7 @@ UNCONDITIONAL_CONFIG = {
OPENAIUNETWRAPPER = "imaginairy.modules.sgm.diffusionmodules.wrappers.OpenAIWrapper"
class DiffusionEngine(pl.LightningModule):
class DiffusionEngine(nn.Module):
def __init__(
self,
network_config,

View File

@ -1,6 +1,8 @@
import importlib
import logging
import numpy as np
import platform
import random
import re
import time
from contextlib import contextmanager, nullcontext
@ -334,3 +336,12 @@ def clear_gpu_cache():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def seed_everything(seed: int | None = None) -> None:
if seed is None:
seed = random.randint(0, 2**32 - 1)
logger.info(f"Using random seed: {seed}")
random.seed(a=seed)
np.random.seed(seed=seed)
torch.manual_seed(seed=seed)
torch.cuda.manual_seed_all(seed=seed)

View File

@ -389,7 +389,10 @@ def disable_transformers_custom_logging():
def disable_pytorch_lighting_custom_logging():
from pytorch_lightning import _logger as pytorch_logger
try:
from pytorch_lightning import _logger as pytorch_logger
except ImportError:
return
try:
from pytorch_lightning.utilities.seed import log
@ -419,7 +422,7 @@ def disable_common_warnings():
def suppress_annoying_logs_and_warnings():
disable_transformers_custom_logging()
disable_pytorch_lighting_custom_logging()
# disable_pytorch_lighting_custom_logging()
disable_common_warnings()

View File

@ -4,10 +4,6 @@
#
# pip-compile --output-file=requirements-dev.txt requirements-dev.in setup.py
#
aiohttp==3.9.1
# via fsspec
aiosignal==1.3.1
# via aiohttp
annotated-types==0.6.0
# via pydantic
antlr4-python3-runtime==4.9.3
@ -16,10 +12,6 @@ anyio==4.2.0
# via
# httpx
# starlette
async-timeout==4.0.3
# via aiohttp
attrs==23.1.0
# via aiohttp
certifi==2023.11.17
# via
# httpcore
@ -67,14 +59,9 @@ filterpy==1.4.5
# via facexlib
fonttools==4.47.0
# via matplotlib
frozenlist==1.4.1
# via
# aiohttp
# aiosignal
fsspec[http]==2023.12.2
fsspec==2023.12.2
# via
# huggingface-hub
# pytorch-lightning
# torch
ftfy==6.1.3
# via
@ -100,7 +87,6 @@ idna==3.6
# anyio
# httpx
# requests
# yarl
imageio==2.33.1
# via imaginAIry (setup.py)
importlib-metadata==7.0.1
@ -115,10 +101,6 @@ kiwisolver==1.4.5
# via matplotlib
kornia==0.7.1
# via imaginAIry (setup.py)
lightning-utilities==0.10.0
# via
# pytorch-lightning
# torchmetrics
llvmlite==0.41.1
# via numba
markupsafe==2.1.3
@ -129,10 +111,6 @@ matplotlib==3.7.4
# filterpy
mpmath==1.3.0
# via sympy
multidict==6.0.4
# via
# aiohttp
# yarl
mypy==1.8.0
# via -r requirements-dev.in
mypy-extensions==1.0.0
@ -155,17 +133,15 @@ numpy==1.24.4
# matplotlib
# numba
# opencv-python
# pytorch-lightning
# refiners
# scipy
# torchmetrics
# torchvision
# transformers
omegaconf==2.3.0
# via imaginAIry (setup.py)
open-clip-torch==2.23.0
# via imaginAIry (setup.py)
opencv-python==4.8.1.78
opencv-python==4.9.0.80
# via
# facexlib
# imaginAIry (setup.py)
@ -173,14 +149,11 @@ packaging==23.2
# via
# huggingface-hub
# kornia
# lightning-utilities
# matplotlib
# pytest
# pytest-sugar
# pytorch-lightning
# torchmetrics
# transformers
pillow==10.1.0
pillow==10.2.0
# via
# diffusers
# facexlib
@ -205,13 +178,13 @@ pydantic-core==2.14.6
# via pydantic
pyparsing==3.1.1
# via matplotlib
pytest==7.4.3
pytest==7.4.4
# via
# -r requirements-dev.in
# pytest-asyncio
# pytest-randomly
# pytest-sugar
pytest-asyncio==0.23.2
pytest-asyncio==0.23.3
# via -r requirements-dev.in
pytest-randomly==3.15.0
# via -r requirements-dev.in
@ -219,13 +192,10 @@ pytest-sugar==0.9.7
# via -r requirements-dev.in
python-dateutil==2.8.2
# via matplotlib
pytorch-lightning==1.9.5
# via imaginAIry (setup.py)
pyyaml==6.0.1
# via
# huggingface-hub
# omegaconf
# pytorch-lightning
# responses
# timm
# transformers
@ -239,7 +209,6 @@ regex==2023.12.25
requests==2.31.0
# via
# diffusers
# fsspec
# huggingface-hub
# imaginAIry (setup.py)
# responses
@ -293,18 +262,12 @@ torch==2.1.2
# imaginAIry (setup.py)
# kornia
# open-clip-torch
# pytorch-lightning
# refiners
# timm
# torchdiffeq
# torchmetrics
# torchvision
torchdiffeq==0.2.3
# via imaginAIry (setup.py)
torchmetrics==1.2.1
# via
# imaginAIry (setup.py)
# pytorch-lightning
torchvision==0.16.2
# via
# facexlib
@ -317,7 +280,6 @@ tqdm==4.66.1
# huggingface-hub
# imaginAIry (setup.py)
# open-clip-torch
# pytorch-lightning
# transformers
transformers==4.36.2
# via imaginAIry (setup.py)
@ -327,7 +289,7 @@ types-pillow==10.1.0.2
# via -r requirements-dev.in
types-psutil==5.9.5.17
# via -r requirements-dev.in
types-requests==2.31.0.10
types-requests==2.31.0.20231231
# via -r requirements-dev.in
types-tqdm==4.66.0.5
# via -r requirements-dev.in
@ -337,11 +299,9 @@ typing-extensions==4.9.0
# fastapi
# huggingface-hub
# jaxtyping
# lightning-utilities
# mypy
# pydantic
# pydantic-core
# pytorch-lightning
# torch
# uvicorn
urllib3==2.1.0
@ -355,10 +315,5 @@ wcwidth==0.2.12
# via ftfy
wheel==0.42.0
# via -r requirements-dev.in
yarl==1.9.4
# via aiohttp
zipp==3.17.0
# via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools

View File

@ -81,7 +81,6 @@ setup(
"fairscale>=0.4.4", # for vendored blip
"fastapi>=0.70.0",
"ftfy>=6.0.1", # for vendored clip
# 2.0.0 produced garbage images on macOS
"torch>=2.1.0",
# https://numpy.org/neps/nep-0029-deprecation_policy.html
"numpy>=1.19.0,<1.26.0",
@ -90,8 +89,6 @@ setup(
"imageio>=2.9.0",
"Pillow>=9.1.0",
"psutil>5.7.3",
# 2.0.0 need to fix `ImportError: cannot import name 'rank_zero_only' from 'pytorch_lightning.utilities.distributed' `
"pytorch-lightning>=1.4.2,<2.0.0",
"omegaconf>=2.1.1",
"open-clip-torch>=2.0.0",
"opencv-python>=4.4.0.46",
@ -105,7 +102,6 @@ setup(
"scipy<1.11",
"timm>=0.4.12,!=0.9.0,!=0.9.1", # for vendored blip
"torchdiffeq>=0.2.0",
"torchmetrics>=0.6.0",
"torchvision>=0.13.1",
"transformers>=4.19.2",
"triton>=2.0.0; sys_platform!='darwin' and platform_machine!='aarch64'",

View File

@ -1,11 +1,11 @@
import itertools
import pytest
from lightning_fabric import seed_everything
from imaginairy.img_processors.control_modes import CONTROL_MODES, create_depth_map
from imaginairy.modules.midas.api import ISL_PATHS
from imaginairy.schema import LazyLoadingImage
from imaginairy.utils import seed_everything
from imaginairy.utils.img_utils import (
pillow_img_to_torch_image,
torch_img_to_pillow_img,

View File

@ -1,9 +1,8 @@
import pytest
from lightning_fabric import seed_everything
from PIL import Image
from imaginairy.enhancers.describe_image_blip import generate_caption
from imaginairy.utils import get_device
from imaginairy.utils import get_device, seed_everything
from tests import TESTS_FOLDER

View File

@ -1,9 +1,8 @@
import pytest
from lightning_fabric import seed_everything
from PIL import Image
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.utils import get_device
from imaginairy.utils import get_device, seed_everything
from tests import TESTS_FOLDER
from tests.utils import assert_image_similar_to_expectation