mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Fix error with training video models with batch greater than 1
This commit is contained in:
@@ -503,7 +503,11 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest')
|
mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest')
|
||||||
|
|
||||||
# multiply by our mask
|
# multiply by our mask
|
||||||
loss = loss * mask_multiplier
|
try:
|
||||||
|
loss = loss * mask_multiplier
|
||||||
|
except:
|
||||||
|
# todo handle mask with video models
|
||||||
|
pass
|
||||||
|
|
||||||
prior_loss = None
|
prior_loss = None
|
||||||
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
|
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
|
||||||
@@ -524,7 +528,12 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# loss = loss + prior_loss
|
# loss = loss + prior_loss
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
# apply loss multiplier before prior loss
|
# apply loss multiplier before prior loss
|
||||||
loss = loss * loss_multiplier
|
# multiply by our mask
|
||||||
|
try:
|
||||||
|
loss = loss * loss_multiplier
|
||||||
|
except:
|
||||||
|
# todo handle mask with video models
|
||||||
|
pass
|
||||||
if prior_loss is not None:
|
if prior_loss is not None:
|
||||||
loss = loss + prior_loss
|
loss = loss + prior_loss
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user