mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-12 22:19:48 +00:00
add ability to use batch noise correction during training
This commit is contained in:
@@ -1330,6 +1330,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
batch_noise = batch_noise * scn_scale
|
||||
noise = noise + batch_noise
|
||||
|
||||
if self.train_config.do_batch_noise_correction:
|
||||
if latents.shape[0] == 1:
|
||||
# if we only have a batch size of 1, then we cant do batch noise correction, so we skip it
|
||||
print_acc("Skipping batch noise correction because batch size is 1, increase batch size and num_repeats to use this feature")
|
||||
else:
|
||||
# shuffle tensors ensuring that no tensor is in the same position as before
|
||||
batch_noise = latents.clone().roll(shifts=torch.randint(1, latents.shape[0], (1,)).item(), dims=0).to(noise.device, dtype=noise.dtype)
|
||||
batch_noise_scale = torch.randn(
|
||||
batch_noise.shape[0], batch_noise.shape[1], 1, 1,
|
||||
device=batch_noise.device,
|
||||
dtype=batch_noise.dtype
|
||||
) * self.train_config.batch_noise_correction_scale
|
||||
batch_noise = batch_noise * batch_noise_scale
|
||||
noise = noise + batch_noise
|
||||
|
||||
if self.train_config.random_noise_shift > 0.0:
|
||||
# get random noise -1 to 1
|
||||
noise_shift = torch.randn(
|
||||
|
||||
@@ -398,6 +398,10 @@ class TrainConfig:
|
||||
self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0)
|
||||
self.random_noise_multiplier = kwargs.get('random_noise_multiplier', 0.0)
|
||||
self.do_signal_correction_noise = kwargs.get('do_signal_correction_noise', False)
|
||||
# batch noise correction adds other images in the batch as noise to correct away from other images
|
||||
self.do_batch_noise_correction = kwargs.get('do_batch_noise_correction', False)
|
||||
self.batch_noise_correction_scale = kwargs.get('batch_noise_correction_scale', 0.1)
|
||||
|
||||
self.signal_correction_noise_scale = kwargs.get('signal_correction_noise_scale', 1.0)
|
||||
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
|
||||
Reference in New Issue
Block a user