From 242c04a0b85e76a3d5964bd3bc5a8f1b0cd4bbab Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 8 Mar 2025 18:47:27 -0700 Subject: [PATCH] Fix error with training video models with batch greater than 1 --- extensions_built_in/sd_trainer/SDTrainer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 3fe2f2f5..6d4f836a 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -503,7 +503,11 @@ class SDTrainer(BaseSDTrainProcess): mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest') # multiply by our mask - loss = loss * mask_multiplier + try: + loss = loss * mask_multiplier + except: + # todo handle mask with video models + pass prior_loss = None if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None: @@ -524,7 +528,12 @@ class SDTrainer(BaseSDTrainProcess): # loss = loss + prior_loss loss = loss.mean([1, 2, 3]) # apply loss multiplier before prior loss - loss = loss * loss_multiplier + # multiply by our mask + try: + loss = loss * loss_multiplier + except: + # todo handle mask with video models + pass if prior_loss is not None: loss = loss + prior_loss