Add signal amplification

This commit is contained in:
Jaret Burkett
2026-03-21 07:44:18 -06:00
parent 35b1cde3cb
commit dcd98dc0d5
2 changed files with 11 additions and 1 deletions

View File

@@ -573,7 +573,16 @@ class SDTrainer(BaseSDTrainProcess):
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
elif self.train_config.do_signal_amplification:
if not self.sd.is_flow_matching:
raise ValueError("Signal amplification is only supported for flow matching models")
with torch.no_grad():
nas = 1.0 - (timesteps / 1000).to(noise.device, dtype=noise.dtype)
while len(nas.shape) < len(noise.shape):
nas = nas.unsqueeze(-1)
aug = batch.latents * nas
target = noise - (batch.latents + aug)
target = target.detach()
elif hasattr(self.sd, 'get_loss_target'):
target = self.sd.get_loss_target(
noise=noise,

View File

@@ -401,6 +401,7 @@ class TrainConfig:
# batch noise correction adds other images in the batch as noise to correct away from other images
self.do_batch_noise_correction = kwargs.get('do_batch_noise_correction', False)
self.batch_noise_correction_scale = kwargs.get('batch_noise_correction_scale', 0.1)
self.do_signal_amplification = kwargs.get('do_signal_amplification', False)
self.signal_correction_noise_scale = kwargs.get('signal_correction_noise_scale', 1.0)
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)