Add signal correction noise

This commit is contained in:
Jaret Burkett
2026-02-07 09:49:55 -07:00
parent e82cf6eec2
commit 454722cc97
2 changed files with 11 additions and 0 deletions

View File

@@ -1306,6 +1306,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise = noise * noise_multiplier
if self.train_config.do_signal_correction_noise:
batch_noise = latents.clone().to(noise.device, dtype=noise.dtype)
scn_scale = torch.randn(
batch_noise.shape[0], batch_noise.shape[1], 1, 1,
device=batch_noise.device,
dtype=batch_noise.dtype
)
batch_noise = batch_noise * scn_scale
noise = noise + batch_noise
if self.train_config.random_noise_shift > 0.0:
# get random noise -1 to 1
noise_shift = torch.randn(