fix: gracefully fail if older pytorch-lightning installed

This commit is contained in:
Bryce 2023-01-16 15:07:15 -08:00 committed by Bryce Drennan
parent 24e10f9e5f
commit 02af4c37b9

View File

@ -13,7 +13,13 @@ from omegaconf import OmegaConf
from PIL import Image from PIL import Image
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback, LearningRateMonitor from pytorch_lightning.callbacks import Callback, LearningRateMonitor
from pytorch_lightning.strategies import DDPStrategy
try:
from pytorch_lightning.strategies import DDPStrategy
except ImportError:
# let's not break all of imaginairy just because a training import doesn't exist in an older version of PL
# Use >= 1.6.0 to make this work
DDPStrategy = None
from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_only
@ -387,6 +393,9 @@ def train_diffusion_model(
accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf
""" """
if DDPStrategy is None:
raise ImportError("Please install pytorch-lightning>=1.6.0 to train a model")
batch_size = 1 batch_size = 1
seed = 23 seed = 23
num_workers = 1 num_workers = 1