mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Add audio_loss_multiplier to scale audio loss to larger values if desired.
This commit is contained in:
@@ -863,6 +863,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# check for audio loss
|
||||
if batch.audio_pred is not None and batch.audio_target is not None:
|
||||
audio_loss = torch.nn.functional.mse_loss(batch.audio_pred.float(), batch.audio_target.float(), reduction="mean")
|
||||
audio_loss = audio_loss * self.train_config.audio_loss_multiplier
|
||||
loss = loss + audio_loss
|
||||
|
||||
# check for additional losses
|
||||
|
||||
Reference in New Issue
Block a user