mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-23 05:43:59 +00:00
Improved the method to augment random noise
This commit is contained in:
@@ -1304,28 +1304,23 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# if we have a 5d tensor, then we need to do it on a per batch item, per channel basis, per frame
|
||||
s = (noise.shape[0], noise.shape[1], noise.shape[2], 1, 1)
|
||||
|
||||
if self.train_config.random_noise_multiplier > 0.0:
|
||||
|
||||
# do it on a per batch item, per channel basis
|
||||
noise_multiplier = 1 + torch.randn(
|
||||
s,
|
||||
device=noise.device,
|
||||
dtype=noise.dtype
|
||||
) * self.train_config.random_noise_multiplier
|
||||
|
||||
with self.timer('make_noisy_latents'):
|
||||
|
||||
noise = noise * noise_multiplier
|
||||
|
||||
if self.train_config.random_noise_shift > 0.0:
|
||||
# get random noise -1 to 1
|
||||
noise_shift = torch.randn(
|
||||
s,
|
||||
batch_size, latents.shape[1], 1, 1,
|
||||
device=noise.device,
|
||||
dtype=noise.dtype
|
||||
) * self.train_config.random_noise_shift
|
||||
# add to noise
|
||||
noise += noise_shift
|
||||
|
||||
if self.train_config.random_noise_multiplier > 0.0:
|
||||
sigma = self.train_config.random_noise_multiplier
|
||||
noise_multiplier = torch.exp(torch.randn(s, device=noise.device, dtype=noise.dtype) * sigma)
|
||||
|
||||
with self.timer('make_noisy_latents'):
|
||||
|
||||
latent_multiplier = self.train_config.latent_multiplier
|
||||
|
||||
|
||||
Reference in New Issue
Block a user