Add signal correction noise

This commit is contained in:
Jaret Burkett
2026-02-07 09:49:55 -07:00
parent e82cf6eec2
commit 454722cc97
2 changed files with 11 additions and 0 deletions

View File

@@ -1306,6 +1306,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise = noise * noise_multiplier
if self.train_config.do_signal_correction_noise:
batch_noise = latents.clone().to(noise.device, dtype=noise.dtype)
scn_scale = torch.randn(
batch_noise.shape[0], batch_noise.shape[1], 1, 1,
device=batch_noise.device,
dtype=batch_noise.dtype
)
batch_noise = batch_noise * scn_scale
noise = noise + batch_noise
if self.train_config.random_noise_shift > 0.0:
# get random noise -1 to 1
noise_shift = torch.randn(

View File

@@ -397,6 +397,7 @@ class TrainConfig:
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0)
self.random_noise_multiplier = kwargs.get('random_noise_multiplier', 0.0)
self.do_signal_correction_noise = kwargs.get('do_signal_correction_noise', False)
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)