Add audio_loss_multiplier to scale audio loss to larger values if desired.

This commit is contained in:
Jaret Burkett
2026-02-19 11:57:44 -07:00
parent 3632656cda
commit 1c74ca5d22
6 changed files with 39 additions and 10 deletions

View File

@@ -863,6 +863,7 @@ class SDTrainer(BaseSDTrainProcess):
# check for audio loss
if batch.audio_pred is not None and batch.audio_target is not None:
audio_loss = torch.nn.functional.mse_loss(batch.audio_pred.float(), batch.audio_target.float(), reduction="mean")
audio_loss = audio_loss * self.train_config.audio_loss_multiplier
loss = loss + audio_loss
# check for additional losses