diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 7815dc5a..3b210154 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -778,15 +778,37 @@ class BaseSDTrainProcess(BaseTrainProcess): while len(sigma.shape) < n_dim: sigma = sigma.unsqueeze(-1) return sigma + + def get_optimal_noise(self, latents, dtype=torch.float32): + batch_num = latents.shape[0] + chunks = torch.chunk(latents, batch_num, dim=0) + noise_chunks = [] + for chunk in chunks: + noise_samples = [torch.randn_like(chunk, device=chunk.device, dtype=dtype) for _ in range(self.train_config.optimal_noise_pairing_samples)] + # find the one most similar to the chunk + lowest_loss = 999999999999 + best_noise = None + for noise in noise_samples: + loss = torch.nn.functional.mse_loss(chunk, noise) + if loss < lowest_loss: + lowest_loss = loss + best_noise = noise + noise_chunks.append(best_noise) + noise = torch.cat(noise_chunks, dim=0) + return noise + def get_noise(self, latents, batch_size, dtype=torch.float32): - # get noise - noise = self.sd.get_latent_noise( - height=latents.shape[2], - width=latents.shape[3], - batch_size=batch_size, - noise_offset=self.train_config.noise_offset, - ).to(self.device_torch, dtype=dtype) + if self.train_config.optimal_noise_pairing_samples > 1: + noise = self.get_optimal_noise(latents, dtype=dtype) + else: + # get noise + noise = self.sd.get_latent_noise( + height=latents.shape[2], + width=latents.shape[3], + batch_size=batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) if self.train_config.random_noise_shift > 0.0: # get random noise -1 to 1 diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 592f6e09..b4417afc 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -403,6 +403,9 @@ class TrainConfig: # diffusion feature extractor self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None) + + # optimal noise pairing + self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1) class ModelConfig: