Adjust signal amplification target. Allow signal amplification strength in config.

This commit is contained in:
Jaret Burkett
2026-03-22 08:30:13 -06:00
parent dcd98dc0d5
commit 0f075fc45e
2 changed files with 3 additions and 0 deletions

View File

@@ -578,6 +578,8 @@ class SDTrainer(BaseSDTrainProcess):
raise ValueError("Signal amplification is only supported for flow matching models")
with torch.no_grad():
nas = 1.0 - (timesteps / 1000).to(noise.device, dtype=noise.dtype)
nas = nas * self.train_config.signal_amplification_strength
nas = nas.clamp(min=0.1)
while len(nas.shape) < len(noise.shape):
nas = nas.unsqueeze(-1)
aug = batch.latents * nas

View File

@@ -402,6 +402,7 @@ class TrainConfig:
self.do_batch_noise_correction = kwargs.get('do_batch_noise_correction', False)
self.batch_noise_correction_scale = kwargs.get('batch_noise_correction_scale', 0.1)
self.do_signal_amplification = kwargs.get('do_signal_amplification', False)
self.signal_amplification_strength = kwargs.get('signal_amplification_strength', 0.5)
self.signal_correction_noise_scale = kwargs.get('signal_correction_noise_scale', 1.0)
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)