diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index d0663e91..40e07514 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1304,28 +1304,23 @@ class BaseSDTrainProcess(BaseTrainProcess): # if we have a 5d tensor, then we need to do it on a per batch item, per channel basis, per frame s = (noise.shape[0], noise.shape[1], noise.shape[2], 1, 1) - if self.train_config.random_noise_multiplier > 0.0: - - # do it on a per batch item, per channel basis - noise_multiplier = 1 + torch.randn( - s, - device=noise.device, - dtype=noise.dtype - ) * self.train_config.random_noise_multiplier - - with self.timer('make_noisy_latents'): - noise = noise * noise_multiplier if self.train_config.random_noise_shift > 0.0: # get random noise -1 to 1 noise_shift = torch.randn( - s, + batch_size, latents.shape[1], 1, 1, device=noise.device, dtype=noise.dtype ) * self.train_config.random_noise_shift # add to noise noise += noise_shift + + if self.train_config.random_noise_multiplier > 0.0: + sigma = self.train_config.random_noise_multiplier + noise_multiplier = torch.exp(torch.randn(s, device=noise.device, dtype=noise.dtype) * sigma) + + with self.timer('make_noisy_latents'): latent_multiplier = self.train_config.latent_multiplier