From dcd98dc0d548a59654d65ed9397da924f1350227 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 21 Mar 2026 07:44:18 -0600 Subject: [PATCH] Add signal amplification --- extensions_built_in/sd_trainer/SDTrainer.py | 11 ++++++++++- toolkit/config_modules.py | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 63c3874b..2ae346b6 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -573,7 +573,16 @@ class SDTrainer(BaseSDTrainProcess): elif self.sd.prediction_type == 'v_prediction': # v-parameterization training target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) - + elif self.train_config.do_signal_amplification: + if not self.sd.is_flow_matching: + raise ValueError("Signal amplification is only supported for flow matching models") + with torch.no_grad(): + nas = 1.0 - (timesteps / 1000).to(noise.device, dtype=noise.dtype) + while len(nas.shape) < len(noise.shape): + nas = nas.unsqueeze(-1) + aug = batch.latents * nas + target = noise - (batch.latents + aug) + target = target.detach() elif hasattr(self.sd, 'get_loss_target'): target = self.sd.get_loss_target( noise=noise, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 418d635d..428574f4 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -401,6 +401,7 @@ class TrainConfig: # batch noise correction adds other images in the batch as noise to correct away from other images self.do_batch_noise_correction = kwargs.get('do_batch_noise_correction', False) self.batch_noise_correction_scale = kwargs.get('batch_noise_correction_scale', 0.1) + self.do_signal_amplification = kwargs.get('do_signal_amplification', False) self.signal_correction_noise_scale = kwargs.get('signal_correction_noise_scale', 1.0) self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)