Add audio_loss_multiplier to scale audio loss to larger values if desired.

This commit is contained in:
Jaret Burkett
2026-02-19 11:57:44 -07:00
parent 3632656cda
commit 1c74ca5d22
6 changed files with 39 additions and 10 deletions

View File

@@ -863,6 +863,7 @@ class SDTrainer(BaseSDTrainProcess):
# check for audio loss
if batch.audio_pred is not None and batch.audio_target is not None:
audio_loss = torch.nn.functional.mse_loss(batch.audio_pred.float(), batch.audio_target.float(), reduction="mean")
audio_loss = audio_loss * self.train_config.audio_loss_multiplier
loss = loss + audio_loss
# check for additional losses

View File

@@ -683,6 +683,8 @@ class ModelConfig:
# model paths for models that support it
self.model_paths = kwargs.get("model_paths", {})
self.audio_loss_multiplier = kwargs.get("audio_loss_multiplier", 1.0)
# allow frontend to pass arch with a color like arch:tag
# but remove the tag
if self.arch is not None:

View File

@@ -551,6 +551,17 @@ export default function SimpleJob({
{ value: 'stepped', label: 'Stepped Recovery' },
]}
/>
{modelArch?.additionalSections?.includes('train.audio_loss_multiplier') && (
<NumberInput
label="Audio Loss Multiplier"
className="pt-2"
value={jobConfig.config.process[0].train.audio_loss_multiplier ?? 1.0}
onChange={value => setJobConfig(value, 'config.process[0].train.audio_loss_multiplier')}
placeholder="eg. 1.0"
docKey={'train.audio_loss_multiplier'}
min={0}
/>
)}
</div>
<div>
<FormGroup label="EMA (Exponential Moving Average)">

View File

@@ -22,6 +22,7 @@ type AdditionalSections =
| 'datasets.audio_preserve_pitch'
| 'sample.ctrl_img'
| 'sample.multi_ctrl_imgs'
| 'train.audio_loss_multiplier'
| 'datasets.num_frames'
| 'model.multistage'
| 'model.layer_offloading'
@@ -642,13 +643,14 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.fps': [24, 1],
'config.process[0].sample.width': [768, 1024],
'config.process[0].sample.height': [768, 1024],
'config.process[0].train.audio_loss_multiplier': [1.0, undefined],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
'config.process[0].datasets[x].do_i2v': [false, undefined],
'config.process[0].datasets[x].do_audio': [true, undefined],
'config.process[0].datasets[x].fps': [24, undefined],
},
disableSections: ['network.conv'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch', 'datasets.do_i2v'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch', 'datasets.do_i2v', 'train.audio_loss_multiplier'],
},
{
name: 'flux2_klein_4b',

View File

@@ -112,8 +112,8 @@ const docs: { [key: string]: ConfigDoc } = {
description: (
<>
For models that support audio with video, this option will load the audio from the video and resize it to match
the video sequence. Since the video is automatically resized, the audio may drop or raise in pitch to match the new
speed of the video. It is important to prep your dataset to have the proper length before training.
the video sequence. Since the video is automatically resized, the audio may drop or raise in pitch to match the
new speed of the video. It is important to prep your dataset to have the proper length before training.
</>
),
},
@@ -121,8 +121,9 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Audio Normalize',
description: (
<>
When loading audio, this will normalize the audio volume to the max peaks. Useful if your dataset has varying audio
volumes. Warning, do not use if you have clips with full silence you want to keep, as it will raise the volume of those clips.
When loading audio, this will normalize the audio volume to the max peaks. Useful if your dataset has varying
audio volumes. Warning, do not use if you have clips with full silence you want to keep, as it will raise the
volume of those clips.
</>
),
},
@@ -132,7 +133,7 @@ const docs: { [key: string]: ConfigDoc } = {
<>
When loading audio to match the number of frames requested, this option will preserve the pitch of the audio if
the length does not match training target. It is recommended to have a dataset that matches your target length,
as this option can add sound distortions.
as this option can add sound distortions.
</>
),
},
@@ -310,10 +311,21 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Num Repeats',
description: (
<>
Number of Repeats will allow you to repeate the items in a dataset multiple times. This is useful when you are using multiple
datasets and want to balance the number of samples from each dataset. For instance, if you have a small dataset of 10 images
and a large dataset of 100 images, you can set the small dataset to have 10 repeats to effectively make it 100 images, making
the two datasets occour equally during training.
Number of Repeats will allow you to repeate the items in a dataset multiple times. This is useful when you are
using multiple datasets and want to balance the number of samples from each dataset. For instance, if you have a
small dataset of 10 images and a large dataset of 100 images, you can set the small dataset to have 10 repeats
to effectively make it 100 images, making the two datasets occour equally during training.
</>
),
},
'train.audio_loss_multiplier': {
title: 'Audio Loss Multiplier',
description: (
<>
When training audio and video, sometimes the video loss is so great that it outweights the audio loss, causing
the audio to become distorted. If you are noticing this happen, you can increase the audio loss multiplier to
give more weight to the audio loss. You could try something like 2.0, 10.0 etc. Warning, setting this too high
could overfit and damage the model.
</>
),
},

View File

@@ -146,6 +146,7 @@ export interface TrainConfig {
loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped';
do_differential_guidance?: boolean;
differential_guidance_scale?: number;
audio_loss_multiplier?: number;
}
export interface QuantizeKwargsConfig {