add ability to use batch noise correction during training

This commit is contained in:
Jaret Burkett
2026-03-10 09:05:57 -06:00
parent b04c64e0f8
commit 06ef3d343a
2 changed files with 19 additions and 0 deletions

View File

@@ -1330,6 +1330,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
batch_noise = batch_noise * scn_scale
noise = noise + batch_noise
if self.train_config.do_batch_noise_correction:
if latents.shape[0] == 1:
# if we only have a batch size of 1, then we cant do batch noise correction, so we skip it
print_acc("Skipping batch noise correction because batch size is 1, increase batch size and num_repeats to use this feature")
else:
# shuffle tensors ensuring that no tensor is in the same position as before
batch_noise = latents.clone().roll(shifts=torch.randint(1, latents.shape[0], (1,)).item(), dims=0).to(noise.device, dtype=noise.dtype)
batch_noise_scale = torch.randn(
batch_noise.shape[0], batch_noise.shape[1], 1, 1,
device=batch_noise.device,
dtype=batch_noise.dtype
) * self.train_config.batch_noise_correction_scale
batch_noise = batch_noise * batch_noise_scale
noise = noise + batch_noise
if self.train_config.random_noise_shift > 0.0:
# get random noise -1 to 1
noise_shift = torch.randn(

View File

@@ -398,6 +398,10 @@ class TrainConfig:
self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0)
self.random_noise_multiplier = kwargs.get('random_noise_multiplier', 0.0)
self.do_signal_correction_noise = kwargs.get('do_signal_correction_noise', False)
# batch noise correction adds other images in the batch as noise to correct away from other images
self.do_batch_noise_correction = kwargs.get('do_batch_noise_correction', False)
self.batch_noise_correction_scale = kwargs.get('batch_noise_correction_scale', 0.1)
self.signal_correction_noise_scale = kwargs.get('signal_correction_noise_scale', 1.0)
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0)