Added ability to use consistent noise for each image in a dataset by hashing the path and using that as a seed.

This commit is contained in:
Jaret Burkett
2025-02-08 07:13:48 -07:00
parent af5e760be1
commit c6d8eedb94
3 changed files with 31 additions and 3 deletions

View File

@@ -66,6 +66,7 @@ from toolkit.print import print_acc
from accelerate import Accelerator
import transformers
import diffusers
import hashlib
def flush():
torch.cuda.empty_cache()
@@ -862,11 +863,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise_chunks.append(best_noise)
noise = torch.cat(noise_chunks, dim=0)
return noise
def get_consistent_noise(self, latents, batch: 'DataLoaderBatchDTO', dtype=torch.float32):
batch_num = latents.shape[0]
chunks = torch.chunk(latents, batch_num, dim=0)
noise_chunks = []
for idx, chunk in enumerate(chunks):
# get seed from path
file_item = batch.file_items[idx]
img_path = file_item.path
# add augmentors
if file_item.flip_x:
img_path += '_fx'
if file_item.flip_y:
img_path += '_fy'
seed = int(hashlib.md5(img_path.encode()).hexdigest(), 16) & 0xffffffff
generator = torch.Generator("cpu").manual_seed(seed)
noise_chunk = torch.randn(chunk.shape, generator=generator).to(chunk.device, dtype=dtype)
noise_chunks.append(noise_chunk)
noise = torch.cat(noise_chunks, dim=0).to(dtype=dtype)
return noise
def get_noise(self, latents, batch_size, dtype=torch.float32):
def get_noise(self, latents, batch_size, dtype=torch.float32, batch: 'DataLoaderBatchDTO' = None):
if self.train_config.optimal_noise_pairing_samples > 1:
noise = self.get_optimal_noise(latents, dtype=dtype)
elif self.train_config.force_consistent_noise:
if batch is None:
raise ValueError("Batch must be provided for consistent noise")
noise = self.get_consistent_noise(latents, batch, dtype=dtype)
else:
# get noise
noise = self.sd.get_latent_noise(
@@ -1137,7 +1162,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
timesteps = torch.stack(timesteps, dim=0)
# get noise
noise = self.get_noise(latents, batch_size, dtype=dtype)
noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch)
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
# this will negate any noise offsets

View File

@@ -407,6 +407,9 @@ class TrainConfig:
# optimal noise pairing
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)
# forces same noise for the same image at a given size.
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
class ModelConfig: