diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 3fe2f2f5..6d4f836a 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -503,7 +503,11 @@ class SDTrainer(BaseSDTrainProcess): mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest') # multiply by our mask - loss = loss * mask_multiplier + try: + loss = loss * mask_multiplier + except: + # todo handle mask with video models + pass prior_loss = None if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None: @@ -524,7 +528,12 @@ class SDTrainer(BaseSDTrainProcess): # loss = loss + prior_loss loss = loss.mean([1, 2, 3]) # apply loss multiplier before prior loss - loss = loss * loss_multiplier + # multiply by our mask + try: + loss = loss * loss_multiplier + except: + # todo handle mask with video models + pass if prior_loss is not None: loss = loss + prior_loss