Change img multiplier math

This commit is contained in:
Jaret Burkett
2024-07-30 11:33:41 -06:00
parent 443c996e7f
commit 47744373f2
4 changed files with 18 additions and 3 deletions

View File

@@ -457,7 +457,7 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.target_norm_std:
# seperate out the batch and channels
pred_std = noise_pred.std([2, 3], keepdim=True)
norm_std_loss = torch.abs(1.0 - pred_std).mean()
norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean()
loss = loss + norm_std_loss