Fix error with training video models with batch greater than 1

This commit is contained in:
Jaret Burkett
2025-03-08 18:47:27 -07:00
parent 386e68a422
commit 242c04a0b8

View File

@@ -503,7 +503,11 @@ class SDTrainer(BaseSDTrainProcess):
mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest')
# multiply by our mask
loss = loss * mask_multiplier
try:
loss = loss * mask_multiplier
except:
# todo handle mask with video models
pass
prior_loss = None
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
@@ -524,7 +528,12 @@ class SDTrainer(BaseSDTrainProcess):
# loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
# apply loss multiplier before prior loss
loss = loss * loss_multiplier
# multiply by our mask
try:
loss = loss * loss_multiplier
except:
# todo handle mask with video models
pass
if prior_loss is not None:
loss = loss + prior_loss