diff --git a/imaginairy/train.py b/imaginairy/train.py index 6064e4d..8c1c584 100644 --- a/imaginairy/train.py +++ b/imaginairy/train.py @@ -13,7 +13,13 @@ from omegaconf import OmegaConf from PIL import Image from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import Callback, LearningRateMonitor -from pytorch_lightning.strategies import DDPStrategy + +try: + from pytorch_lightning.strategies import DDPStrategy +except ImportError: + # let's not break all of imaginairy just because a training import doesn't exist in an older version of PL + # Use >= 1.6.0 to make this work + DDPStrategy = None from pytorch_lightning.trainer import Trainer from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.distributed import rank_zero_only @@ -387,6 +393,9 @@ def train_diffusion_model( accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf """ + if DDPStrategy is None: + raise ImportError("Please install pytorch-lightning>=1.6.0 to train a model") + batch_size = 1 seed = 23 num_workers = 1