Improved the method to augment random noise

This commit is contained in:
Jaret Burkett
2026-02-06 15:44:10 -07:00
parent 115f0a3670
commit 1422789452

View File

@@ -1304,28 +1304,23 @@ class BaseSDTrainProcess(BaseTrainProcess):
# if we have a 5d tensor, then we need to do it on a per batch item, per channel basis, per frame
s = (noise.shape[0], noise.shape[1], noise.shape[2], 1, 1)
if self.train_config.random_noise_multiplier > 0.0:
# do it on a per batch item, per channel basis
noise_multiplier = 1 + torch.randn(
s,
device=noise.device,
dtype=noise.dtype
) * self.train_config.random_noise_multiplier
with self.timer('make_noisy_latents'):
noise = noise * noise_multiplier
if self.train_config.random_noise_shift > 0.0:
# get random noise -1 to 1
noise_shift = torch.randn(
s,
batch_size, latents.shape[1], 1, 1,
device=noise.device,
dtype=noise.dtype
) * self.train_config.random_noise_shift
# add to noise
noise += noise_shift
if self.train_config.random_noise_multiplier > 0.0:
sigma = self.train_config.random_noise_multiplier
noise_multiplier = torch.exp(torch.randn(s, device=noise.device, dtype=noise.dtype) * sigma)
with self.timer('make_noisy_latents'):
latent_multiplier = self.train_config.latent_multiplier