mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fix error with training video models with batch greater than 1
This commit is contained in:
@@ -503,7 +503,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest')
|
||||
|
||||
# multiply by our mask
|
||||
loss = loss * mask_multiplier
|
||||
try:
|
||||
loss = loss * mask_multiplier
|
||||
except:
|
||||
# todo handle mask with video models
|
||||
pass
|
||||
|
||||
prior_loss = None
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
|
||||
@@ -524,7 +528,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# loss = loss + prior_loss
|
||||
loss = loss.mean([1, 2, 3])
|
||||
# apply loss multiplier before prior loss
|
||||
loss = loss * loss_multiplier
|
||||
# multiply by our mask
|
||||
try:
|
||||
loss = loss * loss_multiplier
|
||||
except:
|
||||
# todo handle mask with video models
|
||||
pass
|
||||
if prior_loss is not None:
|
||||
loss = loss + prior_loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user