From 434fb22458a24f98b0b3ae026c9e049e8ef4ec0a Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 22 Jul 2023 15:01:01 -0600 Subject: [PATCH] Adde dnoise offset --- jobs/process/TrainSliderProcess.py | 28 +++++++++++++++++++--------- toolkit/train_tools.py | 8 ++++++++ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index bc619556..c7f04cc1 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -17,7 +17,7 @@ from diffusers import StableDiffusionPipeline from jobs.process import BaseTrainProcess from toolkit.metadata import get_meta_for_safetensors -from toolkit.train_tools import get_torch_dtype +from toolkit.train_tools import get_torch_dtype, apply_noise_offset import gc import torch @@ -38,6 +38,9 @@ def flush(): gc.collect() +UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 +VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 + class StableDiffusion: def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler): self.vae = vae @@ -94,6 +97,7 @@ class TrainConfig: self.xformers = kwargs.get('xformers', False) self.train_unet = kwargs.get('train_unet', True) self.train_text_encoder = kwargs.get('train_text_encoder', True) + self.noise_offset = kwargs.get('noise_offset', 0.0) class ModelConfig: @@ -506,13 +510,19 @@ class TrainSliderProcess(BaseTrainProcess): 1, self.train_config.max_denoising_steps, (1,) ).item() - latents = train_util.get_initial_latents( - noise_scheduler, - self.train_config.batch_size, - height, - width, - 1 - ).to(self.device_torch, dtype=dtype) + # get noise + noise = torch.randn( + ( + self.train_config.batch_size, + UNET_IN_CHANNELS, + height // VAE_SCALE_FACTOR, + width // VAE_SCALE_FACTOR, + ), + device="cpu", + ) + noise = apply_noise_offset(noise, self.train_config.noise_offset) + latents = noise * noise_scheduler.init_noise_sigma + latents = latents.to(self.device_torch, dtype=dtype) with self.network: assert self.network.is_active @@ -673,7 +683,7 @@ class TrainSliderProcess(BaseTrainProcess): # end of step self.step_num = step - self.sample(self.step_num) + self.sample(self.step_num + 1) print("") self.save() diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 4d6e38e4..2082811c 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -369,3 +369,11 @@ def sample_images( if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) + + +# https://www.crosslabs.org//blog/diffusion-with-offset-noise +def apply_noise_offset(noise, noise_offset): + if noise_offset is None or noise_offset < 0.0000001: + return noise + noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device) + return noise