mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-04 10:09:49 +00:00
Add signal correction noise
This commit is contained in:
@@ -1306,6 +1306,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
noise = noise * noise_multiplier
|
||||
|
||||
if self.train_config.do_signal_correction_noise:
|
||||
batch_noise = latents.clone().to(noise.device, dtype=noise.dtype)
|
||||
scn_scale = torch.randn(
|
||||
batch_noise.shape[0], batch_noise.shape[1], 1, 1,
|
||||
device=batch_noise.device,
|
||||
dtype=batch_noise.dtype
|
||||
)
|
||||
batch_noise = batch_noise * scn_scale
|
||||
noise = noise + batch_noise
|
||||
|
||||
if self.train_config.random_noise_shift > 0.0:
|
||||
# get random noise -1 to 1
|
||||
noise_shift = torch.randn(
|
||||
|
||||
@@ -397,6 +397,7 @@ class TrainConfig:
|
||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||
self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0)
|
||||
self.random_noise_multiplier = kwargs.get('random_noise_multiplier', 0.0)
|
||||
self.do_signal_correction_noise = kwargs.get('do_signal_correction_noise', False)
|
||||
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
|
||||
|
||||
Reference in New Issue
Block a user