diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e55525aa..01418c74 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1330,6 +1330,21 @@ class BaseSDTrainProcess(BaseTrainProcess): batch_noise = batch_noise * scn_scale noise = noise + batch_noise + if self.train_config.do_batch_noise_correction: + if latents.shape[0] == 1: + # if we only have a batch size of 1, then we cant do batch noise correction, so we skip it + print_acc("Skipping batch noise correction because batch size is 1, increase batch size and num_repeats to use this feature") + else: + # shuffle tensors ensuring that no tensor is in the same position as before + batch_noise = latents.clone().roll(shifts=torch.randint(1, latents.shape[0], (1,)).item(), dims=0).to(noise.device, dtype=noise.dtype) + batch_noise_scale = torch.randn( + batch_noise.shape[0], batch_noise.shape[1], 1, 1, + device=batch_noise.device, + dtype=batch_noise.dtype + ) * self.train_config.batch_noise_correction_scale + batch_noise = batch_noise * batch_noise_scale + noise = noise + batch_noise + if self.train_config.random_noise_shift > 0.0: # get random noise -1 to 1 noise_shift = torch.randn( diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 51355fee..03c04cf1 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -398,6 +398,10 @@ class TrainConfig: self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0) self.random_noise_multiplier = kwargs.get('random_noise_multiplier', 0.0) self.do_signal_correction_noise = kwargs.get('do_signal_correction_noise', False) + # batch noise correction adds other images in the batch as noise to correct away from other images + self.do_batch_noise_correction = kwargs.get('do_batch_noise_correction', False) + self.batch_noise_correction_scale = kwargs.get('batch_noise_correction_scale', 0.1) + self.signal_correction_noise_scale = kwargs.get('signal_correction_noise_scale', 1.0) self.random_noise_shift = kwargs.get('random_noise_shift', 0.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0)