Adde dnoise offset

This commit is contained in:
Jaret Burkett
2023-07-22 15:01:01 -06:00
parent 3f4f429c4a
commit 434fb22458
2 changed files with 27 additions and 9 deletions

View File

@@ -17,7 +17,7 @@ from diffusers import StableDiffusionPipeline
from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors
from toolkit.train_tools import get_torch_dtype
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
import torch
@@ -38,6 +38,9 @@ def flush():
gc.collect()
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
class StableDiffusion:
def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler):
self.vae = vae
@@ -94,6 +97,7 @@ class TrainConfig:
self.xformers = kwargs.get('xformers', False)
self.train_unet = kwargs.get('train_unet', True)
self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.noise_offset = kwargs.get('noise_offset', 0.0)
class ModelConfig:
@@ -506,13 +510,19 @@ class TrainSliderProcess(BaseTrainProcess):
1, self.train_config.max_denoising_steps, (1,)
).item()
latents = train_util.get_initial_latents(
noise_scheduler,
self.train_config.batch_size,
height,
width,
1
).to(self.device_torch, dtype=dtype)
# get noise
noise = torch.randn(
(
self.train_config.batch_size,
UNET_IN_CHANNELS,
height // VAE_SCALE_FACTOR,
width // VAE_SCALE_FACTOR,
),
device="cpu",
)
noise = apply_noise_offset(noise, self.train_config.noise_offset)
latents = noise * noise_scheduler.init_noise_sigma
latents = latents.to(self.device_torch, dtype=dtype)
with self.network:
assert self.network.is_active
@@ -673,7 +683,7 @@ class TrainSliderProcess(BaseTrainProcess):
# end of step
self.step_num = step
self.sample(self.step_num)
self.sample(self.step_num + 1)
print("")
self.save()

View File

@@ -369,3 +369,11 @@ def sample_images(
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
def apply_noise_offset(noise, noise_offset):
if noise_offset is None or noise_offset < 0.0000001:
return noise
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
return noise