Fix Zero Terminal SNR option

This commit is contained in:
catboxanon
2024-10-19 12:28:14 -04:00
parent 534405e597
commit aba35cde5f
2 changed files with 33 additions and 11 deletions

View File

@@ -35,6 +35,7 @@ from modules.sd_models import apply_token_merging, forge_model_reload
from modules_forge.utils import apply_circular_forge
from modules_forge import main_entry
from backend import memory_management
from backend.modules.k_prediction import rescale_zero_terminal_snr_sigmas
# some of those options should not be changed at all because they would break the model, so I removed them from options.
@@ -969,25 +970,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
# TODO: This currently seems broken. It should be fixed or removed.
sd_models.apply_alpha_schedule_override(p.sd_model, p)
alphas_cumprod_modifiers = p.sd_model.forge_objects.unet.model_options.get('alphas_cumprod_modifiers', [])
alphas_cumprod_backup = None
if len(alphas_cumprod_modifiers) > 0:
alphas_cumprod_backup = p.sd_model.alphas_cumprod
for modifier in alphas_cumprod_modifiers:
p.sd_model.alphas_cumprod = modifier(p.sd_model.alphas_cumprod)
p.sd_model.forge_objects.unet.model.model_sampling.set_sigmas(((1 - p.sd_model.alphas_cumprod) / p.sd_model.alphas_cumprod) ** 0.5)
sigmas_backup = None
if opts.sd_noise_schedule == "Zero Terminal SNR" and p is not None:
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
sigmas_backup = p.sd_model.forge_objects.unet.model.predictor.sigmas
p.sd_model.forge_objects.unet.model.predictor.set_sigmas(rescale_zero_terminal_snr_sigmas(p.sd_model.forge_objects.unet.model.predictor.sigmas))
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
for x_sample in samples_ddim:
p.latents_after_sampling.append(x_sample)
if alphas_cumprod_backup is not None:
p.sd_model.alphas_cumprod = alphas_cumprod_backup
p.sd_model.forge_objects.unet.model.model_sampling.set_sigmas(((1 - p.sd_model.alphas_cumprod) / p.sd_model.alphas_cumprod) ** 0.5)
if sigmas_backup is not None:
p.sd_model.forge_objects.unet.model.predictor.set_sigmas(sigmas_backup)
if p.scripts is not None:
ps = scripts.PostSampleArgs(samples_ddim)