mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added ability to use consistent noise for each image in a dataset by hashing the path and using that as a seed.
This commit is contained in:
@@ -66,6 +66,7 @@ from toolkit.print import print_acc
|
||||
from accelerate import Accelerator
|
||||
import transformers
|
||||
import diffusers
|
||||
import hashlib
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
@@ -862,11 +863,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
noise_chunks.append(best_noise)
|
||||
noise = torch.cat(noise_chunks, dim=0)
|
||||
return noise
|
||||
|
||||
def get_consistent_noise(self, latents, batch: 'DataLoaderBatchDTO', dtype=torch.float32):
|
||||
batch_num = latents.shape[0]
|
||||
chunks = torch.chunk(latents, batch_num, dim=0)
|
||||
noise_chunks = []
|
||||
for idx, chunk in enumerate(chunks):
|
||||
# get seed from path
|
||||
file_item = batch.file_items[idx]
|
||||
img_path = file_item.path
|
||||
# add augmentors
|
||||
if file_item.flip_x:
|
||||
img_path += '_fx'
|
||||
if file_item.flip_y:
|
||||
img_path += '_fy'
|
||||
seed = int(hashlib.md5(img_path.encode()).hexdigest(), 16) & 0xffffffff
|
||||
generator = torch.Generator("cpu").manual_seed(seed)
|
||||
noise_chunk = torch.randn(chunk.shape, generator=generator).to(chunk.device, dtype=dtype)
|
||||
noise_chunks.append(noise_chunk)
|
||||
noise = torch.cat(noise_chunks, dim=0).to(dtype=dtype)
|
||||
return noise
|
||||
|
||||
|
||||
def get_noise(self, latents, batch_size, dtype=torch.float32):
|
||||
def get_noise(self, latents, batch_size, dtype=torch.float32, batch: 'DataLoaderBatchDTO' = None):
|
||||
if self.train_config.optimal_noise_pairing_samples > 1:
|
||||
noise = self.get_optimal_noise(latents, dtype=dtype)
|
||||
elif self.train_config.force_consistent_noise:
|
||||
if batch is None:
|
||||
raise ValueError("Batch must be provided for consistent noise")
|
||||
noise = self.get_consistent_noise(latents, batch, dtype=dtype)
|
||||
else:
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
@@ -1137,7 +1162,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
timesteps = torch.stack(timesteps, dim=0)
|
||||
|
||||
# get noise
|
||||
noise = self.get_noise(latents, batch_size, dtype=dtype)
|
||||
noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch)
|
||||
|
||||
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
|
||||
# this will negate any noise offsets
|
||||
|
||||
Submodule repositories/sd-scripts updated: 25f961bc77...b78c0e2a69
@@ -407,6 +407,9 @@ class TrainConfig:
|
||||
|
||||
# optimal noise pairing
|
||||
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)
|
||||
|
||||
# forces same noise for the same image at a given size.
|
||||
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
||||
Reference in New Issue
Block a user