From aba35cde5ff0deea827863225728e6ab5eed2737 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 19 Oct 2024 12:28:14 -0400 Subject: [PATCH] Fix Zero Terminal SNR option --- backend/modules/k_prediction.py | 24 ++++++++++++++++++++++++ modules/processing.py | 20 +++++++++----------- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/backend/modules/k_prediction.py b/backend/modules/k_prediction.py index 7264d7d1..18fb2aac 100644 --- a/backend/modules/k_prediction.py +++ b/backend/modules/k_prediction.py @@ -45,6 +45,26 @@ def flux_time_shift(mu, sigma, t): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) +def rescale_zero_terminal_snr_sigmas(sigmas): + alphas_cumprod = 1 / ((sigmas * sigmas) + 1) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return ((1 - alphas_bar) / alphas_bar) ** 0.5 + + class AbstractPrediction(torch.nn.Module): def __init__(self, sigma_data=1.0, prediction_type='epsilon'): super().__init__() @@ -114,6 +134,10 @@ class Prediction(AbstractPrediction): self.register_buffer('log_sigmas', sigmas.log().float()) return + def set_sigmas(self, sigmas): + self.register_buffer('sigmas', sigmas.float()) + self.register_buffer('log_sigmas', sigmas.log().float()) + @property def sigma_min(self): return self.sigmas[0] diff --git a/modules/processing.py b/modules/processing.py index ca907412..9bc3333d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -35,6 +35,7 @@ from modules.sd_models import apply_token_merging, forge_model_reload from modules_forge.utils import apply_circular_forge from modules_forge import main_entry from backend import memory_management +from backend.modules.k_prediction import rescale_zero_terminal_snr_sigmas # some of those options should not be changed at all because they would break the model, so I removed them from options. @@ -969,25 +970,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" + # TODO: This currently seems broken. It should be fixed or removed. sd_models.apply_alpha_schedule_override(p.sd_model, p) - alphas_cumprod_modifiers = p.sd_model.forge_objects.unet.model_options.get('alphas_cumprod_modifiers', []) - alphas_cumprod_backup = None - - if len(alphas_cumprod_modifiers) > 0: - alphas_cumprod_backup = p.sd_model.alphas_cumprod - for modifier in alphas_cumprod_modifiers: - p.sd_model.alphas_cumprod = modifier(p.sd_model.alphas_cumprod) - p.sd_model.forge_objects.unet.model.model_sampling.set_sigmas(((1 - p.sd_model.alphas_cumprod) / p.sd_model.alphas_cumprod) ** 0.5) + sigmas_backup = None + if opts.sd_noise_schedule == "Zero Terminal SNR" and p is not None: + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + sigmas_backup = p.sd_model.forge_objects.unet.model.predictor.sigmas + p.sd_model.forge_objects.unet.model.predictor.set_sigmas(rescale_zero_terminal_snr_sigmas(p.sd_model.forge_objects.unet.model.predictor.sigmas)) samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) for x_sample in samples_ddim: p.latents_after_sampling.append(x_sample) - if alphas_cumprod_backup is not None: - p.sd_model.alphas_cumprod = alphas_cumprod_backup - p.sd_model.forge_objects.unet.model.model_sampling.set_sigmas(((1 - p.sd_model.alphas_cumprod) / p.sd_model.alphas_cumprod) ** 0.5) + if sigmas_backup is not None: + p.sd_model.forge_objects.unet.model.predictor.set_sigmas(sigmas_backup) if p.scripts is not None: ps = scripts.PostSampleArgs(samples_ddim)