mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
build: remove pytorch lightning dependency
This commit is contained in:
parent
9f33fa0664
commit
5b3b04b877
@ -2,6 +2,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from imaginairy.api.generate import IMAGINAIRY_SAFETY_MODE
|
||||
from imaginairy.utils import seed_everything
|
||||
from imaginairy.utils.img_utils import calc_scale_to_fit_within, combine_image
|
||||
from imaginairy.utils.named_resolutions import normalize_image_size
|
||||
|
||||
@ -25,7 +26,6 @@ def _generate_single_image_compvis(
|
||||
):
|
||||
import torch.nn
|
||||
from PIL import Image, ImageOps
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from imaginairy.enhancers.clip_masking import get_img_mask
|
||||
from imaginairy.enhancers.describe_image_blip import generate_caption
|
||||
|
@ -6,7 +6,7 @@ from typing import Any
|
||||
|
||||
from imaginairy.config import CONTROL_CONFIG_SHORTCUTS
|
||||
from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode
|
||||
from imaginairy.utils import clear_gpu_cache
|
||||
from imaginairy.utils import clear_gpu_cache, seed_everything
|
||||
from imaginairy.utils.log_utils import ImageLoggingContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -27,7 +27,6 @@ def generate_single_image(
|
||||
):
|
||||
import torch.nn
|
||||
from PIL import Image, ImageOps
|
||||
from pytorch_lightning import seed_everything
|
||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DPMSolver
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -15,7 +15,7 @@ def create_canny_edges(img: "Tensor") -> "Tensor":
|
||||
img = torch.clamp((img + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
img = einops.rearrange(img[0], "c h w -> h w c")
|
||||
img = (255.0 * img).cpu().numpy().astype(np.uint8).squeeze()
|
||||
blurred = cv2.GaussianBlur(img, (5, 5), 0).astype(np.uint8)
|
||||
blurred = cv2.GaussianBlur(img, (5, 5), 0).astype(np.uint8) # type: ignore
|
||||
|
||||
if len(blurred.shape) > 2:
|
||||
blurred = cv2.cvtColor(blurred, cv2.COLOR_BGR2GRAY)
|
||||
@ -143,7 +143,7 @@ def make_noise_disk(H: int, W: int, C: int, F: int) -> "np.ndarray":
|
||||
import numpy as np
|
||||
|
||||
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
|
||||
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
|
||||
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) # type: ignore
|
||||
noise = noise[F : F + H, F : F + W]
|
||||
noise -= np.min(noise)
|
||||
noise /= np.max(noise)
|
||||
@ -165,7 +165,7 @@ def shuffle_map_np(img: "np.ndarray", h=None, w=None, f=256) -> "np.ndarray":
|
||||
x = make_noise_disk(h, w, 1, f) * float(W - 1)
|
||||
y = make_noise_disk(h, w, 1, f) * float(H - 1)
|
||||
flow = np.concatenate([x, y], axis=2).astype(np.float32)
|
||||
return cv2.remap(img, flow, None, cv2.INTER_LINEAR)
|
||||
return cv2.remap(img, flow, None, cv2.INTER_LINEAR) # type: ignore
|
||||
|
||||
|
||||
def shuffle_map_torch(tensor: "Tensor", h=None, w=None, f=256) -> "Tensor":
|
||||
|
@ -5,8 +5,8 @@ import logging
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.cuda import OutOfMemoryError
|
||||
|
||||
from imaginairy.modules.diffusion.model import Decoder, Encoder
|
||||
@ -18,7 +18,7 @@ from imaginairy.utils.feather_tile import rebuild_image, tile_image
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
class AutoencoderKL(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
|
@ -14,7 +14,6 @@ from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import ListConfig
|
||||
@ -93,7 +92,7 @@ def uniform_on_device(r1, r2, shape, device):
|
||||
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
||||
|
||||
|
||||
class DDPM(pl.LightningModule):
|
||||
class DDPM(nn.Module):
|
||||
# classic DDPM with Gaussian diffusion, in image space
|
||||
def __init__(
|
||||
self,
|
||||
@ -1711,7 +1710,7 @@ class LatentDiffusion(DDPM):
|
||||
return x
|
||||
|
||||
|
||||
class DiffusionWrapper(pl.LightningModule):
|
||||
class DiffusionWrapper(nn.Module):
|
||||
def __init__(self, diff_model_config, conditioning_key):
|
||||
super().__init__()
|
||||
self.diffusion_model = instantiate_from_config(diff_model_config)
|
||||
|
@ -7,7 +7,6 @@ from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
@ -32,7 +31,7 @@ if TYPE_CHECKING:
|
||||
logpy = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractAutoencoder(pl.LightningModule):
|
||||
class AbstractAutoencoder(nn.Module):
|
||||
"""
|
||||
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
||||
unCLIP models, etc. Hence, it is fairly general, and specific features
|
||||
|
@ -5,10 +5,10 @@ import math
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from imaginairy.modules.ema import LitEma
|
||||
@ -30,7 +30,7 @@ UNCONDITIONAL_CONFIG = {
|
||||
OPENAIUNETWRAPPER = "imaginairy.modules.sgm.diffusionmodules.wrappers.OpenAIWrapper"
|
||||
|
||||
|
||||
class DiffusionEngine(pl.LightningModule):
|
||||
class DiffusionEngine(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
network_config,
|
||||
|
@ -1,6 +1,8 @@
|
||||
import importlib
|
||||
import logging
|
||||
import numpy as np
|
||||
import platform
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from contextlib import contextmanager, nullcontext
|
||||
@ -334,3 +336,12 @@ def clear_gpu_cache():
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def seed_everything(seed: int | None = None) -> None:
|
||||
if seed is None:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
logger.info(f"Using random seed: {seed}")
|
||||
random.seed(a=seed)
|
||||
np.random.seed(seed=seed)
|
||||
torch.manual_seed(seed=seed)
|
||||
torch.cuda.manual_seed_all(seed=seed)
|
@ -389,7 +389,10 @@ def disable_transformers_custom_logging():
|
||||
|
||||
|
||||
def disable_pytorch_lighting_custom_logging():
|
||||
try:
|
||||
from pytorch_lightning import _logger as pytorch_logger
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
try:
|
||||
from pytorch_lightning.utilities.seed import log
|
||||
@ -419,7 +422,7 @@ def disable_common_warnings():
|
||||
|
||||
def suppress_annoying_logs_and_warnings():
|
||||
disable_transformers_custom_logging()
|
||||
disable_pytorch_lighting_custom_logging()
|
||||
# disable_pytorch_lighting_custom_logging()
|
||||
disable_common_warnings()
|
||||
|
||||
|
||||
|
@ -4,10 +4,6 @@
|
||||
#
|
||||
# pip-compile --output-file=requirements-dev.txt requirements-dev.in setup.py
|
||||
#
|
||||
aiohttp==3.9.1
|
||||
# via fsspec
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
annotated-types==0.6.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
@ -16,10 +12,6 @@ anyio==4.2.0
|
||||
# via
|
||||
# httpx
|
||||
# starlette
|
||||
async-timeout==4.0.3
|
||||
# via aiohttp
|
||||
attrs==23.1.0
|
||||
# via aiohttp
|
||||
certifi==2023.11.17
|
||||
# via
|
||||
# httpcore
|
||||
@ -67,14 +59,9 @@ filterpy==1.4.5
|
||||
# via facexlib
|
||||
fonttools==4.47.0
|
||||
# via matplotlib
|
||||
frozenlist==1.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2023.12.2
|
||||
fsspec==2023.12.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# pytorch-lightning
|
||||
# torch
|
||||
ftfy==6.1.3
|
||||
# via
|
||||
@ -100,7 +87,6 @@ idna==3.6
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
imageio==2.33.1
|
||||
# via imaginAIry (setup.py)
|
||||
importlib-metadata==7.0.1
|
||||
@ -115,10 +101,6 @@ kiwisolver==1.4.5
|
||||
# via matplotlib
|
||||
kornia==0.7.1
|
||||
# via imaginAIry (setup.py)
|
||||
lightning-utilities==0.10.0
|
||||
# via
|
||||
# pytorch-lightning
|
||||
# torchmetrics
|
||||
llvmlite==0.41.1
|
||||
# via numba
|
||||
markupsafe==2.1.3
|
||||
@ -129,10 +111,6 @@ matplotlib==3.7.4
|
||||
# filterpy
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
multidict==6.0.4
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
mypy==1.8.0
|
||||
# via -r requirements-dev.in
|
||||
mypy-extensions==1.0.0
|
||||
@ -155,17 +133,15 @@ numpy==1.24.4
|
||||
# matplotlib
|
||||
# numba
|
||||
# opencv-python
|
||||
# pytorch-lightning
|
||||
# refiners
|
||||
# scipy
|
||||
# torchmetrics
|
||||
# torchvision
|
||||
# transformers
|
||||
omegaconf==2.3.0
|
||||
# via imaginAIry (setup.py)
|
||||
open-clip-torch==2.23.0
|
||||
# via imaginAIry (setup.py)
|
||||
opencv-python==4.8.1.78
|
||||
opencv-python==4.9.0.80
|
||||
# via
|
||||
# facexlib
|
||||
# imaginAIry (setup.py)
|
||||
@ -173,14 +149,11 @@ packaging==23.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# kornia
|
||||
# lightning-utilities
|
||||
# matplotlib
|
||||
# pytest
|
||||
# pytest-sugar
|
||||
# pytorch-lightning
|
||||
# torchmetrics
|
||||
# transformers
|
||||
pillow==10.1.0
|
||||
pillow==10.2.0
|
||||
# via
|
||||
# diffusers
|
||||
# facexlib
|
||||
@ -205,13 +178,13 @@ pydantic-core==2.14.6
|
||||
# via pydantic
|
||||
pyparsing==3.1.1
|
||||
# via matplotlib
|
||||
pytest==7.4.3
|
||||
pytest==7.4.4
|
||||
# via
|
||||
# -r requirements-dev.in
|
||||
# pytest-asyncio
|
||||
# pytest-randomly
|
||||
# pytest-sugar
|
||||
pytest-asyncio==0.23.2
|
||||
pytest-asyncio==0.23.3
|
||||
# via -r requirements-dev.in
|
||||
pytest-randomly==3.15.0
|
||||
# via -r requirements-dev.in
|
||||
@ -219,13 +192,10 @@ pytest-sugar==0.9.7
|
||||
# via -r requirements-dev.in
|
||||
python-dateutil==2.8.2
|
||||
# via matplotlib
|
||||
pytorch-lightning==1.9.5
|
||||
# via imaginAIry (setup.py)
|
||||
pyyaml==6.0.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# omegaconf
|
||||
# pytorch-lightning
|
||||
# responses
|
||||
# timm
|
||||
# transformers
|
||||
@ -239,7 +209,6 @@ regex==2023.12.25
|
||||
requests==2.31.0
|
||||
# via
|
||||
# diffusers
|
||||
# fsspec
|
||||
# huggingface-hub
|
||||
# imaginAIry (setup.py)
|
||||
# responses
|
||||
@ -293,18 +262,12 @@ torch==2.1.2
|
||||
# imaginAIry (setup.py)
|
||||
# kornia
|
||||
# open-clip-torch
|
||||
# pytorch-lightning
|
||||
# refiners
|
||||
# timm
|
||||
# torchdiffeq
|
||||
# torchmetrics
|
||||
# torchvision
|
||||
torchdiffeq==0.2.3
|
||||
# via imaginAIry (setup.py)
|
||||
torchmetrics==1.2.1
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# pytorch-lightning
|
||||
torchvision==0.16.2
|
||||
# via
|
||||
# facexlib
|
||||
@ -317,7 +280,6 @@ tqdm==4.66.1
|
||||
# huggingface-hub
|
||||
# imaginAIry (setup.py)
|
||||
# open-clip-torch
|
||||
# pytorch-lightning
|
||||
# transformers
|
||||
transformers==4.36.2
|
||||
# via imaginAIry (setup.py)
|
||||
@ -327,7 +289,7 @@ types-pillow==10.1.0.2
|
||||
# via -r requirements-dev.in
|
||||
types-psutil==5.9.5.17
|
||||
# via -r requirements-dev.in
|
||||
types-requests==2.31.0.10
|
||||
types-requests==2.31.0.20231231
|
||||
# via -r requirements-dev.in
|
||||
types-tqdm==4.66.0.5
|
||||
# via -r requirements-dev.in
|
||||
@ -337,11 +299,9 @@ typing-extensions==4.9.0
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# jaxtyping
|
||||
# lightning-utilities
|
||||
# mypy
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pytorch-lightning
|
||||
# torch
|
||||
# uvicorn
|
||||
urllib3==2.1.0
|
||||
@ -355,10 +315,5 @@ wcwidth==0.2.12
|
||||
# via ftfy
|
||||
wheel==0.42.0
|
||||
# via -r requirements-dev.in
|
||||
yarl==1.9.4
|
||||
# via aiohttp
|
||||
zipp==3.17.0
|
||||
# via importlib-metadata
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
|
4
setup.py
4
setup.py
@ -81,7 +81,6 @@ setup(
|
||||
"fairscale>=0.4.4", # for vendored blip
|
||||
"fastapi>=0.70.0",
|
||||
"ftfy>=6.0.1", # for vendored clip
|
||||
# 2.0.0 produced garbage images on macOS
|
||||
"torch>=2.1.0",
|
||||
# https://numpy.org/neps/nep-0029-deprecation_policy.html
|
||||
"numpy>=1.19.0,<1.26.0",
|
||||
@ -90,8 +89,6 @@ setup(
|
||||
"imageio>=2.9.0",
|
||||
"Pillow>=9.1.0",
|
||||
"psutil>5.7.3",
|
||||
# 2.0.0 need to fix `ImportError: cannot import name 'rank_zero_only' from 'pytorch_lightning.utilities.distributed' `
|
||||
"pytorch-lightning>=1.4.2,<2.0.0",
|
||||
"omegaconf>=2.1.1",
|
||||
"open-clip-torch>=2.0.0",
|
||||
"opencv-python>=4.4.0.46",
|
||||
@ -105,7 +102,6 @@ setup(
|
||||
"scipy<1.11",
|
||||
"timm>=0.4.12,!=0.9.0,!=0.9.1", # for vendored blip
|
||||
"torchdiffeq>=0.2.0",
|
||||
"torchmetrics>=0.6.0",
|
||||
"torchvision>=0.13.1",
|
||||
"transformers>=4.19.2",
|
||||
"triton>=2.0.0; sys_platform!='darwin' and platform_machine!='aarch64'",
|
||||
|
@ -1,11 +1,11 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
from lightning_fabric import seed_everything
|
||||
|
||||
from imaginairy.img_processors.control_modes import CONTROL_MODES, create_depth_map
|
||||
from imaginairy.modules.midas.api import ISL_PATHS
|
||||
from imaginairy.schema import LazyLoadingImage
|
||||
from imaginairy.utils import seed_everything
|
||||
from imaginairy.utils.img_utils import (
|
||||
pillow_img_to_torch_image,
|
||||
torch_img_to_pillow_img,
|
||||
|
@ -1,9 +1,8 @@
|
||||
import pytest
|
||||
from lightning_fabric import seed_everything
|
||||
from PIL import Image
|
||||
|
||||
from imaginairy.enhancers.describe_image_blip import generate_caption
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.utils import get_device, seed_everything
|
||||
from tests import TESTS_FOLDER
|
||||
|
||||
|
||||
|
@ -1,9 +1,8 @@
|
||||
import pytest
|
||||
from lightning_fabric import seed_everything
|
||||
from PIL import Image
|
||||
|
||||
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.utils import get_device, seed_everything
|
||||
from tests import TESTS_FOLDER
|
||||
from tests.utils import assert_image_similar_to_expectation
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user