Fixed some new bugs i added. woops

This commit is contained in:
Jaret Burkett
2023-12-28 14:03:42 -07:00
parent eeee4a1620
commit 0892dec4a5
2 changed files with 3 additions and 3 deletions

View File

@@ -163,7 +163,7 @@ class SDTrainer(BaseSDTrainProcess):
loss = loss * mask_multiplier
prior_loss = None
if self.train_config.inverted_mask_prior and prior_pred is not None:
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
# to a loss to unmasked areas of the prior for unmasked regularization
prior_loss = torch.nn.functional.mse_loss(
prior_pred.float(),