From c6d8eedb94e557105a240377ffa24ec4c9d9ab67 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 8 Feb 2025 07:13:48 -0700 Subject: [PATCH] Added ability to use consistent noise for each image in a dataset by hashing the path and using that as a seed. --- jobs/process/BaseSDTrainProcess.py | 29 +++++++++++++++++++++++++++-- repositories/sd-scripts | 2 +- toolkit/config_modules.py | 3 +++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 2183709f..6f057455 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -66,6 +66,7 @@ from toolkit.print import print_acc from accelerate import Accelerator import transformers import diffusers +import hashlib def flush(): torch.cuda.empty_cache() @@ -862,11 +863,35 @@ class BaseSDTrainProcess(BaseTrainProcess): noise_chunks.append(best_noise) noise = torch.cat(noise_chunks, dim=0) return noise + + def get_consistent_noise(self, latents, batch: 'DataLoaderBatchDTO', dtype=torch.float32): + batch_num = latents.shape[0] + chunks = torch.chunk(latents, batch_num, dim=0) + noise_chunks = [] + for idx, chunk in enumerate(chunks): + # get seed from path + file_item = batch.file_items[idx] + img_path = file_item.path + # add augmentors + if file_item.flip_x: + img_path += '_fx' + if file_item.flip_y: + img_path += '_fy' + seed = int(hashlib.md5(img_path.encode()).hexdigest(), 16) & 0xffffffff + generator = torch.Generator("cpu").manual_seed(seed) + noise_chunk = torch.randn(chunk.shape, generator=generator).to(chunk.device, dtype=dtype) + noise_chunks.append(noise_chunk) + noise = torch.cat(noise_chunks, dim=0).to(dtype=dtype) + return noise - def get_noise(self, latents, batch_size, dtype=torch.float32): + def get_noise(self, latents, batch_size, dtype=torch.float32, batch: 'DataLoaderBatchDTO' = None): if self.train_config.optimal_noise_pairing_samples > 1: noise = self.get_optimal_noise(latents, dtype=dtype) + elif self.train_config.force_consistent_noise: + if batch is None: + raise ValueError("Batch must be provided for consistent noise") + noise = self.get_consistent_noise(latents, batch, dtype=dtype) else: # get noise noise = self.sd.get_latent_noise( @@ -1137,7 +1162,7 @@ class BaseSDTrainProcess(BaseTrainProcess): timesteps = torch.stack(timesteps, dim=0) # get noise - noise = self.get_noise(latents, batch_size, dtype=dtype) + noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch) # add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents # this will negate any noise offsets diff --git a/repositories/sd-scripts b/repositories/sd-scripts index 25f961bc..b78c0e2a 160000 --- a/repositories/sd-scripts +++ b/repositories/sd-scripts @@ -1 +1 @@ -Subproject commit 25f961bc779bc79aef440813e3e8e92244ac5739 +Subproject commit b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 7aa30b73..30762c72 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -407,6 +407,9 @@ class TrainConfig: # optimal noise pairing self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1) + + # forces same noise for the same image at a given size. + self.force_consistent_noise = kwargs.get('force_consistent_noise', False) class ModelConfig: