mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
Adde dnoise offset
This commit is contained in:
@@ -17,7 +17,7 @@ from diffusers import StableDiffusionPipeline
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import gc
|
||||
|
||||
import torch
|
||||
@@ -38,6 +38,9 @@ def flush():
|
||||
gc.collect()
|
||||
|
||||
|
||||
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
|
||||
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler):
|
||||
self.vae = vae
|
||||
@@ -94,6 +97,7 @@ class TrainConfig:
|
||||
self.xformers = kwargs.get('xformers', False)
|
||||
self.train_unet = kwargs.get('train_unet', True)
|
||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@@ -506,13 +510,19 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
1, self.train_config.max_denoising_steps, (1,)
|
||||
).item()
|
||||
|
||||
latents = train_util.get_initial_latents(
|
||||
noise_scheduler,
|
||||
self.train_config.batch_size,
|
||||
height,
|
||||
width,
|
||||
1
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
# get noise
|
||||
noise = torch.randn(
|
||||
(
|
||||
self.train_config.batch_size,
|
||||
UNET_IN_CHANNELS,
|
||||
height // VAE_SCALE_FACTOR,
|
||||
width // VAE_SCALE_FACTOR,
|
||||
),
|
||||
device="cpu",
|
||||
)
|
||||
noise = apply_noise_offset(noise, self.train_config.noise_offset)
|
||||
latents = noise * noise_scheduler.init_noise_sigma
|
||||
latents = latents.to(self.device_torch, dtype=dtype)
|
||||
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
@@ -673,7 +683,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
# end of step
|
||||
self.step_num = step
|
||||
|
||||
self.sample(self.step_num)
|
||||
self.sample(self.step_num + 1)
|
||||
print("")
|
||||
self.save()
|
||||
|
||||
|
||||
@@ -369,3 +369,11 @@ def sample_images(
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
|
||||
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
def apply_noise_offset(noise, noise_offset):
|
||||
if noise_offset is None or noise_offset < 0.0000001:
|
||||
return noise
|
||||
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
|
||||
return noise
|
||||
|
||||
Reference in New Issue
Block a user