mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-05 12:00:15 +00:00
fix: gracefully fail if older pytorch-lightning installed
This commit is contained in:
parent
24e10f9e5f
commit
02af4c37b9
@ -13,7 +13,13 @@ from omegaconf import OmegaConf
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
|
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
|
||||||
from pytorch_lightning.strategies import DDPStrategy
|
|
||||||
|
try:
|
||||||
|
from pytorch_lightning.strategies import DDPStrategy
|
||||||
|
except ImportError:
|
||||||
|
# let's not break all of imaginairy just because a training import doesn't exist in an older version of PL
|
||||||
|
# Use >= 1.6.0 to make this work
|
||||||
|
DDPStrategy = None
|
||||||
from pytorch_lightning.trainer import Trainer
|
from pytorch_lightning.trainer import Trainer
|
||||||
from pytorch_lightning.utilities import rank_zero_info
|
from pytorch_lightning.utilities import rank_zero_info
|
||||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||||
@ -387,6 +393,9 @@ def train_diffusion_model(
|
|||||||
|
|
||||||
accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf
|
accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf
|
||||||
"""
|
"""
|
||||||
|
if DDPStrategy is None:
|
||||||
|
raise ImportError("Please install pytorch-lightning>=1.6.0 to train a model")
|
||||||
|
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
seed = 23
|
seed = 23
|
||||||
num_workers = 1
|
num_workers = 1
|
||||||
|
Loading…
Reference in New Issue
Block a user